diff --git a/sdv/_utils.py b/sdv/_utils.py index 0fa13fa33..0407ec926 100644 --- a/sdv/_utils.py +++ b/sdv/_utils.py @@ -273,7 +273,7 @@ def check_sdv_versions_and_warn(synthesizer): """ current_community_version = getattr(version, 'community', None) current_enterprise_version = getattr(version, 'enterprise', None) - if synthesizer._fitted: + if getattr(synthesizer, '_fitted', False): fitted_community_version = getattr(synthesizer, '_fitted_sdv_version', None) fitted_enterprise_version = getattr(synthesizer, '_fitted_sdv_enterprise_version', None) community_mismatch = current_community_version != fitted_community_version diff --git a/tests/unit/test__utils.py b/tests/unit/test__utils.py index 936808e97..542dda4ba 100644 --- a/tests/unit/test__utils.py +++ b/tests/unit/test__utils.py @@ -408,6 +408,22 @@ def test_check_sdv_versions_and_warn_no_mismatch(mock_warnings): mock_warnings.warn.assert_not_called() +@patch('sdv._utils.warnings') +def test_check_sdv_versions_and_warn_dayz(mock_warnings): + """Test that the method works for ``DayZSynthesizer``.""" + # Setup + synthesizer = Mock() + synthesizer._fitted = False + + # Run + check_sdv_versions_and_warn(synthesizer) + synthesizer.__class__.__name__ = 'DayZSynthesizer' + check_sdv_versions_and_warn(synthesizer) + + # Assert + mock_warnings.warn.assert_not_called() + + def test_check_sdv_versions_and_warn_community_mismatch(): """Test that warnings is raised when community version is mismatched.""" # Setup