Commit 98b052fc authored by Huihui, Jonathan's avatar Huihui, Jonathan
Browse files

Merge branch 'multi_database_with_auth' into 'develop'

Multi database with auth

See merge request !36
parents ad5fcd61 325e87d6
Loading
Loading
Loading
Loading
+1 −1
Original line number Diff line number Diff line
@@ -4,7 +4,7 @@ verify_ssl = true
name = "pypi"

[packages]
common = {editable = true, directory = "./src"}
common-package = {editable = true, path = "./src"}
twine = "*"

[dev-packages]

src/common/auth.py

0 → 100644
+77 −0
Original line number Diff line number Diff line
"""Provide authentication methods."""

import hashlib
import os
from typing import Any, Callable, Dict
from urllib.parse import urlencode, urlparse

from common.env import check_environment as ce

try:
    from ldap3 import ALL, SUBTREE, Connection, Server
    from ldap3.core.exceptions import LDAPBindError, LDAPException
except ImportError:
    import sys

    from common.logz import create_logger

    log = create_logger()
    log.warn("To use this module, install common-package[auth] extra.")
    sys.exit(1)


def authenticate_ldap_user(uid, password):
    """
    Authenticates a user against an LDAP server using their user ID and password.

    This function retrieves the necessary LDAP configuration from environment variables,
    establishes a connection to the LDAP server, and attempts to bind with the provided
    credentials. If the binding is successful, it searches for the user's entry and returns it.

    Parameters:
    - uid (str): The user ID of the LDAP account to authenticate.
    - password (str): The password for the LDAP account.

    Returns:
    - ldap3.Entry: The LDAP entry of the authenticated user if successful.
    - None: If the authentication fails (e.g., incorrect credentials or issues with server connection).

    Environment Variables:
    - LDAP_SERVER: URL of the LDAP server. Default is "ldaps://ldapx.ornl.gov".
    - LDAP_ROOT_DN: The root distinguished name (DN) for LDAP queries. Default is "dc=xcams,dc=ornl,dc=gov".
    - LDAP_USER_DN: Template for constructing the user's DN. Default is "uid={uid},ou=Users".
    - LDAP_USER_SEARCH_FILTER: LDAP search filter to find the user. Default is "(uid={uid})".

    Example:
    To authenticate a user with ID 'jdoe' and password 'securepassword', you can call:
    authenticate_ldap_user('jdoe', 'securepassword')

    Raises:
    - ldap3.core.exceptions.LDAPException: If there is an issue connecting to the LDAP server or during the search.
    """
    ldap_server = ce("LDAP_SERVER", "ldaps://ldapx.ornl.gov")
    root_dn = ce("LDAP_ROOT_DN", "dc=xcams,dc=ornl,dc=gov")
    user_dn = ce("LDAP_USER_DN", f"uid={uid},ou=Users")
    user_search_filter = ce("LDAP_USER_SEARCH_FILTER", f"(uid={uid})")
    dn = f"{user_dn},{root_dn}"
    server = Server(ldap_server, get_info=ALL)
    connection = Connection(server, user=dn, password=password)
    # check if binding to the connection works
    if not connection.bind():
        return None
    connection.search(root_dn, user_search_filter, attributes=["*"])
    return connection.entries[0]


def generate_salt() -> str:
    return str(os.urandom(32)).replace("\\", "").replace("b", "")


def hash_password(
    password: str,
    salt: str,
    hash_algorithm: Callable[..., Any] = hashlib.pbkdf2_hmac,
    *args: Any,
    **kwargs: Any,
) -> str:
    return hash_algorithm(password.encode(), salt.encode(), *args, **kwargs).hex()
+8 −24
Original line number Diff line number Diff line
@@ -76,8 +76,7 @@ class CRUDTable(ABC):
        # assure that all columns are defined
        # FIXME - asserts should only be used inside tests
        assert all(arg in self.columns.keys() for arg in kwargs.keys()), (
            "Must supply values for all columns to create an entry. "
            + f"Columns: {self.columns}"
            "Must supply values for all columns to create an entry. " + f"Columns: {self.columns}"
        )
        # get a 'pretty' string of column names
        column_names = str(list(kwargs.keys()))[1:-1].replace("'", "")
@@ -88,9 +87,7 @@ class CRUDTable(ABC):
            f"VALUES ({', '.join(['%s']*len(kwargs.keys()))})"
        )
        # tell the user that we are executing an insert query
        self.logger.info(
            f"Executing query: {query} " + f"params: {tuple(kwargs.values())}"
        )
        self.logger.info(f"Executing query: {query} " + f"params: {tuple(kwargs.values())}")
        try:
            # if the db is not open..
            if not self.db.is_open():
