Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 20 additions & 3 deletions skrub/_apply_to_frame.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import warnings

from scipy import sparse as sp
from sklearn.base import BaseEstimator, TransformerMixin, clone
from sklearn.utils.validation import check_is_fitted

Expand Down Expand Up @@ -173,6 +176,11 @@ def fit_transform(self, X, y=None):
The transformed data.
"""
self.all_inputs_ = sbd.column_names(X)

# for sklearn
self.feature_names_in_ = self.all_inputs_
self.n_features_in_ = len(self.all_inputs_)

self._columns = selectors.make_selector(self.cols).expand(X)
to_transform = selectors.select(X, self._columns)
if self.keep_original:
Expand All @@ -184,6 +192,18 @@ def fit_transform(self, X, y=None):
self.transformer_ = clone(self.transformer)
_utils.set_output(self.transformer_, X)
transformed = self.transformer_.fit_transform(to_transform, y)
if sp.issparse(transformed):
warnings.warn(
f"The output of {type(self.transformer_).__name__!r} is a sparse"
" array or sparse matrix."
)
if sbd.column_names(to_transform) != sbd.column_names(X):
raise ValueError(
"When a transformer outputs a sparse array or sparse matrix,"
" all columns of the input dataframe must beselected. Got"
f" cols={self.cols!r} instead of skrub.selectors.all()."
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi Vincent, I have a naive question/suggestion about this raise, would it be a good idea to put a brief observation in the docstring of cols arg to use all columns for sparse objects (despite that it is the default behaviour) ?

)
return transformed
transformed = _utils.check_output(
self.transformer_, to_transform, transformed, allow_column_list=False
)
Expand All @@ -205,9 +225,6 @@ def fit_transform(self, X, y=None):
self.used_inputs_ = self._columns
self.created_outputs_ = self._transformed_output_names
self.all_outputs_ = passthrough_names + self._transformed_output_names
# for sklearn
self.feature_names_in_ = self.all_inputs_
self.n_features_in_ = len(self.all_inputs_)

result = sbd.copy_index(X, result)
return result
Expand Down
12 changes: 12 additions & 0 deletions skrub/tests/test_on_subframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,16 @@
import pandas as pd
import pytest
from pandas.testing import assert_index_equal
from scipy import sparse as sp
from sklearn.base import BaseEstimator
from sklearn.feature_extraction.text import HashingVectorizer
from sklearn.preprocessing import FunctionTransformer

from skrub import SelectCols
from skrub import _dataframe as sbd
from skrub import selectors as s
from skrub._apply_to_frame import ApplyToFrame
from skrub.datasets import toy_orders


class Dummy(BaseEstimator):
Expand Down Expand Up @@ -116,3 +119,12 @@ def test_output_index(cols):
assert_index_equal(transformer.fit_transform(df).index, df.index)
df = pd.DataFrame({"a": [10, 20], "b": [1.1, 2.2]}, index=[-10, 20])
assert_index_equal(transformer.transform(df).index, df.index)


def test_sparse_output():
X = toy_orders().X
X_t = ApplyToFrame(HashingVectorizer()).fit_transform(X)
assert sp.issparse(X_t)

with pytest.raises(ValueError):
_ = ApplyToFrame(HashingVectorizer(), cols=[X.columns[0]]).fit_transform(X)
Loading