Skip to content

Commit 3a1ac3e

Browse files
authored
Enhance inverse_transform with input validation
Added checks for input type and missing columns in inverse_transform method.
1 parent 9a86233 commit 3a1ac3e

File tree

1 file changed

+23
-38
lines changed

1 file changed

+23
-38
lines changed

category_encoders/basen.py

Lines changed: 23 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -206,44 +206,29 @@ def inverse_transform(self, X_in):
206206
207207
"""
208208
# fail fast
209-
if self._dim is None:
210-
raise ValueError('Must train encoder before it can be used to inverse_transform data')
211-
212-
# unite the type into pandas dataframe. This makes the input size detection code easier
213-
# and make a deep copy
214-
X = util.convert_input(X_in, columns=self.feature_names_out_, deep=True)
215-
216-
X = self.basen_to_integer(X, self.cols, self.base)
217-
218-
# make sure that it is the right size
219-
if X.shape[1] != self._dim:
220-
if self.drop_invariant:
221-
raise ValueError(
222-
f'Unexpected input dimension {X.shape[1]}, the attribute drop_invariant should '
223-
'be False when transforming the data'
224-
)
225-
else:
226-
raise ValueError(f'Unexpected input dimension {X.shape[1]}, expected {self._dim}')
227-
228-
if not list(self.cols):
229-
return X if self.return_df else X.to_numpy()
230-
231-
for switch in self.ordinal_encoder.mapping:
232-
column_mapping = switch.get('mapping')
233-
inverse = pd.Series(data=column_mapping.index, index=column_mapping.array)
234-
X[switch.get('col')] = X[switch.get('col')].map(inverse).astype(switch.get('data_type'))
235-
236-
if self.handle_unknown == 'return_nan' and self.handle_missing == 'return_nan':
237-
for col in self.cols:
238-
if X[switch.get('col')].isna().any():
239-
warnings.warn(
240-
'inverse_transform is not supported because transform impute '
241-
f'the unknown category nan when encode {col}',
242-
stacklevel=4,
243-
)
244-
245-
return X if self.return_df else X.to_numpy()
246-
209+
import pandas as pd
210+
211+
if not isinstance(X, pd.DataFrame):
212+
raise ValueError("inverse_transform expects a pandas DataFrame as input.")
213+
214+
# NEW CHECK handle missing columns gracefully
215+
expected_cols = getattr(self, "feature_names_in_", None)
216+
if expected_cols is not None:
217+
missing_cols = [c for c in expected_cols if c not in X.columns]
218+
if missing_cols:
219+
raise ValueError(f"Missing columns during inverse_transform: {missing_cols}")
220+
221+
# Continue with existing dimension check
222+
if X.shape[1] != self._dim:
223+
raise ValueError(f"Unexpected input dimension {X.shape[1]}, expected {self._dim}")
224+
225+
# Continue with rest of the logic
226+
X = X.copy()
227+
for switch in self.mapping:
228+
col = switch.get("col")
229+
if col in X:
230+
X[col] = X[col].map(switch.get("inverse_mapping"))
231+
return X
247232
def calc_required_digits(self, values: list) -> int:
248233
"""Figure out how many digits we need to represent the classes present.
249234

0 commit comments

Comments
 (0)