Get forwarded IP address

merge-requests/30/head
Bob Mottram 2021-06-09 15:01:26 +01:00
parent fde879b998
commit f4b0491c34
2 changed files with 42 additions and 23 deletions

View File

@ -206,6 +206,7 @@ from shares import addShare
from shares import removeShare from shares import removeShare
from shares import expireShares from shares import expireShares
from categories import setHashtagCategory from categories import setHashtagCategory
from utils import isLocalNetworkAddress
from utils import permittedDir from utils import permittedDir
from utils import isAccountDir from utils import isAccountDir
from utils import getOccupationSkills from utils import getOccupationSkills
@ -1437,6 +1438,9 @@ class PubServer(BaseHTTPRequestHandler):
return return
authHeader = \ authHeader = \
createBasicAuthHeader(loginNickname, loginPassword) createBasicAuthHeader(loginNickname, loginPassword)
if self.headers.get('X-Forwarded-For'):
ipAddress = self.headers['X-Forwarded-For']
else:
ipAddress = self.client_address[0] ipAddress = self.client_address[0]
print('Login attempt from IP: ' + str(ipAddress)) print('Login attempt from IP: ' + str(ipAddress))
if not authorizeBasic(baseDir, '/users/' + if not authorizeBasic(baseDir, '/users/' +
@ -1446,28 +1450,33 @@ class PubServer(BaseHTTPRequestHandler):
self._clearLoginDetails(loginNickname, callingDomain) self._clearLoginDetails(loginNickname, callingDomain)
failTime = int(time.time()) failTime = int(time.time())
self.server.lastLoginFailure = failTime self.server.lastLoginFailure = failTime
if not self.server.loginFailureCount.get(ipAddress): if not isLocalNetworkAddress(ipAddress):
while len(self.server.loginFailureCount.items()) > 100: countDict = self.server.loginFailureCount
if not countDict.get(ipAddress):
while len(countDict.items()) > 100:
oldestTime = 0 oldestTime = 0
oldestIP = None oldestIP = None
for ipAddr, ipItem in self.server.loginFailureCount: for ipAddr, ipItem in countDict.items():
if oldestTime == 0 or ipItem['time'] < oldestTime: if oldestTime == 0 or \
ipItem['time'] < oldestTime:
oldestTime = ipItem['time'] oldestTime = ipItem['time']
oldestIP = ipAddr oldestIP = ipAddr
if oldestTime > 0: if oldestTime > 0:
del self.server.loginFailureCount[oldestIP] del countDict[oldestIP]
self.server.loginFailureCount[ipAddress] = { countDict[ipAddress] = {
"count": 1, "count": 1,
"time": failTime "time": failTime
} }
else: else:
self.server.loginFailureCount[ipAddress]['count'] += 1 countDict[ipAddress]['count'] += 1
failCount = \ failCount = \
self.server.loginFailureCount[ipAddress]['count'] countDict[ipAddress]['count']
if failCount > 4: if failCount > 4:
print('WARN: ' + str(ipAddress) + print('WARN: ' + str(ipAddress) +
' failed to log in ' + str(failCount) + ' times') ' failed to log in ' +
self.server.loginFailureCount[ipAddress]['time'] = failTime str(failCount) + ' times')
countDict[ipAddress]['time'] = \
failTime
self.server.POSTbusy = False self.server.POSTbusy = False
return return
else: else:

View File

@ -669,6 +669,16 @@ def getLocalNetworkAddresses() -> []:
return ('localhost', '127.0.', '192.168', '10.0.') return ('localhost', '127.0.', '192.168', '10.0.')
def isLocalNetworkAddress(ipAddress: str) -> bool:
"""
"""
localIPs = getLocalNetworkAddresses()
for ipAddr in localIPs:
if ipAddress.startswith(ipAddr):
return True
return False
def dangerousMarkup(content: str, allowLocalNetworkAccess: bool) -> bool: def dangerousMarkup(content: str, allowLocalNetworkAccess: bool) -> bool:
"""Returns true if the given content contains dangerous html markup """Returns true if the given content contains dangerous html markup
""" """