@@ -129,9 +129,6 @@ class DiffImbalance:
129129 seed (int): seed of JAX random generator, default is 0. Different seeds determine different mini-batch
130130 partitions.
131131 l1_strength (float): strength of the L1 regularization (LASSO) term. Default is 0.
132- gradient_clip_value (float): maximum norm for gradient clipping. If 0, no clipping is
133- applied. Default is 0. This is useful when weights are sometimes automatically set to NaN and
134- there can be gradient explosions.
135132 point_adapt_lambda (bool): whether to use a global smoothing parameter lambda for the c_ij coefficients
136133 in the DII (if False), or a different parameter for each point (if True). Default is True.
137134 k_init (int): initial rank of neighbors used to set lambda. Ranks are defined starting from 1. If
@@ -183,7 +180,6 @@ def __init__(
183180 learning_rate = 1e-2 ,
184181 learning_rate_decay = None ,
185182 num_points_rows = None ,
186- gradient_clip_value = 0.0 ,
187183 ):
188184 """Initialise the DiffImbalance class."""
189185 self .nfeatures_A = data_A .shape [1 ]
@@ -262,7 +258,6 @@ def __init__(
262258 self .num_epochs = num_epochs
263259 self .batches_per_epoch = batches_per_epoch
264260 self .l1_strength = l1_strength
265- self .gradient_clip_value = gradient_clip_value
266261 self .point_adapt_lambda = point_adapt_lambda
267262 self .k_init = k_init
268263 self .k_final = k_final
@@ -853,14 +848,7 @@ def _init_optimizer(self):
853848 raise ValueError (
854849 f'Unknown learning rate decay schedule "{ self .learning_rate_decay } ". Choose among None, "cos" and "exp".'
855850 )
856- # Set up optimizer with optional gradient clipping
857- if self .gradient_clip_value > 0 :
858- optimizer = optax .chain (
859- optax .clip_by_global_norm (self .gradient_clip_value ),
860- opt_class (self .lr_schedule ),
861- )
862- else :
863- optimizer = opt_class (self .lr_schedule )
851+ optimizer = opt_class (self .lr_schedule )
864852
865853 # Initialize training state
866854 self .state = train_state .TrainState .create (
@@ -1167,20 +1155,10 @@ def forward_greedy_feature_selection(
11671155 learning_rate = self .learning_rate ,
11681156 learning_rate_decay = self .learning_rate_decay ,
11691157 num_points_rows = self .num_points_rows ,
1170- gradient_clip_value = self .gradient_clip_value ,
11711158 )
11721159
11731160 # Set initial parameters and train
1174- try :
1175- _ , _ = dii_copy .train ()
1176- except AssertionError as e :
1177- print (f"Training failed for feature [{ feature } ]: { str (e )} " )
1178- print (f"Skipping feature [{ feature } ] and continuing..." )
1179- single_feature_diis .append (
1180- float ("inf" )
1181- ) # Use infinity as a large penalty
1182- single_feature_errors .append (None )
1183- continue
1161+ _ , _ = dii_copy .train ()
11841162
11851163 # Compute DII on the full dataset
11861164 if compute_error :
@@ -1207,18 +1185,9 @@ def forward_greedy_feature_selection(
12071185 # Convert to numpy arrays for easier manipulation
12081186 single_feature_diis = np .array (single_feature_diis )
12091187
1210- # Check if we have any valid features (not infinity)
1211- valid_features = np .isfinite (single_feature_diis )
1212- if not np .any (valid_features ):
1213- print ("ERROR: All single features failed during training!" )
1214- return [], [], [], []
1215-
1216- # Select the best n_best single features (only from valid ones)
1217- valid_indices = np .where (valid_features )[0 ]
1218- valid_diis = single_feature_diis [valid_indices ]
1219- n_best_actual = min (n_best , len (valid_indices ))
1220- best_valid_indices = np .argsort (valid_diis )[:n_best_actual ]
1221- selected_indices = valid_indices [best_valid_indices ]
1188+ # Select the best n_best single features
1189+ n_best_actual = min (n_best , n_features )
1190+ selected_indices = np .argsort (single_feature_diis )[:n_best_actual ]
12221191
12231192 # Convert indices to lists for consistent processing
12241193 selected_features = [[idx ] for idx in selected_indices ]
@@ -1301,24 +1270,10 @@ def forward_greedy_feature_selection(
13011270 learning_rate = self .learning_rate ,
13021271 learning_rate_decay = self .learning_rate_decay ,
13031272 num_points_rows = self .num_points_rows ,
1304- gradient_clip_value = self .gradient_clip_value ,
13051273 )
13061274
13071275 # Set initial parameters and train
1308- try :
1309- _ , _ = dii_copy .train ()
1310- except AssertionError as e :
1311- print (
1312- f"Training failed for feature set { candidate_set } : { str (e )} "
1313- )
1314- print (
1315- f"Skipping feature set { candidate_set } and continuing..."
1316- )
1317- candidate_diis .append (
1318- float ("inf" )
1319- ) # Use infinity as a large penalty
1320- candidate_errors .append (None )
1321- continue
1276+ _ , _ = dii_copy .train ()
13221277
13231278 # Compute DII on the full dataset
13241279 if compute_error :
@@ -1350,18 +1305,9 @@ def forward_greedy_feature_selection(
13501305 if not candidate_features : # No more features to add
13511306 break
13521307
1353- # Check if we have any valid candidates (not infinity)
1354- valid_candidates = np .isfinite (candidate_diis )
1355- if not np .any (valid_candidates ):
1356- print ("ERROR: All candidate feature sets failed during training!" )
1357- break
1358-
1359- # Select the best n_best candidates for the next iteration (only from valid ones)
1360- valid_indices = np .where (valid_candidates )[0 ]
1361- valid_diis = candidate_diis [valid_indices ]
1362- n_best_actual = min (n_best , len (valid_indices ))
1363- best_valid_indices = np .argsort (valid_diis )[:n_best_actual ]
1364- best_indices = valid_indices [best_valid_indices ]
1308+ # Select the best n_best candidates for the next iteration
1309+ n_best_actual = min (n_best , len (candidate_features ))
1310+ best_indices = np .argsort (candidate_diis )[:n_best_actual ]
13651311 selected_features = [candidate_features [i ] for i in best_indices ]
13661312
13671313 # Print the best feature set information
@@ -1401,25 +1347,18 @@ def forward_greedy_feature_selection(
14011347 learning_rate = self .learning_rate ,
14021348 learning_rate_decay = self .learning_rate_decay ,
14031349 num_points_rows = self .num_points_rows ,
1404- gradient_clip_value = self .gradient_clip_value ,
14051350 )
14061351
14071352 # Set initial parameters and train
1408- try :
1409- _ , _ = dii_copy .train ()
1410- # Print and store optimal weights
1411- print (
1412- f"\n Optimal weights for feature set { candidate_features [best_idx ]} : { dii_copy .params_final } \n "
1413- )
1414- # Save optimal weights
1415- best_weights = np .array (dii_copy .params_final )
1416- except AssertionError as e :
1417- print (
1418- f"Training failed for best feature set { candidate_features [best_idx ]} : { str (e )} "
1419- )
1420- print (f"Using zero weights for this iteration..." )
1421- best_weights = np .zeros (n_features )
1353+ _ , _ = dii_copy .train ()
14221354
1355+ # Print and store optimal weights
1356+ print (
1357+ f"\n Optimal weights for feature set { candidate_features [best_idx ]} : { dii_copy .params_final } \n "
1358+ )
1359+
1360+ # Save optimal weights
1361+ best_weights = np .array (dii_copy .params_final )
14231362 best_weights_list .append (best_weights )
14241363
14251364 # Print the best n-tuple information
@@ -1580,24 +1519,13 @@ def backward_greedy_feature_selection(
15801519 learning_rate = self .learning_rate ,
15811520 learning_rate_decay = self .learning_rate_decay ,
15821521 num_points_rows = self .num_points_rows ,
1583- gradient_clip_value = self .gradient_clip_value ,
15841522 )
15851523
15861524 # Set initial parameters and train
1587- try :
1588- _ , _ = dii_copy .train ()
1589- # Store the trained weights
1590- trained_weights = dii_copy .params_final
1591- except AssertionError as e :
1592- print (
1593- f"Training failed for feature set { candidate_set } : { str (e )} "
1594- )
1595- print (f"Skipping feature set { candidate_set } and continuing..." )
1596- candidate_diis .append (
1597- float ("inf" )
1598- ) # Use infinity as a large penalty
1599- candidate_errors .append (None )
1600- continue
1525+ _ , _ = dii_copy .train ()
1526+
1527+ # Store the trained weights
1528+ trained_weights = dii_copy .params_final
16011529
16021530 # Use return_final_dii to compute DII on the full dataset
16031531 dii_copy .params_final = trained_weights
@@ -1632,18 +1560,9 @@ def backward_greedy_feature_selection(
16321560 # Convert to numpy arrays for easier manipulation
16331561 candidate_diis = np .array (candidate_diis )
16341562
1635- # Check if we have any valid candidates (not infinity)
1636- valid_candidates = np .isfinite (candidate_diis )
1637- if not np .any (valid_candidates ):
1638- print ("ERROR: All candidate feature sets failed during training!" )
1639- break
1640-
1641- # Select the best n_best candidates (only from valid ones)
1642- valid_indices = np .where (valid_candidates )[0 ]
1643- valid_diis = candidate_diis [valid_indices ]
1644- n_best_actual = min (n_best , len (valid_indices ))
1645- best_valid_indices = np .argsort (valid_diis )[:n_best_actual ]
1646- best_indices = valid_indices [best_valid_indices ]
1563+ # Select the best n_best candidates
1564+ n_best_actual = min (n_best , len (candidate_features ))
1565+ best_indices = np .argsort (candidate_diis )[:n_best_actual ]
16471566
16481567 # Update current features for the next iteration
16491568 current_features = [candidate_features [i ] for i in best_indices ]
@@ -1677,21 +1596,13 @@ def backward_greedy_feature_selection(
16771596 learning_rate = self .learning_rate ,
16781597 learning_rate_decay = self .learning_rate_decay ,
16791598 num_points_rows = self .num_points_rows ,
1680- gradient_clip_value = self .gradient_clip_value ,
16811599 )
16821600
16831601 # Set initial parameters and train
1684- try :
1685- _ , _ = dii_copy .train ()
1686- # Save optimal weights
1687- best_weights = dii_copy .params_final
1688- except AssertionError as e :
1689- print (
1690- f"Training failed for best feature set { best_feature_set } : { str (e )} "
1691- )
1692- print (f"Using zero weights for this iteration..." )
1693- best_weights = np .zeros (n_features )
1602+ _ , _ = dii_copy .train ()
16941603
1604+ # Save optimal weights
1605+ best_weights = dii_copy .params_final
16951606 best_weights_list .append (best_weights )
16961607
16971608 # Store results
0 commit comments