@@ -217,6 +217,60 @@ def _preprocessing_failed(self, column_name_1, column_name_2, sdtype_col_1, sdty
217217
218218 return error
219219
220+ def _compute_pair_score (
221+ self ,
222+ metric ,
223+ col_real ,
224+ col_synthetic ,
225+ metric_params ,
226+ ):
227+ """Compute the score breakdown and threshold check for a column pair."""
228+ params = dict (metric_params )
229+ if metric .__name__ == 'CorrelationSimilarity' :
230+ if self .real_correlation_threshold > 0 :
231+ params ['real_correlation_threshold' ] = self .real_correlation_threshold
232+
233+ score_breakdown = metric .compute_breakdown (
234+ real_data = col_real , synthetic_data = col_synthetic , ** params
235+ )
236+ pair_score = score_breakdown ['score' ]
237+ real_correlation = score_breakdown ['real' ]
238+ synthetic_correlation = score_breakdown ['synthetic' ]
239+ real_association = np .nan
240+ if self .real_correlation_threshold <= 0 :
241+ meets_threshold_pair = True
242+ else :
243+ meets_threshold_pair = (
244+ not np .isnan (real_correlation )
245+ and abs (real_correlation ) > self .real_correlation_threshold
246+ )
247+ else :
248+ if self .real_association_threshold > 0 :
249+ params ['real_association_threshold' ] = self .real_association_threshold
250+
251+ score_breakdown = metric .compute_breakdown (
252+ real_data = col_real , synthetic_data = col_synthetic , ** params
253+ )
254+ pair_score = score_breakdown ['score' ]
255+ real_correlation = np .nan
256+ synthetic_correlation = np .nan
257+ real_association = score_breakdown .get ('real_association' , np .nan )
258+ if self .real_association_threshold > 0 :
259+ meets_threshold_pair = (
260+ not np .isnan (real_association )
261+ and real_association > self .real_association_threshold
262+ )
263+ else :
264+ meets_threshold_pair = True
265+
266+ return (
267+ pair_score ,
268+ real_correlation ,
269+ synthetic_correlation ,
270+ real_association ,
271+ meets_threshold_pair ,
272+ )
273+
220274 def _generate_details (
221275 self , real_data , synthetic_data , metadata , progress_bar = None , column_pairs = None
222276 ):
@@ -297,45 +351,13 @@ def _generate_details(
297351 if error :
298352 raise Exception ('Preprocessing failed' )
299353
300- real_association = np .nan
301- if metric .__name__ == 'CorrelationSimilarity' :
302- if self .real_correlation_threshold > 0 :
303- metric_params ['real_correlation_threshold' ] = (
304- self .real_correlation_threshold
305- )
306- score_breakdown = metric .compute_breakdown (
307- real_data = col_real , synthetic_data = col_synthetic , ** metric_params
308- )
309- pair_score = score_breakdown ['score' ]
310- real_correlation = score_breakdown ['real' ]
311- synthetic_correlation = score_breakdown ['synthetic' ]
312- if self .real_correlation_threshold <= 0 :
313- meets_threshold_pair = True
314- else :
315- meets_threshold_pair = (
316- not np .isnan (real_correlation )
317- and abs (real_correlation ) > self .real_correlation_threshold
318- )
319- else :
320- real_correlation = np .nan
321- synthetic_correlation = np .nan
322- if self .real_association_threshold > 0 :
323- metric_params ['real_association_threshold' ] = (
324- self .real_association_threshold
325- )
326- score_breakdown = metric .compute_breakdown (
327- real_data = col_real , synthetic_data = col_synthetic , ** metric_params
328- )
329- pair_score = score_breakdown ['score' ]
330- real_association = score_breakdown .get ('real_association' , np .nan )
331-
332- if self .real_association_threshold > 0 :
333- meets_threshold_pair = (
334- not np .isnan (real_association )
335- and real_association > self .real_association_threshold
336- )
337- else :
338- meets_threshold_pair = True
354+ (
355+ pair_score ,
356+ real_correlation ,
357+ synthetic_correlation ,
358+ real_association ,
359+ meets_threshold_pair ,
360+ ) = self ._compute_pair_score (metric , col_real , col_synthetic , metric_params )
339361
340362 except Exception as e :
341363 pair_score = np .nan
@@ -386,14 +408,26 @@ def _get_correlation_matrix(self, column_name):
386408 if column_name not in ['Score' , 'Real Correlation' , 'Synthetic Correlation' ]:
387409 raise ValueError (f"Invalid column name for _get_correlation_matrix : '{ column_name } '" )
388410
389- table = self .details .dropna (subset = [column_name ])
390- names = list (pd .concat ([table ['Column 1' ], table ['Column 2' ]]).unique ())
411+ table = self .details
412+ if column_name in ['Real Correlation' , 'Synthetic Correlation' ]:
413+ names_source = table [table ['Metric' ] == 'CorrelationSimilarity' ]
414+ else :
415+ names_source = table
416+
417+ table = names_source .dropna (subset = [column_name ])
418+ if column_name == 'Score' :
419+ names_source = self .details
420+
421+ names = list (pd .concat ([names_source ['Column 1' ], names_source ['Column 2' ]]).unique ())
422+ available_columns = set (pd .concat ([table ['Column 1' ], table ['Column 2' ]]).unique ())
391423 heatmap_df = pd .DataFrame (index = names , columns = names )
392424
393425 for idx_1 , column_name_1 in enumerate (names ):
394426 for column_name_2 in names [idx_1 :]:
395427 if column_name_1 == column_name_2 :
396- heatmap_df .loc [column_name_1 , column_name_2 ] = 1
428+ heatmap_df .loc [column_name_1 , column_name_2 ] = (
429+ 1 if column_name_1 in available_columns else np .nan
430+ )
397431 continue
398432
399433 # check wether the combination (Colunm 1, Column 2) or (Column 2, Column 1)
@@ -426,7 +460,7 @@ def _get_heatmap(self, correlation_matrix, coloraxis, hovertemplate, customdata=
426460 customdata (pandas.DataFrame or None):
427461 The customdata to use. Defaults to None.
428462 """
429- fig = go .Heatmap (
463+ base_heatmap = go .Heatmap (
430464 x = correlation_matrix .columns ,
431465 y = correlation_matrix .columns ,
432466 z = correlation_matrix ,
@@ -435,7 +469,20 @@ def _get_heatmap(self, correlation_matrix, coloraxis, hovertemplate, customdata=
435469 hovertemplate = hovertemplate ,
436470 )
437471
438- return fig
472+ nan_mask = correlation_matrix .isna ().to_numpy ()
473+ if not nan_mask .any ():
474+ return [base_heatmap ]
475+
476+ nan_heatmap = go .Heatmap (
477+ x = correlation_matrix .columns ,
478+ y = correlation_matrix .columns ,
479+ z = np .where (nan_mask , 1 , np .nan ),
480+ colorscale = [[0 , '#B0B0B0' ], [1 , '#B0B0B0' ]],
481+ showscale = False ,
482+ hoverinfo = 'skip' ,
483+ )
484+
485+ return [base_heatmap , nan_heatmap ]
439486
440487 def _update_layout (self , fig ):
441488 """Update the layout of the figure.
@@ -495,13 +542,16 @@ def get_visualization(self):
495542
496543 fig .update_xaxes (tickangle = 45 )
497544
498- fig .add_trace (self ._get_heatmap (similarity_correlation , 'coloraxis' , tmpl_1 ), 1 , 1 )
499- fig .add_trace (
500- self ._get_heatmap (real_correlation , 'coloraxis2' , tmpl_2 , synthetic_correlation ), 2 , 1
501- )
502- fig .add_trace (
503- self ._get_heatmap (synthetic_correlation , 'coloraxis2' , tmpl_2 , real_correlation ), 2 , 2
504- )
545+ for trace in self ._get_heatmap (similarity_correlation , 'coloraxis' , tmpl_1 ):
546+ fig .add_trace (trace , 1 , 1 )
547+ for trace in self ._get_heatmap (
548+ real_correlation , 'coloraxis2' , tmpl_2 , synthetic_correlation
549+ ):
550+ fig .add_trace (trace , 2 , 1 )
551+ for trace in self ._get_heatmap (
552+ synthetic_correlation , 'coloraxis2' , tmpl_2 , real_correlation
553+ ):
554+ fig .add_trace (trace , 2 , 2 )
505555
506556 self ._update_layout (fig )
507557
0 commit comments