Commit 1fddf6cd authored by MacFarland's avatar MacFarland
Browse files

fixing uri creation in database

parent b4c67bdd
Loading
Loading
Loading
Loading
+37 −1
Original line number Diff line number Diff line
@@ -187,9 +187,44 @@ class Database(ABC):
            db_host=None,
            db_port=None,
            db_schema=None,
            db_type=None,
            db_engine=None,
            db_timeout=None
    ):
        """
        Creates the complete connection_info dictionary to use for database connection.

        This method generates a dictionary containing all the necessary information for
        establishing a connection to a database. It constructs the connection URI based
        on the provided parameters and defaults to certain values if parameters are not
        provided.

        Parameters:
        db_name (str, optional): The name of the database. Defaults to Database.DEFAULT_DB.
        db_user (str, optional): The database user. Defaults to Database.DEFAULT_USER.
        db_password (str, optional): The password for the database user. Defaults to Database.DEFAULT_PW.
        db_host (str, optional): The host where the database is located. Defaults to Database.DEFAULT_HOST.
        db_port (int, optional): The port on which the database is listening. Defaults to Database.DEFAULT_PORT.
        db_schema (str, optional): The schema to use within the database. Defaults to Database.DEFAULT_SCHEMA.
        db_type (str, optional): The type of the database (e.g., 'postgres', 'mssql'). Used to infer db_engine.
        db_engine (str, optional): The SQLAlchemy database engine string (e.g., 'postgresql', 'mssql+pymssql').
                                   If not provided, it is inferred from db_type or defaults to Database.DEFAULT_ENGINE.
        db_timeout (int, optional): The timeout setting for the database connection. Defaults to Database.DEFAULT_TIMEOUT.
        """
        if db_type and not db_engine:
            if db_type.startswith("postgres") or db_type.startswith("pg"):
                db_engine = 'postgresql'
            elif db_type.startswith("mssql"):
                db_engine = 'mssql+pymssql'
            elif db_type.startswith("influx"):
                db_engine = 'influxdb'
            elif db_type.startswith("sqlite"):
                db_engine = 'sqlite'
        if db_engine is not None and db_engine.startswith('sqlite'):
            uri = f'{db_engine}://{db_name}'
        else:
            uri = f"{db_engine if db_engine is not None else Database.DEFAULT_ENGINE}://{db_user if db_user is not None else Database.DEFAULT_USER}:{db_password if db_password is not None else Database.DEFAULT_PW}@{db_host if db_host is not None else Database.DEFAULT_HOST}:{db_port if db_port is not None else Database.DEFAULT_PORT}/{db_name if db_name is not None else Database.DEFAULT_DB}"

        connection_info = {
            "dbName": db_name if db_name is not None else Database.DEFAULT_DB,
            "dbUser": db_user if db_user is not None else Database.DEFAULT_USER,
@@ -199,6 +234,7 @@ class Database(ABC):
            "dbTimeout": db_timeout if db_timeout is not None else Database.DEFAULT_TIMEOUT,
            "dbSchema": db_schema if db_schema is not None else Database.DEFAULT_SCHEMA,
            "dbEngine": db_engine if db_engine is not None else Database.DEFAULT_ENGINE,
            "uri": f"{db_engine if db_engine is not None else Database.DEFAULT_ENGINE}://{db_user if db_user is not None else Database.DEFAULT_USER}:{db_password if db_password is not None else Database.DEFAULT_PW}@{db_host if db_host is not None else Database.DEFAULT_HOST}/{db_name if db_name is not None else Database.DEFAULT_DB}",
            "uri": uri,
        }
        return connection_info
+1 −0
Original line number Diff line number Diff line
@@ -31,6 +31,7 @@ class MultiDatabase:
                        db_password=config.get('DATABASE_PW', None),
                        db_host=config.get('DATABASE_HOST', None),
                        db_port=config.get('DATABASE_PORT', None),
                        db_type=db_type
                    )
                    db_instance = type(db_type.capitalize() + "DB", (mixin, Database), {})(connection_info=conn_info)
                    self.databases[db_id] = db_instance
+28 −0
Original line number Diff line number Diff line
import unittest
from common.mixins.multiple import MultiDatabase

class TestMultiDatabase(unittest.TestCase):
    def setUp(self):
        self.infl_id = "influx_id"
        self.influx = {"type": "influx", "id": self.infl_id, "host": "influx", "user": 'influx',
                       "password": 'secret', "database": "something", "port": '8086'}

        self.pg_id = "pg_db"
        self.postgres = {"type": "postgres", "id": self.pg_id, "host": "localhost", "user": 'me',
                         "password": 'nevergonnaguess', "database": "db", "port": '5435'}

        self.configs = [self.influx, self.postgres]
        self.multi = MultiDatabase(self.configs)

    def test_postgres_uri(self):
        pg_expected_uri = 'postgresql://me:nevergonnaguess@localhost:5435/db'
        actual_pg_uri = self.multi.databases[self.pg_id].connection_info["uri"]
        self.assertEqual(actual_pg_uri, pg_expected_uri)

    def test_influx_uri(self):
        influx_expected_uri = 'influxdb://influx:secret@influx:8086/something'
        actual_influx_uri = self.multi.databases[self.infl_id].connection_info["uri"]
        self.assertEqual(actual_influx_uri, influx_expected_uri)

if __name__ == "__main__":
    unittest.main()