Commit ea4a31c7 authored by John Davis's avatar John Davis
Browse files

Normalize email based on new config

parent d0ad647c
Loading
Loading
Loading
Loading
+69 −20
Original line number Diff line number Diff line
@@ -8,6 +8,7 @@ user inputs - so these methods do not need to be escaped.
import logging
import re
from typing import (
    Dict,
    List,
    Optional,
)
@@ -88,7 +89,7 @@ def validate_email(trans, email, user=None, check_dup=True, allow_empty=False, v
        message = validate_email_domain_name(domain)

    if not message:
        if is_email_banned(email, trans.app.config.email_ban_file):
        if is_email_banned(email, trans.app.config.email_ban_file, trans.app.config.canonical_email_rules):
            message = "This email address has been banned."

    stmt = select(trans.app.model.User).filter(func.lower(trans.app.model.User.email) == email.lower()).limit(1)
@@ -175,33 +176,81 @@ def validate_preferred_object_store_id(
    return object_store.validate_selected_object_store_id(trans.user, preferred_object_store_id) or ""


def is_email_banned(email: str, filepath: Optional[str]) -> bool:
def is_email_banned(email: str, filepath: Optional[str], canonical_email_rules: Optional[Dict]) -> bool:
    if not filepath:
        return False
    email = _make_canonical_email(email)
    normalizer = EmailAddressNormalizer(canonical_email_rules)
    email = normalizer.normalize(email)
    banned_emails = _read_email_ban_list(filepath)
    for address in banned_emails:
        if email == _make_canonical_email(address):
        if email == normalizer.normalize(address):
            return True
    return False


def _make_canonical_email(email: str) -> str:
    """
    Transform to canonical representation:
    - lowercase
    - gmail: drop periods in local-part
    - gmail: drop plus suffixes in local-part
    """
    email = email.lower()
    localpart, domain = email.split("@")
    if domain == "gmail.com":
        localpart = localpart.replace(".", "")
        if localpart.find("+") > -1:
            localpart = localpart[: localpart.index("+")]
    return f"{localpart}@{domain}"


def _read_email_ban_list(filepath: str) -> List[str]:
    with open(filepath) as f:
        return [line.strip() for line in f if not line.startswith("#")]


class EmailAddressNormalizer:
    IGNORE_CASE_RULE = "ignore_case"
    IGNORE_DOTS_RULE = "ignore_dots"
    SUB_ADDRESSING_RULE = "sub_addressing"
    SUB_ADDRESSING_DELIM = "sub_addressing_delim"
    SUB_ADDRESSING_DELIM_DEFAULT = "+"
    ALL = "all"

    def __init__(self, canonical_email_rules: Optional[Dict]) -> None:
        self.config = canonical_email_rules

    def normalize(self, email: str) -> str:
        """Transform email to its canonical form."""

        email_localpart, email_domain = email.split("@")
        # the domain part of an email address is case-insensitive (RFC1035)
        email_domain = email_domain.lower()

        # Step 1: If no rules are set, do not modify local-part
        if not self.config:
            return f"{email_localpart}@{email_domain}"

        # Step 2: Apply rules defined for all services before applying rules defined for specific services
        if self.ALL in self.config:
            email_localpart = self._apply_rules(email_localpart, self.ALL)

        # Step 3: Apply rules definied for each email service if email matches service
        for service in (s for s in self.config if s != self.ALL):
            service = service.lower()  # ensure domain is lowercase
            apply_rules = False

            if email_domain == service:
                apply_rules = True
            elif self.config[service].get("aliases"):
                service_aliases = [
                    a.lower() for a in self.config[service]["aliases"]
                ]  # ensure domain aliases are lowercase
                if email_domain in service_aliases:
                    # email domain is an alias of the service. Change it to the service's primary domain name.
                    email_domain = service
                    apply_rules = True

            if apply_rules:
                email_localpart = self._apply_rules(email_localpart, service)

        return f"{email_localpart}@{email_domain}"

    def _apply_rules(self, email_localpart: str, service: str) -> str:
        assert self.config
        config = self.config[service]

        if config.get(self.IGNORE_CASE_RULE, False):
            email_localpart = email_localpart.lower()
        if config.get(self.IGNORE_DOTS_RULE, False):
            email_localpart = email_localpart.replace(".", "")
        if config.get(self.SUB_ADDRESSING_RULE, False):
            delim = config.get(self.SUB_ADDRESSING_DELIM, self.SUB_ADDRESSING_DELIM_DEFAULT)
            if email_localpart.find(delim) > -1:
                email_localpart = email_localpart[: email_localpart.index(delim)]

        return email_localpart