Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 23 additions & 38 deletions category_encoders/basen.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,44 +206,29 @@ def inverse_transform(self, X_in):

"""
# fail fast
if self._dim is None:
raise ValueError('Must train encoder before it can be used to inverse_transform data')

# unite the type into pandas dataframe. This makes the input size detection code easier
# and make a deep copy
X = util.convert_input(X_in, columns=self.feature_names_out_, deep=True)

X = self.basen_to_integer(X, self.cols, self.base)

# make sure that it is the right size
if X.shape[1] != self._dim:
if self.drop_invariant:
raise ValueError(
f'Unexpected input dimension {X.shape[1]}, the attribute drop_invariant should '
'be False when transforming the data'
)
else:
raise ValueError(f'Unexpected input dimension {X.shape[1]}, expected {self._dim}')

if not list(self.cols):
return X if self.return_df else X.to_numpy()

for switch in self.ordinal_encoder.mapping:
column_mapping = switch.get('mapping')
inverse = pd.Series(data=column_mapping.index, index=column_mapping.array)
X[switch.get('col')] = X[switch.get('col')].map(inverse).astype(switch.get('data_type'))

if self.handle_unknown == 'return_nan' and self.handle_missing == 'return_nan':
for col in self.cols:
if X[switch.get('col')].isna().any():
warnings.warn(
'inverse_transform is not supported because transform impute '
f'the unknown category nan when encode {col}',
stacklevel=4,
)

return X if self.return_df else X.to_numpy()

import pandas as pd

if not isinstance(X, pd.DataFrame):
raise ValueError("inverse_transform expects a pandas DataFrame as input.")

# NEW CHECK handle missing columns gracefully
expected_cols = getattr(self, "feature_names_in_", None)
if expected_cols is not None:
missing_cols = [c for c in expected_cols if c not in X.columns]
if missing_cols:
raise ValueError(f"Missing columns during inverse_transform: {missing_cols}")

# Continue with existing dimension check
if X.shape[1] != self._dim:
raise ValueError(f"Unexpected input dimension {X.shape[1]}, expected {self._dim}")

# Continue with rest of the logic
X = X.copy()
for switch in self.mapping:
col = switch.get("col")
if col in X:
X[col] = X[col].map(switch.get("inverse_mapping"))
return X
def calc_required_digits(self, values: list) -> int:
"""Figure out how many digits we need to represent the classes present.

Expand Down
Loading