Skip to content

Commit f0f17d1

Browse files
committed
update mask
1 parent 993241e commit f0f17d1

File tree

1 file changed

+3
-5
lines changed

1 file changed

+3
-5
lines changed

sdgym/benchmark.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1169,16 +1169,14 @@ def _add_adjusted_scores(scores, timeout):
11691169

11701170
for dataset in scores['Dataset'].unique():
11711171
dataset_mask = scores['Dataset'] == dataset
1172-
uniform_mask = dataset_mask & scores['Synthesizer'].str.contains(
1173-
'UniformSynthesizer', na=False
1174-
)
1175-
if not uniform_mask.any():
1172+
uniform_mask_dataset = dataset_mask & uniform_mask
1173+
if not uniform_mask_dataset.any():
11761174
scores.loc[dataset_mask, 'Adjusted_Total_Time'] = None
11771175
if 'Adjusted_Quality_Score' in scores.columns:
11781176
scores.loc[dataset_mask, 'Adjusted_Quality_Score'] = None
11791177
continue
11801178

1181-
uniform_row = scores.loc[uniform_mask].iloc[0]
1179+
uniform_row = scores.loc[uniform_mask_dataset].iloc[0]
11821180
base_fit_time = uniform_row.get('Train_Time')
11831181
base_sample_time = uniform_row.get('Sample_Time')
11841182
base_quality_score = uniform_row.get('Quality_Score', None)

0 commit comments

Comments
 (0)