Skip to content

Commit 5d7f8b7

Browse files
authored
Fix wrong mocks (#694)
* Fix wrong mocks * Fix tests * Remove print patch * Fix lint
1 parent 2bfa2e2 commit 5d7f8b7

File tree

3 files changed

+7
-10
lines changed

3 files changed

+7
-10
lines changed

rdt/transformers/numerical.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -524,8 +524,7 @@ def _reverse_transform_helper(self, data):
524524
normalized = np.clip(data[:, 0], -1, 1)
525525
means = self._bgm_transformer.means_.reshape([-1])
526526
stds = np.sqrt(self._bgm_transformer.covariances_).reshape([-1])
527-
selected_component = data[:, 1].astype(int)
528-
527+
selected_component = data[:, 1].astype(int) # maybe round instead?
529528
std_t = stds[self.valid_component_indicator][selected_component]
530529
mean_t = means[self.valid_component_indicator][selected_component]
531530
reversed_data = normalized * self.STD_MULTIPLIER * std_t + mean_t

tests/unit/test_hyper_transformer.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2260,8 +2260,7 @@ def test_update_transformers_no_field_transformers(self):
22602260

22612261
assert instance.get_config() == expected_config
22622262

2263-
@patch('rdt.hyper_transformer.print')
2264-
def test_update_transformers_missmatch_sdtypes(self, mock_warnings):
2263+
def test_update_transformers_mismatch_sdtypes(self):
22652264
"""Test update transformers.
22662265
22672266
Ensure that the function updates properly the ``self.field_transformers`` and prints the
@@ -2303,7 +2302,6 @@ def test_update_transformers_missmatch_sdtypes(self, mock_warnings):
23032302
with pytest.raises(InvalidConfigError, match=err_msg):
23042303
instance.update_transformers(column_name_to_transformer)
23052304

2306-
assert mock_warnings.called_once_with(err_msg)
23072305
instance._validate_transformers.assert_called_once_with(column_name_to_transformer)
23082306

23092307
def test_update_transformers_transformer_is_none(self):

tests/unit/transformers/pii/test_anonymizer.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,7 @@ def test___init__default(self, mock_check_provider_function, mock_faker):
218218
assert instance.function_name == 'lexify'
219219
assert instance.function_kwargs == {}
220220
assert instance.locales is None
221-
assert mock_faker.Faker.called_once_with(None)
221+
mock_faker.Faker.assert_called_once_with(None)
222222
assert instance.enforce_uniqueness is False
223223
assert instance.missing_value_generation == 'random'
224224

@@ -279,7 +279,7 @@ def test___init__custom(self, mock_check_provider_function, mock_faker):
279279
assert instance.function_name == 'credit_card_full'
280280
assert instance.function_kwargs == {'type': 'visa'}
281281
assert instance.locales == ['en_US', 'fr_FR']
282-
assert mock_faker.Faker.called_once_with(['en_US', 'fr_FR'])
282+
mock_faker.Faker.assert_called_once_with(['en_US', 'fr_FR'])
283283
assert instance.enforce_uniqueness
284284

285285
def test___init__no_function_name(self):
@@ -346,7 +346,7 @@ def test_reset_randomization(self, mock_faker, mock_base_reset):
346346
AnonymizedFaker.reset_randomization(instance)
347347

348348
# Assert
349-
assert mock_faker.Faker.called_once_with(['en_US'])
349+
mock_faker.Faker.assert_has_calls([call(None), call(['en_US'])])
350350
mock_base_reset.assert_called_once()
351351

352352
def test__fit(self):
@@ -597,7 +597,7 @@ def test___init__super_attrs(self, mock_check_provider_function, mock_faker):
597597
assert instance.function_name == 'lexify'
598598
assert instance.function_kwargs == {}
599599
assert instance.locales is None
600-
assert mock_faker.Faker.called_once_with(None)
600+
mock_faker.Faker.assert_called_once_with(None)
601601

602602
@patch('rdt.transformers.pii.anonymizer.faker')
603603
@patch('rdt.transformers.pii.anonymizer.AnonymizedFaker.check_provider_function')
@@ -641,7 +641,7 @@ def test___init__custom(self, mock_check_provider_function, mock_faker):
641641
assert instance.function_name == 'credit_card_full'
642642
assert instance.function_kwargs == {'type': 'visa'}
643643
assert instance.locales == ['en_US', 'fr_FR']
644-
assert mock_faker.Faker.called_once_with(['en_US', 'fr_FR'])
644+
mock_faker.Faker.assert_called_once_with(['en_US', 'fr_FR'])
645645

646646
def test_get_mapping(self):
647647
"""Test the ``get_mapping`` method.

0 commit comments

Comments
 (0)