Commit c43695e9 authored by Grant's avatar Grant
Browse files

add a MultipleDatabase class

parent 433b0d97
Loading
Loading
Loading
Loading
+28 −3
Original line number Diff line number Diff line
@@ -2,6 +2,10 @@
# -*- coding: utf-8 -*-
"""Provide common utilities for checking the environment."""
import os
from contextlib import contextmanager

TRUE_SET = {1, "1", "TRUE", "True", "true", True, "yes", "y", "T", "t"}
FALSE_SET = {0, "0", "FALSE", "False", "false", False, "no", "n", "F", "f"}


def boolify(var):
@@ -9,11 +13,11 @@ def boolify(var):

    :param var: the variable to check to see if it can be converted to bool
    """
    if var in [0, "0", "FALSE", "False", "false", False]:
    if var in FALSE_SET:
        return False
    if var in [1, "1", "TRUE", "True", "true", True]:
    if var in TRUE_SET:
        return True
    raise TypeError("unable to evaluate expected boolean")
    raise TypeError(f"unable to evaluate expected boolean value: {var}")


def check_environment(env_var, default=None):
@@ -60,3 +64,24 @@ def check_multi_environment(env_var_multi, multi_value, env_var, default=None):
        return check_environment(env_var_multi, multi_value)
    # set environment for vanilla and return vanilla
    return check_environment(env_var, default)


@contextmanager
def mock_env_vars(temp_vars: dict):
    """Mock environment variables.

    :param temp_vars: a dictionary of the temporary variables in the form of key: name
    """
    # store the original values
    original = {var: os.environ.get(var) for var in temp_vars}
    # apply the temp_vars dict to the environment
    os.environ.update(temp_vars)
    try:
        yield
    finally:
        # restore original values
        for var, value in original.items():
            if value is None:
                del os.environ[var]  # remove the var if not originally set
            else:
                os.environ[var] = value  # restore the original value
+131 −0
Original line number Diff line number Diff line
#! /usr/bin/env python3
# -*- coding: utf-8 -*-
"""
This is a MultiDatabase class using multiple Database Mixins
"""
from typing import Type, Callable, Optional
from uuid import uuid4
from concurrent.futures import ThreadPoolExecutor

from common.env import mock_env_vars
from common.database import Database
from common.logz import create_logger
from common.mixins.postgres import PostgresMixin
from common.mixins.mssql import MSSQLMixin
from common.mixins.influx import InfluxMixin


class MultiDatabase:
    def __init__(self, configs: list[dict]):
        self.databases = {}
        self.logger = create_logger()
        for config in configs:
            db_type: str = config.get("type")
            db_id: str = config.get("id", str(uuid4()))
            self.logger.debug(f"Creating {db_type} database with id {db_id}")
            with mock_env_vars(config):
                mixin = self._get_mixin_for_type(db_type=db_type)
                if mixin:
                    db_instance = type(
                        db_type.capitalize() + "DB", (mixin, Database), {}
                    )()
                    self.databases[db_id] = db_instance
                else:
                    self.logger.warning(f"Skipping unknown Database type: {db_type}")

    @staticmethod
    def _get_mixin_for_type(
        db_type: str,
    ) -> Type[PostgresMixin, MSSQLMixin, InfluxMixin, None]:
        db_type = db_type.lower()
        if db_type.startswith("postgres") or db_type.startswith("pg"):
            return PostgresMixin
        if db_type.startswith("mssql"):
            return MSSQLMixin
        if db_type.startswith("influx"):
            return InfluxMixin
        # add more database Mixins here as they become available
        return None

    def open(self):
        with ThreadPoolExecutor(max_workers=len(self.databases)) as executor:
            for db in self.databases.values():
                executor.submit(db.open)

    def close(self):
        with ThreadPoolExecutor(max_workers=len(self.databases)) as executor:
            for db in self.databases.values():
                executor.submit(db.close)

    def query(self, query):
        results = {}
        with ThreadPoolExecutor(max_workers=len(self.databases)) as executor:
            future_to_db_id = {
                executor.submit(db.query, query): db_id
                for db_id, db in self.databases.items()
            }
            for future in future_to_db_id:
                db_id = future_to_db_id[future]
                results[db_id] = future.result()
        return results

    def transform_and_transfer(
        self,
        from_db_id: str,
        to_db_id: str,
        select_query: str,
        insert_query_template: str,
        transform_func: Optional[Callable] = None,
    ):
        """
        Converts data from a PostgreSQL database to a Microsoft SQL Server database.

        Args:
        from_db_id (str): The ID of the PostgreSQL database.
        to_db_id (str): The ID of the SQL Server database.
        select_query (str): SQL query to select data from PostgreSQL.
        insert_query_template (str): SQL template for inserting data into SQL Server.

        Returns:
        bool: True if the conversion was successful, False otherwise.
        """
        try:
            # Execute the select query on PostgreSQL
            source_results = self.databases[from_db_id].query(select_query)
            if not source_results:
                self.logger.info(f"No data found to convert from {from_db_id}.")
                return False

            operations = []
            for row in source_results:
                transformed_row = transform_func(row) if transform_func else row
                insert_query = insert_query_template.format(*transformed_row)
                operations.append(insert_query)

            # Insert data into SQL Server
            self._execute_parallel_db_operations(to_db_id, operations)

            self.logger.info(
                f"Data conversion from {from_db_id} to {to_db_id} completed successfully."
            )
            return True
        except Exception as e:
            self.logger.error(
                f"Failed to convert data from {from_db_id} to {to_db_id}: {e}"
            )
            return False

    def _execute_parallel_db_operations(self, db_id: str, operations: list):
        """
        Executes database operations in parallel.

        Args:
            db_id (str): the ID of the database to execute operations in SQL
            operations (list): A list of SQL operations to execute
        """
        with ThreadPoolExecutor(max_workers=1) as executor:
            futures = [
                executor.submit(self.databases[db_id].query, op) for op in operations
            ]
            for future in futures:
                future.result()  # this will raise an exception if any op fails, may want to catch and report