Get forwarded IP address

main
Bob Mottram 2021-06-09 15:01:26 +01:00
parent fde879b998
commit f4b0491c34
2 changed files with 42 additions and 23 deletions

View File

@ -206,6 +206,7 @@ from shares import addShare
from shares import removeShare
from shares import expireShares
from categories import setHashtagCategory
from utils import isLocalNetworkAddress
from utils import permittedDir
from utils import isAccountDir
from utils import getOccupationSkills
@ -1437,7 +1438,10 @@ class PubServer(BaseHTTPRequestHandler):
return
authHeader = \
createBasicAuthHeader(loginNickname, loginPassword)
ipAddress = self.client_address[0]
if self.headers.get('X-Forwarded-For'):
ipAddress = self.headers['X-Forwarded-For']
else:
ipAddress = self.client_address[0]
print('Login attempt from IP: ' + str(ipAddress))
if not authorizeBasic(baseDir, '/users/' +
loginNickname + '/outbox',
@ -1446,28 +1450,33 @@ class PubServer(BaseHTTPRequestHandler):
self._clearLoginDetails(loginNickname, callingDomain)
failTime = int(time.time())
self.server.lastLoginFailure = failTime
if not self.server.loginFailureCount.get(ipAddress):
while len(self.server.loginFailureCount.items()) > 100:
oldestTime = 0
oldestIP = None
for ipAddr, ipItem in self.server.loginFailureCount:
if oldestTime == 0 or ipItem['time'] < oldestTime:
oldestTime = ipItem['time']
oldestIP = ipAddr
if oldestTime > 0:
del self.server.loginFailureCount[oldestIP]
self.server.loginFailureCount[ipAddress] = {
"count": 1,
"time": failTime
}
else:
self.server.loginFailureCount[ipAddress]['count'] += 1
failCount = \
self.server.loginFailureCount[ipAddress]['count']
if failCount > 4:
print('WARN: ' + str(ipAddress) +
' failed to log in ' + str(failCount) + ' times')
self.server.loginFailureCount[ipAddress]['time'] = failTime
if not isLocalNetworkAddress(ipAddress):
countDict = self.server.loginFailureCount
if not countDict.get(ipAddress):
while len(countDict.items()) > 100:
oldestTime = 0
oldestIP = None
for ipAddr, ipItem in countDict.items():
if oldestTime == 0 or \
ipItem['time'] < oldestTime:
oldestTime = ipItem['time']
oldestIP = ipAddr
if oldestTime > 0:
del countDict[oldestIP]
countDict[ipAddress] = {
"count": 1,
"time": failTime
}
else:
countDict[ipAddress]['count'] += 1
failCount = \
countDict[ipAddress]['count']
if failCount > 4:
print('WARN: ' + str(ipAddress) +
' failed to log in ' +
str(failCount) + ' times')
countDict[ipAddress]['time'] = \
failTime
self.server.POSTbusy = False
return
else:

View File

@ -669,6 +669,16 @@ def getLocalNetworkAddresses() -> []:
return ('localhost', '127.0.', '192.168', '10.0.')
def isLocalNetworkAddress(ipAddress: str) -> bool:
"""
"""
localIPs = getLocalNetworkAddresses()
for ipAddr in localIPs:
if ipAddress.startswith(ipAddr):
return True
return False
def dangerousMarkup(content: str, allowLocalNetworkAccess: bool) -> bool:
"""Returns true if the given content contains dangerous html markup
"""