Commit f22ad5e1 authored by Grant's avatar Grant
Browse files

some changes

parents 1af577d6 5a590dc4
Loading
Loading
Loading
Loading
+184 −52
Original line number Diff line number Diff line
"""Provide authentication methods."""
from typing import Dict, Any, Callable
from urllib.parse import urlparse, urlencode
import hashlib
import os

from common.mail import send_email
from common.env import check_environment as ce
from common.database import Database

try:
    from onelogin.saml2.auth import OneLogin_Saml2_Auth
    from ldap3 import Server, Connection, ALL, SUBTREE
    from ldap3.core.exceptions import LDAPException, LDAPBindError
    from flask import jsonify

except ImportError:
    import sys
    from common.logz import create_logger
@@ -59,81 +66,206 @@ def authenticate_ldap_user(uid, password):
    return connection.entries[0]


def authenticate_saml(
    scheme,
    host,
    url,
    method,
    args,
    form,
    process_success=lambda _: None,
    process_failure=lambda _: None,
    saml_path=None,
def authenticate_saml_user(
    request_scheme: str,
    request_host: str,
    request_url: str,
    request_method: str,
    request_args: dict,
    request_form: dict,
    success_handler: Callable = lambda _: None,
    failure_handler: Callable = lambda _: None,
    saml_path: str = None,
    logger=None,
):
    """
    Handles SAML authentication by processing the necessary components extracted from a web request.

    This function manages SAML Single Sign-On (SSO), Single Log-Out (SLO), and Assertion Consumer Service (ACS)
    endpoints. It accepts individual components typically found in a web request and utilizes them to process
    SAML requests and responses, delegating outcome handling to provided callback functions.

    Parameters:
    - scheme (str): The request scheme ('http' or 'https').
    - host (str): The request host.
    - url (str): The full URL of the request.
    - method (str): The HTTP method of the request.
    - args (dict): The query parameters of the request as a dictionary.
    - form (dict): The form data of the request as a dictionary.
    - process_success (Callable): Function to call after successful authentication or logout to handle the response.
    - process_failure (Callable): Function to call when an error occurs or no relevant SAML action is found.
    - request_scheme (str): The request scheme ('http' or 'https').
    - request_host (str): The request host.
    - request_url (str): The full URL of the request.
    - request_method (str): The HTTP method of the request.
    - request_args (dict): The query parameters of the request as a dictionary.
    - request_form (dict): The form data of the request as a dictionary.
    - success_handler (Callable, optional): Function to handle successful SAML authentication.
      Defaults to a no-op function.
    - failure_handler (Callable, optional): Function to handle SAML authentication failure.
      Defaults to a no-op function.
    - saml_path (str, optional): Custom path for SAML configuration files. If not provided, environment variable
      'SAML_PATH' is used.
    - logger (Logger, optional): Logger object for logging. If not provided, no logging is performed.

    Returns:
    - Result of the `process_success` or `process_failure` callable depending on the SAML processing outcome.

    Raises:
    - Exception: Descriptive exceptions can be raised depending on the SAML processing errors or misconfigurations.

    Usage:
    This function is designed to be framework-agnostic and should be integrated into any web application by adapting
    the request handling to extract necessary components.
    Example:
    # Extract necessary components from a Flask or Django request and pass them to this function.
    result = authenticate_saml(request.scheme, request.host, request.url, request.method, request.args, request.form,
                               handle_success, handle_failure)
    - Result of the `success_handler` or `failure_handler` callable depending on the SAML processing outcome.
    """
    # Prepare request data for SAML auth
    url_data = urlparse(url)
    url_data = urlparse(request_url)
    saml_request = {
        "https": "on" if scheme == "https" else "off",
        "http_host": host,
        "https": "on" if request_scheme == "https" else "off",
        "http_host": request_host,
        "server_port": url_data.port,
        "script_name": url_data.path,
        "get_data": args,
        "post_data": form,
        "query_string": urlencode(args) if method == "GET" else "",
        "get_data": request_args,
        "post_data": request_form,
        "query_string": urlencode(request_args) if request_method == "GET" else "",
    }

    # Initialize SAML auth
    saml_path = saml_path or ce("SAML_PATH")
    auth = OneLogin_Saml2_Auth(saml_request, custom_base_path=saml_path)

    # SAML Action Handling
    try:
        if "sso" in args:
            return process_success(auth.login())
        elif "slo" in args:
            return process_success(auth.logout())
        elif "acs" in args:
        if "sso" in request_args:
            return success_handler(auth.login())
        elif "slo" in request_args:
            return success_handler(auth.logout())
        elif "acs" in request_args:
            auth.process_response(
                request_id=None
            )  # Assuming request carries all needed info
            if auth.is_authenticated():
                return process_success(None)  # pass user details / attributes
                return success_handler(None)  # pass user details / attributes
            else:
                return process_failure("Authentication failed or errors occurred")
                return failure_handler("SAML authentication failed")
        else:
            return failure_handler("No valid SAML action found")
    except Exception as e:
        error_msg = f"SAML authentication error: {str(e)}"
        if logger:
            logger.error(error_msg)
        return failure_handler(error_msg)


def generate_salt() -> str:
    return str(os.urandom(32)).replace("\\", "").replace("b", "")


def hash_password(
    password: str,
    salt: str,
    hash_algorithm: Callable[..., Any] = hashlib.pbkdf2_hmac,
    *args: Any,
    **kwargs: Any,
) -> str:
    return hash_algorithm(password.encode(), salt.encode(), *args, **kwargs).hex()


def login(
    username: str,
    password: str,
    db: str,
    table_name: str,
    username_column: str,
    password_column: str,
    salt_column: str,
) -> dict:
    user_data = db.query(
        f"SELECT {password_column}, {salt_column} FROM {table_name} WHERE {username_column} = '{username}'"
    )
    if len(user_data) == 0:
        return {"msg": "Username does not exist."}, 401

    user_pass, salt = user_data[0]
    pass_to_check = hash_password(password, salt)

    if pass_to_check == user_pass:
        # User authentication successful
        return {"msg": "Login successful!"}
    else:
            return process_failure("No SAML action found")
        return {"msg": "Invalid uid or password"}, 401


def register_user(
    username: str,
    password: str,
    confirm_pw: str,
    first_name: str,
    last_name: str,
    email: str,
    work_sector: str,
    user_type: str,
    reason: str,
    db: UNDB,
    table_name: str,
    username_column: str,
    password_column: str,
    salt_column: str,
) -> dict:
    # Check if username already exists
    if check_username_exists(username, db, table_name, username_column):
        return {"msg": "Username already exists."}, 401

    # Check if passwords match
    if password != confirm_pw:
        return {"msg": "Passwords do not match."}, 401

    # Generate salt and hash the password
    salt = generate_salt()
    hashed_pw = hash_password(password, salt)

    try:
        # Insert user data into the database
        db.cursor.execute(
            f"""INSERT INTO {table_name} ({username_column}, first_name, last_name, email, work_sector, 
                             user_type, {password_column}, {salt_column}, last_login_date, enabled) VALUES 
                             ('{username}', '{first_name}', '{last_name}', '{email}', '{work_sector}', 
                             '{user_type}', '{hashed_pw}', '{salt}', CURRENT_TIMESTAMP, false);"""
        )
        db.commit()
    except Exception as e:
        return {"msg": "Error registering new user account.", "error": str(e)}, 401

    # Send email notification
    os.environ["EMAIL_RECIPIENTS"] = "plattmw@ornl.gov, burdetteja@ornl.gov"
    os.environ["EMAIL_SENDER"] = "psplatial@ornl.gov"
    subject = "PS Platial Account Request"
    msg = f"{last_name}, {first_name} at {email} has requested a PlanetSense Platial account for the following reason: {reason}."
    send_email(subject, msg)

    return {
        "msg": "Thank you for submitting an account request. Our team will review your request and respond as soon as possible."
    }


def reset_password(
    username: str,
    token: str,
    new_password: str,
    confirm_new_password: str,
    db: Database,
    table_name: str,
    username_column: str,
    salt_column: str,
) -> dict:
    # Retrieve user's salt from the database
    user_salt = db.query(
        f"SELECT {salt_column} FROM {table_name} WHERE {username_column} = '{username}';"
    )
    if not user_salt:
        return {"msg": "Username not found."}, 401

    user_salt = str(user_salt[0][0]).replace("b", "")

    # Validate token
    token_to_match = db.query(
        f"SELECT email_reset_code FROM {table_name} WHERE {username_column} = '{username}';"
    )
    if token != token_to_match[0][0]:
        return {"msg": "Token does not match."}, 401

    # Check if passwords match
    if new_password != confirm_new_password:
        return {"msg": "Passwords do not match."}, 401

    # Generate new hashed password and update in the database
    new_hashed_pw = hash_password(new_password, user_salt)
    try:
        db.cursor.execute(
            f"UPDATE {table_name} SET {password_column} = '{new_hashed_pw}' WHERE {username_column} = '{username}';"
        )
        db.commit()
    except Exception as e:
        return process_failure(str(e))
        return {"msg": "Failed to update password.", "error": str(e)}, 401

    return {
        "msg": "Password updated successfully! Redirecting you back to login page now..."
    }