Get forwarded IP address

merge-requests/30/head
Bob Mottram 2021-06-09 15:01:26 +01:00
parent fde879b998
commit f4b0491c34
2 changed files with 42 additions and 23 deletions

View File

@ -206,6 +206,7 @@ from shares import addShare
from shares import removeShare from shares import removeShare
from shares import expireShares from shares import expireShares
from categories import setHashtagCategory from categories import setHashtagCategory
from utils import isLocalNetworkAddress
from utils import permittedDir from utils import permittedDir
from utils import isAccountDir from utils import isAccountDir
from utils import getOccupationSkills from utils import getOccupationSkills
@ -1437,7 +1438,10 @@ class PubServer(BaseHTTPRequestHandler):
return return
authHeader = \ authHeader = \
createBasicAuthHeader(loginNickname, loginPassword) createBasicAuthHeader(loginNickname, loginPassword)
ipAddress = self.client_address[0] if self.headers.get('X-Forwarded-For'):
ipAddress = self.headers['X-Forwarded-For']
else:
ipAddress = self.client_address[0]
print('Login attempt from IP: ' + str(ipAddress)) print('Login attempt from IP: ' + str(ipAddress))
if not authorizeBasic(baseDir, '/users/' + if not authorizeBasic(baseDir, '/users/' +
loginNickname + '/outbox', loginNickname + '/outbox',
@ -1446,28 +1450,33 @@ class PubServer(BaseHTTPRequestHandler):
self._clearLoginDetails(loginNickname, callingDomain) self._clearLoginDetails(loginNickname, callingDomain)
failTime = int(time.time()) failTime = int(time.time())
self.server.lastLoginFailure = failTime self.server.lastLoginFailure = failTime
if not self.server.loginFailureCount.get(ipAddress): if not isLocalNetworkAddress(ipAddress):
while len(self.server.loginFailureCount.items()) > 100: countDict = self.server.loginFailureCount
oldestTime = 0 if not countDict.get(ipAddress):
oldestIP = None while len(countDict.items()) > 100:
for ipAddr, ipItem in self.server.loginFailureCount: oldestTime = 0
if oldestTime == 0 or ipItem['time'] < oldestTime: oldestIP = None
oldestTime = ipItem['time'] for ipAddr, ipItem in countDict.items():
oldestIP = ipAddr if oldestTime == 0 or \
if oldestTime > 0: ipItem['time'] < oldestTime:
del self.server.loginFailureCount[oldestIP] oldestTime = ipItem['time']
self.server.loginFailureCount[ipAddress] = { oldestIP = ipAddr
"count": 1, if oldestTime > 0:
"time": failTime del countDict[oldestIP]
} countDict[ipAddress] = {
else: "count": 1,
self.server.loginFailureCount[ipAddress]['count'] += 1 "time": failTime
failCount = \ }
self.server.loginFailureCount[ipAddress]['count'] else:
if failCount > 4: countDict[ipAddress]['count'] += 1
print('WARN: ' + str(ipAddress) + failCount = \
' failed to log in ' + str(failCount) + ' times') countDict[ipAddress]['count']
self.server.loginFailureCount[ipAddress]['time'] = failTime if failCount > 4:
print('WARN: ' + str(ipAddress) +
' failed to log in ' +
str(failCount) + ' times')
countDict[ipAddress]['time'] = \
failTime
self.server.POSTbusy = False self.server.POSTbusy = False
return return
else: else:

View File

@ -669,6 +669,16 @@ def getLocalNetworkAddresses() -> []:
return ('localhost', '127.0.', '192.168', '10.0.') return ('localhost', '127.0.', '192.168', '10.0.')
def isLocalNetworkAddress(ipAddress: str) -> bool:
"""
"""
localIPs = getLocalNetworkAddresses()
for ipAddr in localIPs:
if ipAddress.startswith(ipAddr):
return True
return False
def dangerousMarkup(content: str, allowLocalNetworkAccess: bool) -> bool: def dangerousMarkup(content: str, allowLocalNetworkAccess: bool) -> bool:
"""Returns true if the given content contains dangerous html markup """Returns true if the given content contains dangerous html markup
""" """