Commit 1ada04a4 authored by John Davis's avatar John Davis
Browse files

Factor out username_from_email; fix bug; add tests

The bug was introduced when moving this code to SQLAlchemy 2.0
The statement in generate_next_available_username should be evaluated at
each step of the iteration: only then will the new value of `i` result
in a new statement.
parent cf1fc553
Loading
Loading
Loading
Loading
+37 −13
Original line number Diff line number Diff line
@@ -664,23 +664,11 @@ class UserManager(base.ModelManager, deletable.PurgableManagerMixin):
                    self.app.security_agent.user_set_default_permissions(user)
                    self.app.security_agent.user_set_default_permissions(user, history=True, dataset=True)
        elif user is None:
            username = remote_user_email.split("@", 1)[0].lower()
            random.seed()
            user = self.app.model.User(email=remote_user_email)
            user.set_random_password(length=12)
            user.external = True
            # Replace invalid characters in the username
            for char in [x for x in username if x not in f"{string.ascii_lowercase + string.digits}-."]:
                username = username.replace(char, "-")
            # Find a unique username - user can change it later
            stmt = select(self.app.model.User).filter_by(username=username).limit(1)
            if self.session().scalars(stmt).first():
                i = 1
                stmt = select(self.app.model.User).filter_by(username=f"{username}-{str(i)}").limit(1)
                while self.session().scalars(stmt).first():
                    i += 1
                username += f"-{str(i)}"
            user.username = username
            user.username = username_from_email(self.session(), remote_user_email, self.app.model.User)
            self.session().add(user)
            with transaction(self.session()):
                self.session().commit()
@@ -896,3 +884,39 @@ def get_user_by_email(session, email: str, model_class=User, case_sensitive=True
def get_user_by_username(session, username: str, model_class=User):
    stmt = select(model_class).filter(model_class.username == username).limit(1)
    return session.scalars(stmt).first()


def username_from_email(session, email, model_class=User):
    """Get next available username generated based on email"""
    engine = session.bind
    with engine.connect() as connection:
        return username_from_email_with_connection(connection, email, model_class)


def username_from_email_with_connection(connection, email, model_class=User):
    # This function is also called from database revision scripts, which do not provide a session.
    username = email.split("@", 1)[0].lower()
    username = filter_out_invalid_username_characters(username)
    if username_exists(connection, username, model_class):
        username = generate_next_available_username(connection, username, model_class)
    return username


def filter_out_invalid_username_characters(username):
    """Replace invalid characters in username"""
    for char in [x for x in username if x not in f"{string.ascii_lowercase + string.digits}-."]:
        username = username.replace(char, "-")
    return username


def username_exists(connection, username: str, model_class=User):
    stmt = select(model_class).filter(model_class.username == username).limit(1)
    return bool(connection.execute(stmt).first())


def generate_next_available_username(connection, username, model_class=User):
    """Generate unique username; user can change it later"""
    i = 1
    while connection.execute(select(model_class).where(model_class.username == f"{username}-{i}")).first():
        i += 1
    return f"{username}-{i}"
+34 −0
Original line number Diff line number Diff line
from galaxy.managers.users import (
    filter_out_invalid_username_characters,
    username_exists,
    username_from_email,
)
from galaxy.model import User


def test_username_from_email(session, make_user):
    make_user(username="foo")
    next_username = username_from_email(session, "foo@.foo.com", User)
    assert next_username == "foo-1"  # because foo exists

    make_user(username="foo-1")
    next_username = username_from_email(session, "foo@.foo.com", User)
    assert next_username == "foo-2"  # because foo and foo-1 exist

    make_user(username="foo-2")
    next_username = username_from_email(session, "foo@.foo.com", User)
    assert next_username == "foo-3"  # because foo and foo-1 and foo-2 exist

    next_username = username_from_email(session, "bar@.foo.com", User)
    assert next_username == "bar"  # no change


def test_filter_out_invalid_username_characters():
    username = "abcDCE123$%^-."
    assert filter_out_invalid_username_characters(username) == "abc---123----."


def test_username_exists(session, make_user):
    make_user(username="foo", email="foo@foo.com")
    assert username_exists(session, "foo")
    assert not username_exists(session, "bar")