@@ -169,17 +166,13 @@ class CRUDTable(ABC):
            else:
                # raise an exception as we do not know what to do
                raise TypeError(
                    "column argument should be of type list or"
                    + f" str not {type(columns)}"
                    "column argument should be of type list or" + f" str not {type(columns)}"
                )
        # if the where clause was not set/specified
        if where_clause is None:
            try:
                # make a query to select the columns with no where clause
                query = (
                    f"SELECT {select_clause} "
                    + f"FROM {self.schema}.{self.__class__.__name__}"
                )
                query = f"SELECT {select_clause} " + f"FROM {self.schema}.{self.__class__.__name__}"
                # inform the user we are executing the query..
                self.logger.info(f"Executing query: {query}")
                # if the db is not already opened..
@@ -191,9 +184,7 @@ class CRUDTable(ABC):
                # execute the query
                curr.execute(query)
            except Exception as err:
                self.logger.error(
                    "Exception occured when trying to execute " + f"query: {query}"
                )
                self.logger.error("Exception occured when trying to execute " + f"query: {query}")
                self.logger.error(f"Exception Message: {err}")

        # if there is a where clause...
@@ -206,9 +197,7 @@ class CRUDTable(ABC):
                    + f"{where_clause[0]}"
                )
                # tell the user we are executing their query
                self.logger.info(
                    f"Executing query: {query} " + f"params: {where_clause[1]}"
                )
                self.logger.info(f"Executing query: {query} " + f"params: {where_clause[1]}")
                # if the db is not already opened..
                if not self.db.is_open():
                    # open the connection to the database
@@ -315,14 +304,9 @@ class CRUDTable(ABC):
            # construct the where clause with the conversion method
            where_clause = convert_to_where(kwargs)
            # build a delete query with the specified values
            query = (
                f"DELETE FROM {self.schema}.{self.__class__.__name__} "
                + f"{where_clause[0]}"
            )
            query = f"DELETE FROM {self.schema}.{self.__class__.__name__} " + f"{where_clause[0]}"
            # tell the user that we are executing their query
            self.logger.info(
                f"Executing query: {query}, " + f"params: {where_clause[1]}"
            )
            self.logger.info(f"Executing query: {query}, " + f"params: {where_clause[1]}")
            # if the db is not open..
            if not self.db.is_open():
                # open a connection to the database
+25 −2
Original line number Diff line number Diff line
@@ -65,8 +65,7 @@ class Database(ABC):
    DEFAULT_LOG_ENCODING = ce("DATABASE_LOG_ENCODING", "utf-8")
    # define a URI string if URI is perferred to connect
    DEFAULT_URI = (
        f"{DEFAULT_ENGINE}://{DEFAULT_USER}:{str(DEFAULT_PW)}"
        + f"@{DEFAULT_HOST}/{DEFAULT_DB}"
        f"{DEFAULT_ENGINE}://{DEFAULT_USER}:{str(DEFAULT_PW)}" + f"@{DEFAULT_HOST}/{DEFAULT_DB}"
    )

    def __init__(
@@ -179,3 +178,27 @@ class Database(ABC):
        :param value: the value for the connection variable
        """
        self.connection_info[variable] = value

    @staticmethod
    def create_connection_info(
            db_name=None,
            db_user=None,
            db_password=None,
            db_host=None,
            db_port=None,
            db_schema=None,
            db_engine=None,
            db_timeout=None
    ):
        connection_info = {
            "dbName": db_name if db_name is not None else Database.DEFAULT_DB,
            "dbUser": db_user if db_user is not None else Database.DEFAULT_USER,
            "dbPassword": db_password if db_password is not None else Database.DEFAULT_PW,
            "dbHost": db_host if db_host is not None else Database.DEFAULT_HOST,
            "dbPort": db_port if db_port is not None else Database.DEFAULT_PORT,
            "dbTimeout": db_timeout if db_timeout is not None else Database.DEFAULT_TIMEOUT,
            "dbSchema": db_schema if db_schema is not None else Database.DEFAULT_SCHEMA,
            "dbEngine": db_engine if db_engine is not None else Database.DEFAULT_ENGINE,
            "uri": f"{db_engine if db_engine is not None else Database.DEFAULT_ENGINE}://{db_user if db_user is not None else Database.DEFAULT_USER}:{db_password if db_password is not None else Database.DEFAULT_PW}@{db_host if db_host is not None else Database.DEFAULT_HOST}/{db_name if db_name is not None else Database.DEFAULT_DB}",
        }
        return connection_info
+2 −3
Original line number Diff line number Diff line
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
'''
"""
This module defines a data class 'ErrorCodes' that instantiates various error codes used throughout the application.
It categorizes error codes into distinct sections for database operations, scraping processes,
templating issues, and provides a default error code for general use. Each error type is associated with specific
integer values, making it easier to manage and identify errors consistently across different components of the application.
'''
"""
from dataclasses import dataclass


@dataclass
class ErrorCodes:

    # Database Errors
    DB_CONNECTION_FAILED: int = 1001
    DB_TIMEOUT: int = 1002
Loading