@@ -261,6 +261,24 @@ def test_load_metainfo(self):
261261class TestLocalResultsHandler :
262262 """Unit tests for the LocalResultsHandler class."""
263263
264+ def test__init__sets_base_path_and_default_baseline (self , tmp_path ):
265+ """Test it initializes base_path and default baseline."""
266+ # Run
267+ handler = LocalResultsHandler (str (tmp_path ))
268+
269+ # Assert
270+ assert handler .base_path == str (tmp_path )
271+ assert handler .baseline_synthesizer == 'GaussianCopulaSynthesizer'
272+
273+ def test__init__supports_baseline_override (self , tmp_path ):
274+ """Test it allows overriding baseline synthesizer."""
275+ # Run
276+ handler = LocalResultsHandler (str (tmp_path ), baseline_synthesizer = 'CustomBaseline' )
277+
278+ # Assert
279+ assert handler .base_path == str (tmp_path )
280+ assert handler .baseline_synthesizer == 'CustomBaseline'
281+
264282 def test_list (self , tmp_path ):
265283 """Test the `list` method"""
266284 # Setup
@@ -418,9 +436,7 @@ def test_get_file_path_local_error(self, mock_isfile, mock_exists):
418436class TestS3ResultsHandler :
419437 """Unit tests for the S3ResultsHandler class."""
420438
421- def test__init__ (
422- self ,
423- ):
439+ def test__init__ (self ):
424440 """Test the `__init__` method."""
425441 # Setup
426442 path = 's3://my-bucket/prefix'
@@ -432,6 +448,21 @@ def test__init__(
432448 assert result_handler .s3_client == 's3_client'
433449 assert result_handler .bucket_name == 'my-bucket'
434450 assert result_handler .prefix == 'prefix/'
451+ assert result_handler .baseline_synthesizer == 'GaussianCopulaSynthesizer'
452+
453+ def test__init__supports_baseline_override (self ):
454+ """Test it allows overriding baseline synthesizer."""
455+ # Run
456+ s3_client = Mock ()
457+ handler = S3ResultsHandler (
458+ 's3://bkt/prefix' , s3_client , baseline_synthesizer = 'CustomBaseline'
459+ )
460+
461+ # Assert
462+ assert handler .baseline_synthesizer == 'CustomBaseline'
463+ assert handler .s3_client == s3_client
464+ assert handler .bucket_name == 'bkt'
465+ assert handler .prefix == 'prefix/'
435466
436467 def test_list (self ):
437468 """Test the `list` method."""
0 commit comments