@@ -45,17 +45,17 @@ def _beam_search(
4545
4646 for i in range (n_features_to_select - n_inclusions ):
4747 if i == 0 :
48- X_support , X_selected = _prepare_candidates (
48+ mask , X_selected = _prepare_candidates (
4949 X , mask_exclude , indices_include
5050 )
5151 if X_selected .shape [1 ] == 0 :
5252 beams_scores = np .sum ((X .T @ V ) ** 2 , axis = 1 )
53- beams_scores [~ X_support ] = 0
53+ beams_scores [mask ] = 0
5454 else :
5555 W_selected = orth (X_selected )
5656 selected_score = np .sum ((W_selected .T @ V ) ** 2 )
57- beams_scores = _gram_schmidt (
58- X , X_support , X_selected , selected_score , V , tol
57+ beams_scores = _mgs_ssc (
58+ X , V , W_selected , mask , selected_score , tol
5959 )
6060 beams_selected_ids = [indices_include for _ in range (beam_width )]
6161 beams_selected_ids , top_k_scores = _select_top_k (
@@ -66,11 +66,12 @@ def _beam_search(
6666 continue
6767 beams_scores = np .zeros ((beam_width , n_features ))
6868 for beam_idx in range (beam_width ):
69- X_support , X_selected = _prepare_candidates (
69+ mask , X_selected = _prepare_candidates (
7070 X , mask_exclude , beams_selected_ids [beam_idx ]
7171 )
72- beams_scores [beam_idx ] = _gram_schmidt (
73- X , X_support , X_selected , top_k_scores [beam_idx ], V , tol
72+ W_selected = orth (X_selected )
73+ beams_scores [beam_idx ] = _mgs_ssc (
74+ X , V , W_selected , mask , top_k_scores [beam_idx ], tol
7475 )
7576 beams_selected_ids , top_k_scores = _select_top_k (
7677 beams_scores ,
@@ -92,10 +93,10 @@ def _beam_search(
9293
9394
9495def _prepare_candidates (X , mask_exclude , indices_selected ):
95- X_support = np .copy (~ mask_exclude )
96- X_support [indices_selected ] = False
96+ mask = np .copy (mask_exclude )
97+ mask [indices_selected ] = True
9798 X_selected = X [:, indices_selected ]
98- return X_support , X_selected
99+ return mask , X_selected
99100
100101
101102def _select_top_k (
@@ -127,31 +128,22 @@ def _select_top_k(
127128 return new_ids_selected , top_k_scores
128129
129130
130- def _gram_schmidt (X , X_support , X_selected , selected_score , V , tol ):
131+ def _mgs_ssc (X , V , W_selected , mask , selected_score , tol ):
131132 X = np .copy (X )
132- W_selected = orth (X_selected ) # Change to Modified Gram-Schmidt
133- scores = np .zeros (X .shape [1 ])
134- for i , support in enumerate (X_support ):
135- if not support :
136- continue
137- xi = X [:, i : i + 1 ]
138- for j in range (W_selected .shape [1 ]):
139- proj = W_selected [:, j : j + 1 ].T @ xi
140- xi -= proj * W_selected [:, j : j + 1 ]
141- wi , X_support [i ] = _safe_normalize (xi )
142- if not X_support [i ]:
143- continue
144- if np .any (np .abs (wi .T @ W_selected ) > tol ):
145- X_support [i ] = False
146- continue
147- scores [i ] = np .sum ((wi .T @ V ) ** 2 )
133+ proj = W_selected .T @ X
134+ X -= W_selected @ proj
135+ W , non_zero_mask = _safe_normalize (X )
136+ mask |= non_zero_mask
137+ linear_independent_mask = np .any (np .abs (W .T @ W_selected ) > tol , axis = 1 )
138+ mask |= linear_independent_mask
139+ scores = np .sum ((W .T @ V ) ** 2 , axis = 1 )
148140 scores += selected_score
149- scores [~ X_support ] = 0
141+ scores [mask ] = 0
150142 return scores
151143
152144
153145def _safe_normalize (X ):
154146 norm = np .linalg .norm (X , axis = 0 )
155- non_zero_support = norm ! = 0
156- norm [~ non_zero_support ] = 1
157- return X / norm , non_zero_support
147+ non_zero_mask = norm = = 0
148+ norm [non_zero_mask ] = 1
149+ return X / norm , non_zero_mask
0 commit comments