Unit test for user agent domain

merge-requests/30/head
Bob Mottram 2021-06-20 16:45:29 +01:00
parent af556dc134
commit 0068b2b8cd
3 changed files with 42 additions and 27 deletions

View File

@ -207,6 +207,7 @@ from shares import addShare
from shares import removeShare from shares import removeShare
from shares import expireShares from shares import expireShares
from categories import setHashtagCategory from categories import setHashtagCategory
from utils import userAgentDomain
from utils import isLocalNetworkAddress from utils import isLocalNetworkAddress
from utils import permittedDir from utils import permittedDir
from utils import isAccountDir from utils import isAccountDir
@ -452,35 +453,13 @@ class PubServer(BaseHTTPRequestHandler):
else: else:
print('ERROR: unable to create vote') print('ERROR: unable to create vote')
def _userAgentDomain(self) -> str:
"""Returns the domain specified within User-Agent header
"""
if not self.headers.get('User-Agent'):
return None
agentStr = self.headers.get('User-Agent')
if '+http' not in agentStr:
return None
agentDomain = agentStr.split('+http')[1].strip()
if '://' in agentDomain:
agentDomain = agentDomain.split('://')[1]
if '/' in agentDomain:
agentDomain = agentDomain.split('/')[0]
if ')' in agentDomain:
agentDomain = agentDomain.split(')')[0].strip()
if ' ' in agentDomain:
agentDomain = agentDomain.replace(' ', '')
if ';' in agentDomain:
agentDomain = agentDomain.replace(';', '')
if '.' not in agentDomain:
return None
if self.server.debug:
print('User-Agent Domain: ' + agentDomain)
return agentDomain
def _blockedUserAgent(self, callingDomain: str) -> bool: def _blockedUserAgent(self, callingDomain: str) -> bool:
"""Should a GET or POST be blocked based upon its user agent? """Should a GET or POST be blocked based upon its user agent?
""" """
agentDomain = self._userAgentDomain() agentDomain = None
if self.headers.get('User-Agent'):
agentDomain = userAgentDomain(self.headers['User-Agent'],
self.server.debug)
blockedUA = False blockedUA = False
if not agentDomain: if not agentDomain:
if self.server.userAgentDomainRequired: if self.server.userAgentDomainRequired:

View File

@ -37,13 +37,14 @@ from follow import clearFollows
from follow import clearFollowers from follow import clearFollowers
from follow import sendFollowRequestViaServer from follow import sendFollowRequestViaServer
from follow import sendUnfollowRequestViaServer from follow import sendUnfollowRequestViaServer
from siteactive import siteIsActive
from utils import userAgentDomain
from utils import camelCaseSplit from utils import camelCaseSplit
from utils import decodedHost from utils import decodedHost
from utils import getFullDomain from utils import getFullDomain
from utils import validNickname from utils import validNickname
from utils import firstParagraphFromString from utils import firstParagraphFromString
from utils import removeIdEnding from utils import removeIdEnding
from siteactive import siteIsActive
from utils import updateRecentPostsCache from utils import updateRecentPostsCache
from utils import followPerson from utils import followPerson
from utils import getNicknameFromActor from utils import getNicknameFromActor
@ -3938,10 +3939,21 @@ def _testRoles() -> None:
assert not actorHasRole(actorJson, "artist") assert not actorHasRole(actorJson, "artist")
def _testUserAgentDomain() -> None:
print('testUserAgentDomain')
userAgent = \
'http.rb/4.4.1 (Mastodon/9.10.11; +https://mastodon.something/)'
assert userAgentDomain(userAgent, False) == 'mastodon.something'
userAgent = \
'Mozilla/70.0 (X11; Linux x86_64; rv:1.0) Gecko/20450101 Firefox/1.0'
assert userAgentDomain(userAgent, False) is None
def runAllTests(): def runAllTests():
print('Running tests...') print('Running tests...')
updateDefaultThemesList(os.getcwd()) updateDefaultThemesList(os.getcwd())
_testFunctions() _testFunctions()
_testUserAgentDomain()
_testRoles() _testRoles()
_testSkills() _testSkills()
_testSpoofGeolocation() _testSpoofGeolocation()

View File

@ -2433,3 +2433,27 @@ def permittedDir(path: str) -> bool:
path.startswith('/accounts'): path.startswith('/accounts'):
return False return False
return True return True
def userAgentDomain(userAgent: str, debug: bool) -> str:
"""If the User-Agent string contains a domain
then return it
"""
if '+http' not in userAgent:
return None
agentDomain = userAgent.split('+http')[1].strip()
if '://' in agentDomain:
agentDomain = agentDomain.split('://')[1]
if '/' in agentDomain:
agentDomain = agentDomain.split('/')[0]
if ')' in agentDomain:
agentDomain = agentDomain.split(')')[0].strip()
if ' ' in agentDomain:
agentDomain = agentDomain.replace(' ', '')
if ';' in agentDomain:
agentDomain = agentDomain.replace(';', '')
if '.' not in agentDomain:
return None
if debug:
print('User-Agent Domain: ' + agentDomain)
return agentDomain