@@ -159,48 +159,14 @@ def __init__(
159159 handle_unknown = handle_unknown ,
160160 handle_missing = handle_missing ,
161161 )
162- self .ordinal_encoder = None
163162 self .min_samples_leaf = min_samples_leaf
164163 self .smoothing = smoothing
164+ self .hierarchy = hierarchy
165+ self .ordinal_encoder = None
165166 self .mapping = None
166167 self ._mean = None
167- # @ToDo create a function to check the hierarchy
168- if isinstance (hierarchy , (dict , pd .DataFrame )) and cols is None :
169- raise ValueError ('Hierarchy is defined but no columns are named for encoding' )
170- if isinstance (hierarchy , dict ):
171- self .hierarchy = {}
172- self .hierarchy_depth = {}
173- for switch in hierarchy :
174- flattened_hierarchy = util .flatten_reverse_dict (hierarchy [switch ])
175- hierarchy_check = self ._check_dict_key_tuples (flattened_hierarchy )
176- self .hierarchy_depth [switch ] = hierarchy_check [1 ]
177- if not hierarchy_check [0 ]:
178- raise ValueError (
179- 'Hierarchy mapping contains different levels for key "' + switch + '"'
180- )
181- self .hierarchy [switch ] = {
182- (k if isinstance (t , tuple ) else t ): v
183- for t , v in flattened_hierarchy .items ()
184- for k in t
185- }
186- elif isinstance (hierarchy , pd .DataFrame ):
187- self .hierarchy = hierarchy
188- self .hierarchy_depth = {}
189- for col in self .cols :
190- HIER_cols = self .hierarchy .columns [
191- self .hierarchy .columns .str .startswith (f'HIER_{ col } ' )
192- ].tolist ()
193- HIER_levels = [int (i .replace (f'HIER_{ col } _' , '' )) for i in HIER_cols ]
194- if np .array_equal (sorted (HIER_levels ), np .arange (1 , max (HIER_levels ) + 1 )):
195- self .hierarchy_depth [col ] = max (HIER_levels )
196- else :
197- raise ValueError (f'Hierarchy columns are not complete for column { col } ' )
198- elif hierarchy is None :
199- self .hierarchy = hierarchy
200- else :
201- raise ValueError ('Given hierarchy mapping is neither a dictionary nor a dataframe' )
202-
203- self .cols_hier = []
168+ # Call this in the constructor only for the possible side effect of raising an exception.
169+ self ._generate_inverted_hierarchy ()
204170
205171 @staticmethod
206172 def _check_dict_key_tuples (dict_to_check : dict [Any , tuple ]) -> tuple [bool , int ]:
@@ -219,24 +185,25 @@ def _check_dict_key_tuples(dict_to_check: dict[Any, tuple]) -> tuple[bool, int]:
219185 return min_tuple_size == max_tuple_size , min_tuple_size
220186
221187 def _fit (self , X : util .X_type , y : util .y_type , ** kwargs ) -> None :
222- if isinstance (self .hierarchy , dict ):
188+ inverted_hierarchy , self .hierarchy_depth = self ._generate_inverted_hierarchy ()
189+ if isinstance (inverted_hierarchy , dict ):
223190 X_hier = pd .DataFrame ()
224- for switch in self . hierarchy :
191+ for switch in inverted_hierarchy :
225192 if switch in self .cols :
226193 colnames = [
227194 f'HIER_{ str (switch )} _{ str (i + 1 )} '
228195 for i in range (self .hierarchy_depth [switch ])
229196 ]
230197 df = pd .DataFrame (
231- X [str (switch )].map (self . hierarchy [str (switch )]).tolist (),
198+ X [str (switch )].map (inverted_hierarchy [str (switch )]).tolist (),
232199 index = X .index ,
233200 columns = colnames ,
234201 )
235202 X_hier = pd .concat ([X_hier , df ], axis = 1 )
236- elif isinstance (self . hierarchy , pd .DataFrame ):
237- X_hier = self . hierarchy
203+ elif isinstance (inverted_hierarchy , pd .DataFrame ):
204+ X_hier = inverted_hierarchy
238205
239- if isinstance (self . hierarchy , (dict , pd .DataFrame )):
206+ if isinstance (inverted_hierarchy , (dict , pd .DataFrame )):
240207 enc_hier = OrdinalEncoder (
241208 verbose = self .verbose ,
242209 cols = X_hier .columns ,
@@ -251,7 +218,7 @@ def _fit(self, X: util.X_type, y: util.y_type, **kwargs) -> None:
251218 )
252219 self .ordinal_encoder = self .ordinal_encoder .fit (X )
253220 X_ordinal = self .ordinal_encoder .transform (X )
254- if self . hierarchy is not None :
221+ if inverted_hierarchy is not None :
255222 self .mapping = self .fit_target_encoding (
256223 pd .concat ([X_ordinal , X_hier_ordinal ], axis = 1 ), y
257224 )
@@ -344,3 +311,42 @@ def _weighting(self, n: int) -> float:
344311 # monotonically increasing function of n bounded between 0 and 1
345312 # sigmoid in this case, using scipy.expit for numerical stability
346313 return expit ((n - self .min_samples_leaf ) / self .smoothing )
314+
315+ def _generate_inverted_hierarchy (self ) -> tuple [dict | pd .DataFrame | None , dict ]:
316+ if isinstance (self .hierarchy , (dict , pd .DataFrame )) and self .cols is None :
317+ raise ValueError ('Hierarchy is defined but no columns are named for encoding' )
318+ if isinstance (self .hierarchy , dict ):
319+ inverted_hierarchy = {}
320+ hierarchy_depth = {}
321+ for switch in self .hierarchy :
322+ flattened_hierarchy = util .flatten_reverse_dict (self .hierarchy [switch ])
323+ hierarchy_check = self ._check_dict_key_tuples (flattened_hierarchy )
324+ hierarchy_depth [switch ] = hierarchy_check [1 ]
325+ if not hierarchy_check [0 ]:
326+ raise ValueError (
327+ 'Hierarchy mapping contains different levels for key "' + switch + '"'
328+ )
329+ inverted_hierarchy [switch ] = {
330+ (k if isinstance (t , tuple ) else t ): v
331+ for t , v in flattened_hierarchy .items ()
332+ for k in t
333+ }
334+ elif isinstance (self .hierarchy , pd .DataFrame ):
335+ inverted_hierarchy = self .hierarchy
336+ hierarchy_depth = {}
337+ for col in self .cols :
338+ HIER_cols = inverted_hierarchy .columns [
339+ inverted_hierarchy .columns .str .startswith (f'HIER_{ col } ' )
340+ ].tolist ()
341+ HIER_levels = [int (i .replace (f'HIER_{ col } _' , '' )) for i in HIER_cols ]
342+ if np .array_equal (sorted (HIER_levels ), np .arange (1 , max (HIER_levels ) + 1 )):
343+ hierarchy_depth [col ] = max (HIER_levels )
344+ else :
345+ raise ValueError (f'Hierarchy columns are not complete for column { col } ' )
346+ elif self .hierarchy is None :
347+ inverted_hierarchy = None
348+ hierarchy_depth = {}
349+ else :
350+ raise ValueError ('Given hierarchy mapping is neither a dictionary nor a dataframe' )
351+ return inverted_hierarchy , hierarchy_depth
352+
0 commit comments