@@ -13,7 +13,7 @@ def _beam_search(
1313 X , V , n_features_to_select , beam_width , indices_include , mask_exclude , tol , verbose
1414):
1515 """
16- Perform beam search to find the best subset of features .
16+ Beam search with SSC .
1717
1818 Parameters:
1919 X : np.ndarray
@@ -48,12 +48,16 @@ def _beam_search(
4848 X_support , X_selected = _prepare_candidates (
4949 X , mask_exclude , indices_include
5050 )
51+ if X_selected .shape [1 ] == 0 :
52+ beams_scores = np .sum ((X .T @ V ) ** 2 , axis = 1 )
53+ beams_scores [~ X_support ] = 0
54+ else :
55+ W_selected = orth (X_selected )
56+ selected_score = np .sum ((W_selected .T @ V ) ** 2 )
57+ beams_scores = _gram_schmidt (
58+ X , X_support , X_selected , selected_score , V , tol
59+ )
5160 beams_selected_ids = [indices_include for _ in range (beam_width )]
52- W_selected = orth (X_selected )
53- selected_score = np .sum ((W_selected .T @ V ) ** 2 )
54- beams_scores = _gram_schmidt (
55- X , X_support , X_selected , selected_score , V , tol
56- )
5761 beams_selected_ids , top_k_scores = _select_top_k (
5862 beams_scores [None , :],
5963 beams_selected_ids ,
@@ -123,11 +127,9 @@ def _select_top_k(
123127 return new_ids_selected , top_k_scores
124128
125129
126- def _gram_schmidt (X , X_support , X_selected , selected_score , V , tol , modified = True ):
130+ def _gram_schmidt (X , X_support , X_selected , selected_score , V , tol ):
127131 X = np .copy (X )
128- if modified :
129- # Change to Modified Gram-Schmidt
130- W_selected = orth (X_selected )
132+ W_selected = orth (X_selected ) # Change to Modified Gram-Schmidt
131133 scores = np .zeros (X .shape [1 ])
132134 for i , support in enumerate (X_support ):
133135 if not support :
0 commit comments