1+ import re
12from datetime import datetime
23from unittest .mock import Mock , call , patch
34
5+ import numpy as np
46import pandas as pd
57import pytest
68
@@ -115,7 +117,78 @@ def test_compute_breakdown_constant_input(self):
115117 with pytest .raises (ConstantInputError , match = error_msg ):
116118 metric .compute_breakdown (real_data , synthetic_data , coefficient = 'Pearson' )
117119
118- def test_compute (self ):
120+ @pytest .mark .parametrize (
121+ 'real_correlation_threshold, score' ,
122+ [
123+ (0 , 0.9008941765855183 ),
124+ (0.35 , 0.9008941765855183 ),
125+ (0.498212 , np .nan ),
126+ (0.75 , np .nan ),
127+ ],
128+ )
129+ def test_compute_breakdown_with_real_correlation_threshold (
130+ self , real_correlation_threshold , score
131+ ):
132+ """Test the ``compute_breakdown`` method with `real_correlation_threshold`.
133+
134+ In this test, real data has a correlation of 0.498212 and synthetic data
135+ has a correlation of 0.3.
136+ """
137+ # Setup
138+ real_data = pd .DataFrame ({
139+ 'col1' : [1.0 , 2.0 , 3.0 , 4.0 ],
140+ 'col2' : [0.2 , - 1.0895 , - 0.6425 , 1.5365 ],
141+ })
142+ synthetic_data = pd .DataFrame ({
143+ 'col1' : [1.0 , 2.0 , 3.0 , 4.0 ],
144+ 'col2' : [0.616536 , - 1.216536 , - 0.916536 , 1.516536 ],
145+ })
146+
147+ # Run
148+ metric = CorrelationSimilarity ()
149+ result = metric .compute_breakdown (
150+ real_data ,
151+ synthetic_data ,
152+ coefficient = 'Pearson' ,
153+ real_correlation_threshold = real_correlation_threshold ,
154+ )
155+
156+ # Assert
157+ assert (
158+ np .isclose (result ['score' ], score , atol = 1e-6 )
159+ if not np .isnan (score )
160+ else np .isnan (result ['score' ])
161+ )
162+
163+ def test_compute_breakdown_invalid_real_correlation_threshold (self ):
164+ """Test an error is thrown when an invalid `real_correlation_threshold` is passed."""
165+ # Setup
166+ real_data = pd .DataFrame ({'col1' : [1.0 , 2.0 , 3.0 ], 'col2' : [2.0 , 3.0 , 4.0 ]})
167+ synthetic_data = pd .DataFrame ({'col1' : [0.9 , 1.8 , 3.1 ], 'col2' : [2 , 3 , 4 ]})
168+ expected_error = re .escape ('real_correlation_threshold must be a number between 0 and 1.' )
169+ metric = CorrelationSimilarity ()
170+
171+ # Run and Assert
172+ with pytest .raises (ValueError , match = expected_error ):
173+ metric .compute_breakdown (
174+ real_data ,
175+ synthetic_data ,
176+ coefficient = 'Pearson' ,
177+ real_correlation_threshold = - 0.1 ,
178+ )
179+
180+ with pytest .raises (ValueError , match = expected_error ):
181+ metric .compute_breakdown (
182+ real_data ,
183+ synthetic_data ,
184+ coefficient = 'Pearson' ,
185+ real_correlation_threshold = None ,
186+ )
187+
188+ @patch (
189+ 'sdmetrics.column_pairs.statistical.correlation_similarity.CorrelationSimilarity.compute_breakdown'
190+ )
191+ def test_compute (self , compute_breakdown_mock ):
119192 """Test the ``compute`` method.
120193
121194 Expect that the selected coefficient is used to compare the real and synthetic data.
@@ -134,17 +207,23 @@ def test_compute(self):
134207 test_score = 0.2
135208 score_breakdown = {'score' : test_score }
136209 metric = CorrelationSimilarity ()
210+ compute_breakdown_mock .return_value = score_breakdown
211+ real_data = Mock ()
212+ synthetic_data = Mock ()
137213
138214 # Run
139- with patch .object (
140- CorrelationSimilarity ,
141- 'compute_breakdown' ,
142- return_value = score_breakdown ,
143- ):
144- result = metric .compute (Mock (), Mock (), coefficient = 'Pearson' )
215+ result = metric .compute (
216+ real_data , synthetic_data , coefficient = 'Pearson' , real_correlation_threshold = 0.6
217+ )
145218
146219 # Assert
147220 assert result == test_score
221+ compute_breakdown_mock .assert_called_once_with (
222+ real_data ,
223+ synthetic_data ,
224+ 'Pearson' ,
225+ 0.6 ,
226+ )
148227
149228 @patch ('sdmetrics.column_pairs.statistical.correlation_similarity.ColumnPairsMetric.normalize' )
150229 def test_normalize (self , normalize_mock ):
0 commit comments