diff --git a/daemon.py b/daemon.py index edc88ef85..5f150d154 100644 --- a/daemon.py +++ b/daemon.py @@ -245,6 +245,7 @@ from languages import set_actor_languages from languages import get_understood_languages from like import update_likes_collection from reaction import update_reaction_collection +from utils import get_domain_from_url_in_string from utils import local_network_host from utils import undo_reaction_collection_entry from utils import get_new_post_endpoints @@ -13461,13 +13462,11 @@ class PubServer(BaseHTTPRequestHandler): """ referer_domain = None if self.headers.get('referer'): - referer_domain, referer_port = \ - get_domain_from_actor(self.headers['referer']) - referer_domain = get_full_domain(referer_domain, referer_port) + referer_domain = \ + get_domain_from_url_in_string(self.headers['referer']) elif self.headers.get('Referer'): - referer_domain, referer_port = \ - get_domain_from_actor(self.headers['Referer']) - referer_domain = get_full_domain(referer_domain, referer_port) + referer_domain = \ + get_domain_from_url_in_string(self.headers['Referer']) elif self.headers.get('Signature'): if 'keyId="' in self.headers['Signature']: referer_domain = self.headers['Signature'].split('keyId="')[1] @@ -13478,18 +13477,7 @@ class PubServer(BaseHTTPRequestHandler): elif '"' in referer_domain: referer_domain = referer_domain.split('"')[0] elif ua_str: - if 'https://' in ua_str: - referer_domain = ua_str.split('https://')[1] - if '/' in referer_domain: - referer_domain = referer_domain.split('/')[0] - elif ')' in referer_domain: - referer_domain = referer_domain.split(')')[0] - elif 'http://' in ua_str: - referer_domain = ua_str.split('http://')[1] - if '/' in referer_domain: - referer_domain = referer_domain.split('/')[0] - elif ')' in referer_domain: - referer_domain = referer_domain.split(')')[0] + referer_domain = get_domain_from_url_in_string(ua_str) return referer_domain def _get_user_agent(self) -> str: diff --git a/utils.py b/utils.py index 5f89e315c..58638e408 100644 --- a/utils.py +++ b/utils.py @@ -3323,3 +3323,22 @@ def valid_hash_tag(hashtag: str) -> bool: if _is_valid_language(hashtag): return True return False + + +def get_domain_from_url_in_string(text: str) -> str: + """Returns the domain from within a string if it exists + """ + domain_str = '' + if 'https://' in text: + domain_str = text.split('https://')[1] + if '/' in domain_str: + domain_str = domain_str.split('/')[0] + elif ')' in domain_str: + domain_str = domain_str.split(')')[0] + elif 'http://' in text: + domain_str = text.split('http://')[1] + if '/' in domain_str: + domain_str = domain_str.split('/')[0] + elif ')' in domain_str: + domain_str = domain_str.split(')')[0] + return domain_str