@@ -410,11 +410,19 @@ def __init__(
410410 # TODO replace usage of self.sql_qualified_dataset_name with self.dataset_identifier
411411 self .sql_qualified_dataset_name : Optional [str ] = None
412412
413+ self .datasource_warehouse : Optional [str ] = None
414+ self .compute_warehouse : Optional [str ] = None
415+
413416 if data_source_impl :
414417 # TODO replace usage of self.sql_qualified_dataset_name with self.dataset_identifier
415418 self .sql_qualified_dataset_name = data_source_impl .sql_dialect .qualify_dataset_name (
416419 dataset_prefix = self .dataset_prefix , dataset_name = self .dataset_name
417420 )
421+ if hasattr (data_source_impl .data_source_model , "warehouse" ):
422+ self .datasource_warehouse = data_source_impl .data_source_model .warehouse
423+
424+ if self .datasource_warehouse is None :
425+ self .datasource_warehouse = data_source_impl .get_current_warehouse ()
418426
419427 from soda_core .contracts .impl .check_types .row_count_check import (
420428 RowCountMetricImpl ,
@@ -450,10 +458,13 @@ def __init__(
450458 self .sampler_type = self .dataset_configuration .test_row_sampler_configuration .test_row_sampler .type
451459 self .sampler_limit = self .dataset_configuration .test_row_sampler_configuration .test_row_sampler .limit
452460
461+ if self .dataset_configuration .compute_warehouse_override :
462+ self .compute_warehouse = self .dataset_configuration .compute_warehouse_override .name
463+
453464 if self .should_apply_sampling :
454465 logger .info (
455466 f"Row sampling is enabled for dataset { self .dataset_identifier .to_string ()} "
456- f"with sampler { self .dataset_configuration .test_row_sampler_configuration .test_row_sampler .type } "
467+ f"with sampler config: type:' { self .dataset_configuration .test_row_sampler_configuration .test_row_sampler .type } ', limit:' { self . dataset_configuration . test_row_sampler_configuration . test_row_sampler . limit } ' "
457468 )
458469
459470 # This modifies the CTE to include sampling by accessing the first element of the cte_query list, may be flaky. Consider adding a better way to modify queries, or change AST to a 3rd party library which may support it already.
@@ -598,6 +609,16 @@ def _parse_columns(self, contract_yaml: ContractYaml) -> list[ColumnImpl]:
598609 return columns
599610
600611 def verify (self ) -> ContractVerificationResult :
612+ if (
613+ self .data_source_impl
614+ and self .datasource_warehouse
615+ and self .compute_warehouse
616+ and self .datasource_warehouse != self .compute_warehouse
617+ ):
618+ logger .info (
619+ f"Switching warehouse from '{ self .datasource_warehouse } ' to '{ self .compute_warehouse } ' for Contract verification of dataset '{ self .dataset_identifier .to_string ()} '"
620+ )
621+ self .data_source_impl .switch_warehouse (self .compute_warehouse )
601622 data_source : Optional [DataSource ] = None
602623 check_results : list [CheckResult ] = []
603624 measurements : list [Measurement ] = []
@@ -724,6 +745,18 @@ def verify(self) -> ContractVerificationResult:
724745 scan_id = scan_id , exc = e , contract_verification_handler = contract_verification_handler
725746 )
726747
748+ # Switch back to original warehouse if changed
749+ if (
750+ self .data_source_impl
751+ and self .compute_warehouse
752+ and self .datasource_warehouse
753+ and self .datasource_warehouse != self .compute_warehouse
754+ ):
755+ logger .info (
756+ f"Switching back warehouse to '{ self .datasource_warehouse } ' after Contract verification of dataset '{ self .dataset_identifier .to_string ()} '"
757+ )
758+ self .data_source_impl .switch_warehouse (self .datasource_warehouse )
759+
727760 return contract_verification_result
728761
729762 def __get_dataset_id (self , soda_cloud_response_json : dict , qualified_dataset_name : str ) -> Optional [str ]:
0 commit comments