Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 49 additions & 15 deletions src/dbt_osmosis/core/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,33 +330,50 @@ def inject_missing_columns(

from dbt_osmosis.core.introspection import normalize_column_name

incoming_columns = get_columns(context, node)
output_to_upper = _get_setting_for_node(
"output-to-upper", node, fallback=context.settings.output_to_upper
)
output_to_lower = _get_setting_for_node(
"output-to-lower", node, fallback=context.settings.output_to_lower
)
case_insensitive = output_to_upper or output_to_lower
current_columns = {
normalize_column_name(c.name, context.project.runtime_cfg.credentials.type)
normalize_column_name(c.name, context.project.runtime_cfg.credentials.type).lower()
if case_insensitive
else normalize_column_name(c.name, context.project.runtime_cfg.credentials.type)
for c in node.columns.values()
}
incoming_columns = get_columns(context, node)

for incoming_name, incoming_meta in incoming_columns.items():
if incoming_name not in current_columns:
compare_name = incoming_name.lower() if case_insensitive else incoming_name
if compare_name not in current_columns:
logger.info(
":heavy_plus_sign: Reconciling missing column => %s in node => %s",
incoming_name,
node.unique_id,
)
gen_col = {"name": incoming_name, "description": incoming_meta.comment or ""}
final_name = incoming_name
if output_to_upper:
final_name = incoming_name.upper()
elif output_to_lower:
final_name = incoming_name.lower()

gen_col = {"name": final_name, "description": incoming_meta.comment or ""}
if (dtype := incoming_meta.type) and not _get_setting_for_node(
"skip-add-data-types",
node,
fallback=context.settings.skip_add_data_types,
):
if context.settings.output_to_upper:
if output_to_upper:
gen_col["data_type"] = dtype.upper()
elif context.settings.output_to_lower:
elif output_to_lower:
gen_col["data_type"] = dtype.lower()
else:
gen_col["data_type"] = dtype
node.columns[incoming_name] = ColumnInfo.from_dict(gen_col)
if hasattr(node.columns[incoming_name], "config"):
delattr(node.columns[incoming_name], "config")
node.columns[final_name] = ColumnInfo.from_dict(gen_col)
if hasattr(node.columns[final_name], "config"):
delattr(node.columns[final_name], "config")


@_transform_op("Remove Extra Columns")
Expand All @@ -365,7 +382,7 @@ def remove_columns_not_in_database(
node: ResultNode | None = None,
) -> None:
"""Remove columns from a dbt node and it's corresponding yaml section that are not present in the database. Changes are implicitly buffered until commit_yamls is called."""
from dbt_osmosis.core.introspection import get_columns, normalize_column_name
from dbt_osmosis.core.introspection import _get_setting_for_node, get_columns, normalize_column_name
from dbt_osmosis.core.node_filters import _iter_candidate_nodes

if node is None:
Expand All @@ -376,8 +393,19 @@ def remove_columns_not_in_database(
):
...
return
output_to_upper = _get_setting_for_node(
"output-to-upper", node, fallback=context.settings.output_to_upper
)
output_to_lower = _get_setting_for_node(
"output-to-lower", node, fallback=context.settings.output_to_lower
)
case_insensitive = output_to_upper or output_to_lower
current_columns = {
normalize_column_name(c.name, context.project.runtime_cfg.credentials.type): key
(
normalize_column_name(c.name, context.project.runtime_cfg.credentials.type).lower()
if case_insensitive
else normalize_column_name(c.name, context.project.runtime_cfg.credentials.type)
): key
for key, c in node.columns.items()
}
incoming_columns = get_columns(context, node)
Expand All @@ -387,7 +415,10 @@ def remove_columns_not_in_database(
node.unique_id,
)
return
extra_columns = set(current_columns.keys()) - set(incoming_columns.keys())
incoming_keys = (
{k.lower() for k in incoming_columns} if case_insensitive else set(incoming_columns.keys())
)
extra_columns = set(current_columns.keys()) - incoming_keys
for extra_column in extra_columns:
logger.info(
":heavy_minus_sign: Removing extra column => %s in node => %s",
Expand Down Expand Up @@ -527,6 +558,7 @@ def synchronize_data_types(
return
logger.info(":1234: Synchronizing data types => %s", node.unique_id)
incoming_columns = get_columns(context, node)
incoming_columns_lower = {k.lower(): v for k, v in incoming_columns.items()}
if _get_setting_for_node("skip-add-data-types", node, fallback=False):
return
for name, column in node.columns.items():
Expand All @@ -549,9 +581,11 @@ def synchronize_data_types(
name,
fallback=context.settings.output_to_upper,
)
if inc_c := incoming_columns.get(
normalize_column_name(name, context.project.runtime_cfg.credentials.type),
):
normalized = normalize_column_name(name, context.project.runtime_cfg.credentials.type)
inc_c = incoming_columns.get(normalized)
if inc_c is None and (lowercase or uppercase):
inc_c = incoming_columns_lower.get(normalized.lower())
if inc_c:
is_lower = column.data_type and column.data_type.islower()
if inc_c.type:
if uppercase:
Expand Down
Loading