Commit a7b9f1fb authored by Grant, Josh's avatar Grant, Josh
Browse files

Merge branch 'fix_db_uris' into 'develop'

Fix db uris

See merge request !40
parents b4c67bdd 683aa515
Loading
Loading
Loading
Loading
+37 −1
Original line number Diff line number Diff line
@@ -187,9 +187,44 @@ class Database(ABC):
            db_host=None,
            db_port=None,
            db_schema=None,
            db_type=None,
            db_engine=None,
            db_timeout=None
    ):
        """
        Creates the complete connection_info dictionary to use for database connection.

        This method generates a dictionary containing all the necessary information for
        establishing a connection to a database. It constructs the connection URI based
        on the provided parameters and defaults to certain values if parameters are not
        provided.

        Parameters:
        db_name (str, optional): The name of the database. Defaults to Database.DEFAULT_DB.
        db_user (str, optional): The database user. Defaults to Database.DEFAULT_USER.
        db_password (str, optional): The password for the database user. Defaults to Database.DEFAULT_PW.
        db_host (str, optional): The host where the database is located. Defaults to Database.DEFAULT_HOST.
        db_port (int, optional): The port on which the database is listening. Defaults to Database.DEFAULT_PORT.
        db_schema (str, optional): The schema to use within the database. Defaults to Database.DEFAULT_SCHEMA.
        db_type (str, optional): The type of the database (e.g., 'postgres', 'mssql'). Used to infer db_engine.
        db_engine (str, optional): The SQLAlchemy database engine string (e.g., 'postgresql', 'mssql+pymssql').
                                   If not provided, it is inferred from db_type or defaults to Database.DEFAULT_ENGINE.
        db_timeout (int, optional): The timeout setting for the database connection. Defaults to Database.DEFAULT_TIMEOUT.
        """
        if db_type and not db_engine:
            if db_type.startswith("postgres") or db_type.startswith("pg"):
                db_engine = 'postgresql'
            elif db_type.startswith("mssql"):
                db_engine = 'mssql+pymssql'
            elif db_type.startswith("influx"):
                db_engine = 'influxdb'
            elif db_type.startswith("sqlite"):
                db_engine = 'sqlite'
        if db_engine is not None and db_engine.startswith('sqlite'):
            uri = f'{db_engine}://{db_name}'
        else:
            uri = f"{db_engine if db_engine is not None else Database.DEFAULT_ENGINE}://{db_user if db_user is not None else Database.DEFAULT_USER}:{db_password if db_password is not None else Database.DEFAULT_PW}@{db_host if db_host is not None else Database.DEFAULT_HOST}:{db_port if db_port is not None else Database.DEFAULT_PORT}/{db_name if db_name is not None else Database.DEFAULT_DB}"

        connection_info = {
            "dbName": db_name if db_name is not None else Database.DEFAULT_DB,
            "dbUser": db_user if db_user is not None else Database.DEFAULT_USER,
@@ -199,6 +234,7 @@ class Database(ABC):
            "dbTimeout": db_timeout if db_timeout is not None else Database.DEFAULT_TIMEOUT,
            "dbSchema": db_schema if db_schema is not None else Database.DEFAULT_SCHEMA,
            "dbEngine": db_engine if db_engine is not None else Database.DEFAULT_ENGINE,
            "uri": f"{db_engine if db_engine is not None else Database.DEFAULT_ENGINE}://{db_user if db_user is not None else Database.DEFAULT_USER}:{db_password if db_password is not None else Database.DEFAULT_PW}@{db_host if db_host is not None else Database.DEFAULT_HOST}/{db_name if db_name is not None else Database.DEFAULT_DB}",
            "uri": uri,
        }
        return connection_info
+1 −0
Original line number Diff line number Diff line
@@ -31,6 +31,7 @@ class MultiDatabase:
                        db_password=config.get('DATABASE_PW', None),
                        db_host=config.get('DATABASE_HOST', None),
                        db_port=config.get('DATABASE_PORT', None),
                        db_type=db_type
                    )
                    db_instance = type(db_type.capitalize() + "DB", (mixin, Database), {})(connection_info=conn_info)
                    self.databases[db_id] = db_instance
+31 −0
Original line number Diff line number Diff line
import pytest
from common.mixins.multiple import MultiDatabase

@pytest.fixture
def multi_database_fixture():
    infl_id = "influx_id"
    influx = {"type": "influx", "id": infl_id, "host": "influx", "user": 'influx',
              "password": 'secret', "database": "something", "port": '8086'}

    pg_id = "pg_db"
    postgres = {"type": "postgres", "id": pg_id, "host": "localhost", "user": 'me',
                "password": 'nevergonnaguess', "database": "db", "port": '5435'}

    configs = [influx, postgres]
    multi = MultiDatabase(configs)
    return multi, pg_id, infl_id

def test_postgres_uri(multi_database_fixture):
    multi, pg_id, infl_id = multi_database_fixture
    pg_expected_uri = 'postgresql://me:nevergonnaguess@localhost:5435/db'
    actual_pg_uri = multi.databases[pg_id].connection_info["uri"]
    assert actual_pg_uri == pg_expected_uri

def test_influx_uri(multi_database_fixture):
    multi, pg_id, infl_id = multi_database_fixture
    influx_expected_uri = 'influxdb://influx:secret@influx:8086/something'
    actual_influx_uri = multi.databases[infl_id].connection_info["uri"]
    assert actual_influx_uri == influx_expected_uri

if __name__ == "__main__":
    pytest.main()