Skip to content

Commit 114258d

Browse files
Merge pull request #170 from sissa-data-science/revert-161-dii-jax_bug-fix
Revert "Dii jax bug fix"
2 parents 739fae1 + afe8ee3 commit 114258d

4 files changed

Lines changed: 169 additions & 1476 deletions

File tree

dadapy/diff_imbalance.py

Lines changed: 27 additions & 116 deletions
Original file line numberDiff line numberDiff line change
@@ -129,9 +129,6 @@ class DiffImbalance:
129129
seed (int): seed of JAX random generator, default is 0. Different seeds determine different mini-batch
130130
partitions.
131131
l1_strength (float): strength of the L1 regularization (LASSO) term. Default is 0.
132-
gradient_clip_value (float): maximum norm for gradient clipping. If 0, no clipping is
133-
applied. Default is 0. This is useful when weights are sometimes automatically set to NaN and
134-
there can be gradient explosions.
135132
point_adapt_lambda (bool): whether to use a global smoothing parameter lambda for the c_ij coefficients
136133
in the DII (if False), or a different parameter for each point (if True). Default is True.
137134
k_init (int): initial rank of neighbors used to set lambda. Ranks are defined starting from 1. If
@@ -183,7 +180,6 @@ def __init__(
183180
learning_rate=1e-2,
184181
learning_rate_decay=None,
185182
num_points_rows=None,
186-
gradient_clip_value=0.0,
187183
):
188184
"""Initialise the DiffImbalance class."""
189185
self.nfeatures_A = data_A.shape[1]
@@ -262,7 +258,6 @@ def __init__(
262258
self.num_epochs = num_epochs
263259
self.batches_per_epoch = batches_per_epoch
264260
self.l1_strength = l1_strength
265-
self.gradient_clip_value = gradient_clip_value
266261
self.point_adapt_lambda = point_adapt_lambda
267262
self.k_init = k_init
268263
self.k_final = k_final
@@ -853,14 +848,7 @@ def _init_optimizer(self):
853848
raise ValueError(
854849
f'Unknown learning rate decay schedule "{self.learning_rate_decay}". Choose among None, "cos" and "exp".'
855850
)
856-
# Set up optimizer with optional gradient clipping
857-
if self.gradient_clip_value > 0:
858-
optimizer = optax.chain(
859-
optax.clip_by_global_norm(self.gradient_clip_value),
860-
opt_class(self.lr_schedule),
861-
)
862-
else:
863-
optimizer = opt_class(self.lr_schedule)
851+
optimizer = opt_class(self.lr_schedule)
864852

865853
# Initialize training state
866854
self.state = train_state.TrainState.create(
@@ -1167,20 +1155,10 @@ def forward_greedy_feature_selection(
11671155
learning_rate=self.learning_rate,
11681156
learning_rate_decay=self.learning_rate_decay,
11691157
num_points_rows=self.num_points_rows,
1170-
gradient_clip_value=self.gradient_clip_value,
11711158
)
11721159

11731160
# Set initial parameters and train
1174-
try:
1175-
_, _ = dii_copy.train()
1176-
except AssertionError as e:
1177-
print(f"Training failed for feature [{feature}]: {str(e)}")
1178-
print(f"Skipping feature [{feature}] and continuing...")
1179-
single_feature_diis.append(
1180-
float("inf")
1181-
) # Use infinity as a large penalty
1182-
single_feature_errors.append(None)
1183-
continue
1161+
_, _ = dii_copy.train()
11841162

11851163
# Compute DII on the full dataset
11861164
if compute_error:
@@ -1207,18 +1185,9 @@ def forward_greedy_feature_selection(
12071185
# Convert to numpy arrays for easier manipulation
12081186
single_feature_diis = np.array(single_feature_diis)
12091187

1210-
# Check if we have any valid features (not infinity)
1211-
valid_features = np.isfinite(single_feature_diis)
1212-
if not np.any(valid_features):
1213-
print("ERROR: All single features failed during training!")
1214-
return [], [], [], []
1215-
1216-
# Select the best n_best single features (only from valid ones)
1217-
valid_indices = np.where(valid_features)[0]
1218-
valid_diis = single_feature_diis[valid_indices]
1219-
n_best_actual = min(n_best, len(valid_indices))
1220-
best_valid_indices = np.argsort(valid_diis)[:n_best_actual]
1221-
selected_indices = valid_indices[best_valid_indices]
1188+
# Select the best n_best single features
1189+
n_best_actual = min(n_best, n_features)
1190+
selected_indices = np.argsort(single_feature_diis)[:n_best_actual]
12221191

12231192
# Convert indices to lists for consistent processing
12241193
selected_features = [[idx] for idx in selected_indices]
@@ -1301,24 +1270,10 @@ def forward_greedy_feature_selection(
13011270
learning_rate=self.learning_rate,
13021271
learning_rate_decay=self.learning_rate_decay,
13031272
num_points_rows=self.num_points_rows,
1304-
gradient_clip_value=self.gradient_clip_value,
13051273
)
13061274

13071275
# Set initial parameters and train
1308-
try:
1309-
_, _ = dii_copy.train()
1310-
except AssertionError as e:
1311-
print(
1312-
f"Training failed for feature set {candidate_set}: {str(e)}"
1313-
)
1314-
print(
1315-
f"Skipping feature set {candidate_set} and continuing..."
1316-
)
1317-
candidate_diis.append(
1318-
float("inf")
1319-
) # Use infinity as a large penalty
1320-
candidate_errors.append(None)
1321-
continue
1276+
_, _ = dii_copy.train()
13221277

13231278
# Compute DII on the full dataset
13241279
if compute_error:
@@ -1350,18 +1305,9 @@ def forward_greedy_feature_selection(
13501305
if not candidate_features: # No more features to add
13511306
break
13521307

1353-
# Check if we have any valid candidates (not infinity)
1354-
valid_candidates = np.isfinite(candidate_diis)
1355-
if not np.any(valid_candidates):
1356-
print("ERROR: All candidate feature sets failed during training!")
1357-
break
1358-
1359-
# Select the best n_best candidates for the next iteration (only from valid ones)
1360-
valid_indices = np.where(valid_candidates)[0]
1361-
valid_diis = candidate_diis[valid_indices]
1362-
n_best_actual = min(n_best, len(valid_indices))
1363-
best_valid_indices = np.argsort(valid_diis)[:n_best_actual]
1364-
best_indices = valid_indices[best_valid_indices]
1308+
# Select the best n_best candidates for the next iteration
1309+
n_best_actual = min(n_best, len(candidate_features))
1310+
best_indices = np.argsort(candidate_diis)[:n_best_actual]
13651311
selected_features = [candidate_features[i] for i in best_indices]
13661312

13671313
# Print the best feature set information
@@ -1401,25 +1347,18 @@ def forward_greedy_feature_selection(
14011347
learning_rate=self.learning_rate,
14021348
learning_rate_decay=self.learning_rate_decay,
14031349
num_points_rows=self.num_points_rows,
1404-
gradient_clip_value=self.gradient_clip_value,
14051350
)
14061351

14071352
# Set initial parameters and train
1408-
try:
1409-
_, _ = dii_copy.train()
1410-
# Print and store optimal weights
1411-
print(
1412-
f"\nOptimal weights for feature set {candidate_features[best_idx]}: {dii_copy.params_final}\n"
1413-
)
1414-
# Save optimal weights
1415-
best_weights = np.array(dii_copy.params_final)
1416-
except AssertionError as e:
1417-
print(
1418-
f"Training failed for best feature set {candidate_features[best_idx]}: {str(e)}"
1419-
)
1420-
print(f"Using zero weights for this iteration...")
1421-
best_weights = np.zeros(n_features)
1353+
_, _ = dii_copy.train()
14221354

1355+
# Print and store optimal weights
1356+
print(
1357+
f"\nOptimal weights for feature set {candidate_features[best_idx]}: {dii_copy.params_final}\n"
1358+
)
1359+
1360+
# Save optimal weights
1361+
best_weights = np.array(dii_copy.params_final)
14231362
best_weights_list.append(best_weights)
14241363

14251364
# Print the best n-tuple information
@@ -1580,24 +1519,13 @@ def backward_greedy_feature_selection(
15801519
learning_rate=self.learning_rate,
15811520
learning_rate_decay=self.learning_rate_decay,
15821521
num_points_rows=self.num_points_rows,
1583-
gradient_clip_value=self.gradient_clip_value,
15841522
)
15851523

15861524
# Set initial parameters and train
1587-
try:
1588-
_, _ = dii_copy.train()
1589-
# Store the trained weights
1590-
trained_weights = dii_copy.params_final
1591-
except AssertionError as e:
1592-
print(
1593-
f"Training failed for feature set {candidate_set}: {str(e)}"
1594-
)
1595-
print(f"Skipping feature set {candidate_set} and continuing...")
1596-
candidate_diis.append(
1597-
float("inf")
1598-
) # Use infinity as a large penalty
1599-
candidate_errors.append(None)
1600-
continue
1525+
_, _ = dii_copy.train()
1526+
1527+
# Store the trained weights
1528+
trained_weights = dii_copy.params_final
16011529

16021530
# Use return_final_dii to compute DII on the full dataset
16031531
dii_copy.params_final = trained_weights
@@ -1632,18 +1560,9 @@ def backward_greedy_feature_selection(
16321560
# Convert to numpy arrays for easier manipulation
16331561
candidate_diis = np.array(candidate_diis)
16341562

1635-
# Check if we have any valid candidates (not infinity)
1636-
valid_candidates = np.isfinite(candidate_diis)
1637-
if not np.any(valid_candidates):
1638-
print("ERROR: All candidate feature sets failed during training!")
1639-
break
1640-
1641-
# Select the best n_best candidates (only from valid ones)
1642-
valid_indices = np.where(valid_candidates)[0]
1643-
valid_diis = candidate_diis[valid_indices]
1644-
n_best_actual = min(n_best, len(valid_indices))
1645-
best_valid_indices = np.argsort(valid_diis)[:n_best_actual]
1646-
best_indices = valid_indices[best_valid_indices]
1563+
# Select the best n_best candidates
1564+
n_best_actual = min(n_best, len(candidate_features))
1565+
best_indices = np.argsort(candidate_diis)[:n_best_actual]
16471566

16481567
# Update current features for the next iteration
16491568
current_features = [candidate_features[i] for i in best_indices]
@@ -1677,21 +1596,13 @@ def backward_greedy_feature_selection(
16771596
learning_rate=self.learning_rate,
16781597
learning_rate_decay=self.learning_rate_decay,
16791598
num_points_rows=self.num_points_rows,
1680-
gradient_clip_value=self.gradient_clip_value,
16811599
)
16821600

16831601
# Set initial parameters and train
1684-
try:
1685-
_, _ = dii_copy.train()
1686-
# Save optimal weights
1687-
best_weights = dii_copy.params_final
1688-
except AssertionError as e:
1689-
print(
1690-
f"Training failed for best feature set {best_feature_set}: {str(e)}"
1691-
)
1692-
print(f"Using zero weights for this iteration...")
1693-
best_weights = np.zeros(n_features)
1602+
_, _ = dii_copy.train()
16941603

1604+
# Save optimal weights
1605+
best_weights = dii_copy.params_final
16951606
best_weights_list.append(best_weights)
16961607

16971608
# Store results

examples/notebook_on_differentiable_imbalance_jax.ipynb

Lines changed: 142 additions & 1106 deletions
Large diffs are not rendered by default.

tests/test_diff_imbalance_jax/test_greedy_dii.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,6 @@ def test_DiffImbalance_forward_greedy():
6969
learning_rate=1e-1,
7070
learning_rate_decay="cos",
7171
num_points_rows=None,
72-
gradient_clip_value=0.0,
7372
)
7473
weights, imbs = dii.train()
7574

@@ -144,7 +143,6 @@ def test_DiffImbalance_backward_greedy():
144143
learning_rate=1e-1,
145144
learning_rate_decay="cos",
146145
num_points_rows=None,
147-
gradient_clip_value=0.0,
148146
)
149147
weights, imbs = dii.train()
150148

@@ -222,7 +220,6 @@ def test_DiffImbalance_greedy_symmetry_5d_gaussian():
222220
learning_rate=1e-1,
223221
learning_rate_decay="cos",
224222
num_points_rows=None,
225-
gradient_clip_value=0.0,
226223
)
227224
weights, imbs = dii.train()
228225

@@ -356,7 +353,6 @@ def test_DiffImbalance_greedy_random_initialization():
356353
learning_rate=1e-1,
357354
learning_rate_decay="cos",
358355
num_points_rows=None,
359-
gradient_clip_value=0.0,
360356
)
361357
weights, imbs = dii.train()
362358

0 commit comments

Comments
 (0)