Skip to content

Commit fd3aabb

Browse files
authored
Allow override of db name (#89)
1 parent 3669be4 commit fd3aabb

File tree

2 files changed

+65
-3
lines changed

2 files changed

+65
-3
lines changed

druzhba/table.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,10 @@ class TableConfig(object):
174174
include_comments : boolean, optional
175175
flag to specify whether or not to ingest table and column comments
176176
when building or rebuilding the target table.
177+
override_db_name : str, optional
178+
override database name to use in the index tracking table instead of
179+
the actual database name. This allows tracking multiple databases or
180+
environments under a common name for index purposes.
177181
178182
Attributes
179183
----------
@@ -241,6 +245,7 @@ def __init__(
241245
include_comments=True,
242246
monitor_tables_config=None,
243247
lookback_value=0,
248+
override_db_name=None,
244249
):
245250
self.database_alias = database_alias
246251
self.db_host = db_connection_params.host
@@ -282,6 +287,7 @@ def __init__(
282287
self.monitor_tables_config = monitor_tables_config
283288
self.lookback_value = lookback_value
284289
self._lookback_index_value = "notset"
290+
self.override_db_name = override_db_name
285291

286292
self.date_key = datetime.datetime.strftime(
287293
datetime.datetime.utcnow(), "%Y%m%dT%H%M%S"
@@ -301,6 +307,14 @@ def __init__(
301307
self.logger = logging.getLogger(f"druzhba.{database_alias}.{source_table_name}")
302308
self.s3 = Session().client("s3")
303309

310+
@property
311+
def index_db_name(self):
312+
"""Returns the database name to use for index tracking operations.
313+
314+
Returns override_db_name if set, otherwise returns the actual db_name.
315+
"""
316+
return self.override_db_name if self.override_db_name else self.db_name
317+
304318
@classmethod
305319
def _clean_type_map(cls, type_map):
306320
if not type_map:
@@ -579,7 +593,7 @@ def _load_old_index_value(self):
579593
self.logger.debug("Querying Redshift for last updated index")
580594
with get_redshift().cursor() as cur:
581595
cur.execute(
582-
query, (self.database_alias, self.db_name, self.source_table_name)
596+
query, (self.database_alias, self.index_db_name, self.source_table_name)
583597
)
584598
index_value = cur.fetchone()
585599

@@ -627,7 +641,7 @@ def _load_lookback_index_value(self):
627641
self.logger.debug("Querying Redshift for nth last updated index to lookback")
628642
with get_redshift().cursor() as cur:
629643
cur.execute(
630-
query, (self.database_alias, self.db_name, self.source_table_name)
644+
query, (self.database_alias, self.index_db_name, self.source_table_name)
631645
)
632646
index_value = cur.fetchone()
633647

@@ -849,7 +863,7 @@ def set_last_updated_index(self):
849863
with get_redshift().cursor() as cur:
850864
args = (
851865
self.database_alias,
852-
self.db_name,
866+
self.index_db_name,
853867
self.source_table_name,
854868
new_index_value,
855869
)

test/unit/test_table.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1087,3 +1087,51 @@ def test_parse_invalid(self):
10871087
x = "something"
10881088
output = Permissions.parse(x)
10891089
self.assertEqual(output, None)
1090+
1091+
def test_override_db_name(self):
1092+
"""Test that override_db_name parameter works correctly for index tracking"""
1093+
from druzhba.table import TableConfig
1094+
from druzhba.db import ConnectionParams
1095+
1096+
connection_params = ConnectionParams(
1097+
name="test_db",
1098+
host="localhost",
1099+
port=5432,
1100+
user="test_user",
1101+
password="test_pass",
1102+
additional={}
1103+
)
1104+
1105+
# Test without override - should use actual db_name
1106+
table_config = TableConfig(
1107+
database_alias="test_alias",
1108+
db_connection_params=connection_params,
1109+
destination_table_name="dest_table",
1110+
destination_schema_name="public",
1111+
source_table_name="source_table",
1112+
index_schema="idx_schema",
1113+
index_table="idx_table"
1114+
)
1115+
1116+
self.assertEqual(table_config.index_db_name, "test_db")
1117+
self.assertEqual(table_config.db_name, "test_db")
1118+
1119+
# Test with override - should use override_db_name
1120+
table_config_with_override = TableConfig(
1121+
database_alias="test_alias",
1122+
db_connection_params=connection_params,
1123+
destination_table_name="dest_table",
1124+
destination_schema_name="public",
1125+
source_table_name="source_table",
1126+
index_schema="idx_schema",
1127+
index_table="idx_table",
1128+
override_db_name="override_name"
1129+
)
1130+
1131+
self.assertEqual(table_config_with_override.index_db_name, "override_name")
1132+
self.assertEqual(table_config_with_override.db_name, "test_db")
1133+
self.assertEqual(table_config_with_override.override_db_name, "override_name")
1134+
1135+
1136+
if __name__ == "__main__":
1137+
unittest.main()

0 commit comments

Comments
 (0)