66# SPDX-License-Identifier: MIT 
77
88import  math 
9+ import  warnings 
910from  itertools  import  combinations_with_replacement 
1011from  numbers  import  Integral 
1112
@@ -198,12 +199,16 @@ def make_poly_features(X, ids):
198199            None , 
199200            Interval (Integral , 1 , None , closed = "left" ), 
200201        ], 
202+         "max_poly" : [None , Interval (Integral , 1 , None , closed = "left" )], 
203+         "random_state" : ["random_state" ], 
201204    }, 
202205    prefer_skip_nested_validation = True , 
203206) 
204207def  make_poly_ids (
205208    n_features = 1 ,
206209    degree = 1 ,
210+     max_poly = None ,
211+     random_state = None ,
207212):
208213    """Generate ids for polynomial features. 
209214    (variable_index, variable_index, ...) 
@@ -217,6 +222,15 @@ def make_poly_ids(
217222    degree : int, default=1 
218223        The maximum degree of polynomial features. 
219224
225+     max_poly : int, default=None 
226+         Maximum number of ids of polynomial features to generate. 
227+         Randomly selected by reservoir sampling. 
228+         If None, all possible ids are returned. 
229+ 
230+     random_state : int or RandomState instance, default=None 
231+         Used when `max_poly` is not None to subsample ids of polynomial features. 
232+         See :term:`Glossary <random_state>` for details. 
233+ 
220234    Returns 
221235    ------- 
222236    ids : array-like of shape (n_outputs, degree) 
@@ -236,29 +250,45 @@ def make_poly_ids(
236250           [1, 2, 2], 
237251           [2, 2, 2]]) 
238252    """ 
239-     n_outputs  =  math .comb (n_features  +  degree , degree ) -  1 
240-     if  n_outputs  >  np .iinfo (np .intp ).max :
253+     n_total  =  math .comb (n_features  +  degree , degree ) -  1 
254+     if  n_total  >  np .iinfo (np .intp ).max :
241255        msg  =  (
242-             "The output that would result from the  current configuration would" 
243-             f" have  { n_outputs }   features which is too large to be" 
244-             f"  indexed by { np .intp ().dtype .name }  ." 
256+             "The current configuration would  " 
257+             f"result in  { n_total }   features which is too large to be  " 
258+             f"indexed by { np .intp ().dtype .name }  ." 
245259        )
246260        raise  ValueError (msg )
247- 
248-     ids  =  np .array (
249-         list (
250-             combinations_with_replacement (
251-                 range (n_features  +  1 ),
252-                 degree ,
253-             )
261+     if  n_total  >  10_000_000 :
262+         warnings .warn (
263+             "Total number of polynomial features is larger than 10,000,000! " 
264+             f"The current configuration would result in { n_total }   features. " 
265+             "This may take a while." ,
266+             UserWarning ,
267+         )
268+     if  max_poly  is  not   None  and  max_poly  <  n_total :
269+         # reservoir sampling 
270+         rng  =  np .random .default_rng (random_state )
271+         reservoir  =  []
272+         for  i , comb  in  enumerate (
273+             combinations_with_replacement (range (n_features  +  1 ), degree )
274+         ):
275+             if  i  <  max_poly :
276+                 reservoir .append (comb )
277+             else :
278+                 j  =  rng .integers (0 , i  +  1 )
279+                 if  j  <  max_poly :
280+                     reservoir [j ] =  comb 
281+         ids  =  np .array (reservoir )
282+     else :
283+         ids  =  np .array (
284+             list (combinations_with_replacement (range (n_features  +  1 ), degree ))
254285        )
255-     )
256286
257287    const_id  =  np .where ((ids  ==  0 ).all (axis = 1 ))
258288    return  np .delete (ids , const_id , 0 )  # remove the constant feature 
259289
260290
261- def  _valiate_time_shift_poly_ids (
291+ def  _validate_time_shift_poly_ids (
262292    time_shift_ids , poly_ids , n_samples = None , n_features = None , n_outputs = None 
263293):
264294    if  n_samples  is  None :
@@ -496,7 +526,7 @@ def tp2fd(time_shift_ids, poly_ids):
496526    [[-1  1] 
497527     [ 2  3]] 
498528    """ 
499-     _time_shift_ids , _poly_ids  =  _valiate_time_shift_poly_ids (
529+     _time_shift_ids , _poly_ids  =  _validate_time_shift_poly_ids (
500530        time_shift_ids ,
501531        poly_ids ,
502532    )
0 commit comments