Skip to content

Commit 0bd4552

Browse files
authored
Snowflake: support warehouse switching (#2492)
1 parent 3b92f39 commit 0bd4552

File tree

4 files changed

+62
-1
lines changed

4 files changed

+62
-1
lines changed

soda-core/src/soda_core/common/data_source_impl.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -235,3 +235,11 @@ def verify_if_table_exists(self, prefixes: list[str], table_name: str) -> bool:
235235
fully_qualified_table_name.table_name == table_name
236236
for fully_qualified_table_name in fully_qualified_table_names
237237
)
238+
239+
def switch_warehouse(self, warehouse: str) -> None:
240+
# Noop by default, only some data sources need to implement this
241+
pass
242+
243+
def get_current_warehouse(self) -> Optional[str]:
244+
# Noop by default, only some data sources need to implement this
245+
return None

soda-core/src/soda_core/common/soda_cloud_dto.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,12 @@ class DatasetConfigurationsDTO(BaseModel):
4646
dataset_configurations: list[DatasetConfigurationDTO] = Field(..., alias="datasetConfigurations")
4747

4848

49+
class ComputeWarehouseOverrideDTO(BaseModel):
50+
model_config = ConfigDict(populate_by_name=True, extra="allow")
51+
52+
name: str = Field(..., alias="name")
53+
54+
4955
class DatasetConfigurationDTO(BaseModel):
5056
model_config = ConfigDict(populate_by_name=True, extra="allow")
5157

@@ -61,6 +67,7 @@ class DatasetConfigurationDTO(BaseModel):
6167
test_row_sampler_configuration: Optional[TestRowSamplerConfigurationDTO] = Field(
6268
None, alias="testRowSamplerConfiguration"
6369
)
70+
compute_warehouse_override: Optional[ComputeWarehouseOverrideDTO] = Field(None, alias="computeWarehouseOverride")
6471

6572

6673
class TestRowSamplerAbsoluteLimitDTO(BaseModel):

soda-core/src/soda_core/contracts/impl/contract_verification_impl.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -410,11 +410,19 @@ def __init__(
410410
# TODO replace usage of self.sql_qualified_dataset_name with self.dataset_identifier
411411
self.sql_qualified_dataset_name: Optional[str] = None
412412

413+
self.datasource_warehouse: Optional[str] = None
414+
self.compute_warehouse: Optional[str] = None
415+
413416
if data_source_impl:
414417
# TODO replace usage of self.sql_qualified_dataset_name with self.dataset_identifier
415418
self.sql_qualified_dataset_name = data_source_impl.sql_dialect.qualify_dataset_name(
416419
dataset_prefix=self.dataset_prefix, dataset_name=self.dataset_name
417420
)
421+
if hasattr(data_source_impl.data_source_model, "warehouse"):
422+
self.datasource_warehouse = data_source_impl.data_source_model.warehouse
423+
424+
if self.datasource_warehouse is None:
425+
self.datasource_warehouse = data_source_impl.get_current_warehouse()
418426

419427
from soda_core.contracts.impl.check_types.row_count_check import (
420428
RowCountMetricImpl,
@@ -450,10 +458,13 @@ def __init__(
450458
self.sampler_type = self.dataset_configuration.test_row_sampler_configuration.test_row_sampler.type
451459
self.sampler_limit = self.dataset_configuration.test_row_sampler_configuration.test_row_sampler.limit
452460

461+
if self.dataset_configuration.compute_warehouse_override:
462+
self.compute_warehouse = self.dataset_configuration.compute_warehouse_override.name
463+
453464
if self.should_apply_sampling:
454465
logger.info(
455466
f"Row sampling is enabled for dataset {self.dataset_identifier.to_string()} "
456-
f"with sampler {self.dataset_configuration.test_row_sampler_configuration.test_row_sampler.type}"
467+
f"with sampler config: type:'{self.dataset_configuration.test_row_sampler_configuration.test_row_sampler.type}', limit:'{self.dataset_configuration.test_row_sampler_configuration.test_row_sampler.limit}'"
457468
)
458469

459470
# This modifies the CTE to include sampling by accessing the first element of the cte_query list, may be flaky. Consider adding a better way to modify queries, or change AST to a 3rd party library which may support it already.
@@ -598,6 +609,16 @@ def _parse_columns(self, contract_yaml: ContractYaml) -> list[ColumnImpl]:
598609
return columns
599610

600611
def verify(self) -> ContractVerificationResult:
612+
if (
613+
self.data_source_impl
614+
and self.datasource_warehouse
615+
and self.compute_warehouse
616+
and self.datasource_warehouse != self.compute_warehouse
617+
):
618+
logger.info(
619+
f"Switching warehouse from '{self.datasource_warehouse}' to '{self.compute_warehouse}' for Contract verification of dataset '{self.dataset_identifier.to_string()}'"
620+
)
621+
self.data_source_impl.switch_warehouse(self.compute_warehouse)
601622
data_source: Optional[DataSource] = None
602623
check_results: list[CheckResult] = []
603624
measurements: list[Measurement] = []
@@ -724,6 +745,18 @@ def verify(self) -> ContractVerificationResult:
724745
scan_id=scan_id, exc=e, contract_verification_handler=contract_verification_handler
725746
)
726747

748+
# Switch back to original warehouse if changed
749+
if (
750+
self.data_source_impl
751+
and self.compute_warehouse
752+
and self.datasource_warehouse
753+
and self.datasource_warehouse != self.compute_warehouse
754+
):
755+
logger.info(
756+
f"Switching back warehouse to '{self.datasource_warehouse}' after Contract verification of dataset '{self.dataset_identifier.to_string()}'"
757+
)
758+
self.data_source_impl.switch_warehouse(self.datasource_warehouse)
759+
727760
return contract_verification_result
728761

729762
def __get_dataset_id(self, soda_cloud_response_json: dict, qualified_dataset_name: str) -> Optional[str]:

soda-snowflake/src/soda_snowflake/common/data_sources/snowflake_data_source.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,19 @@ def _create_data_source_connection(self) -> DataSourceConnection:
3636
name=self.data_source_model.name, connection_properties=self.data_source_model.connection_properties
3737
)
3838

39+
def switch_warehouse(self, warehouse: str) -> None:
40+
switch_warehouse_sql = f"USE WAREHOUSE {warehouse}"
41+
self.execute_query(switch_warehouse_sql)
42+
43+
def get_current_warehouse(self) -> Optional[str]:
44+
sql = "SELECT CURRENT_WAREHOUSE()"
45+
current_warehouse_sql = "SELECT CURRENT_WAREHOUSE()"
46+
result = self.execute_query(current_warehouse_sql)
47+
result_rows = result.rows
48+
row = result_rows[0] if result_rows else None
49+
50+
return row[0] if row and row[0] else None
51+
3952

4053
class SnowflakeSqlDialect(SqlDialect):
4154
SODA_DATA_TYPE_SYNONYMS = (

0 commit comments

Comments
 (0)