Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@ New features

Changes
-------
- The :class:`StringEncoder` now exposes the ``vocabulary`` parameter from the parent
:class:`TfidfVectorizer`.
:pr:`1819` by :user:`Eloi Massoulié <emassoulie>`


Bugfixes
--------
Expand Down
39 changes: 27 additions & 12 deletions skrub/_string_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,11 @@ class StringEncoder(TransformerMixin, SingleColumnTransformer):
Used during randomized svd. Pass an int for reproducible results across
multiple function calls.

vocabulary : Mapping or iterable, default=None
In case of "tfidf" vectorizer, the vocabulary mapping passed to the vectorizer.
Either a Mapping (e.g., a dict) where keys are terms and values are
indices in the feature matrix, or an iterable over terms.

Attributes
----------
input_name_ : str
Expand Down Expand Up @@ -131,13 +136,15 @@ def __init__(
analyzer="char_wb",
stop_words=None,
random_state=None,
vocabulary=None,
):
self.n_components = n_components
self.vectorizer = vectorizer
self.ngram_range = ngram_range
self.analyzer = analyzer
self.stop_words = stop_words
self.random_state = random_state
self.vocabulary = vocabulary

def fit_transform(self, X, y=None):
"""Fit the encoder and transform a column.
Expand Down Expand Up @@ -165,21 +172,29 @@ def fit_transform(self, X, y=None):
ngram_range=self.ngram_range,
analyzer=self.analyzer,
stop_words=self.stop_words,
vocabulary=self.vocabulary,
)
elif self.vectorizer == "hashing":
self.vectorizer_ = Pipeline(
[
(
"hashing",
HashingVectorizer(
ngram_range=self.ngram_range,
analyzer=self.analyzer,
stop_words=self.stop_words,
if self.vocabulary is not None:
raise ValueError(
"Custom vocabulary passed to StringEncoder, unsupported by"
"HashingVectorizer. Rerun without a 'vocabulary' parameter."
)
else:
self.vectorizer_ = Pipeline(
[
(
"hashing",
HashingVectorizer(
ngram_range=self.ngram_range,
analyzer=self.analyzer,
stop_words=self.stop_words,
),
),
),
("tfidf", TfidfTransformer()),
]
)
("tfidf", TfidfTransformer()),
]
)

else:
raise ValueError(
f"Unknown vectorizer {self.vectorizer}. Options are 'tfidf' or"
Expand Down
46 changes: 46 additions & 0 deletions skrub/tests/test_string_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,3 +310,49 @@ def test_zero_padding_in_feature_names_out(df_module, n_components, expected_col
feature_names = encoder.get_feature_names_out()

assert feature_names[: len(expected_columns)] == expected_columns


def test_vocabulary_parameter(df_module):
voc = {
"this": 5,
"is": 1,
"simple": 3,
"example": 0,
"this is": 6,
"is simple": 2,
"simple example": 4,
}
encoder = StringEncoder(n_components=2, vocabulary=voc)
pipeline = Pipeline(
[
(
"tfidf",
TfidfVectorizer(ngram_range=(3, 4), analyzer="char_wb", vocabulary=voc),
),
("tsvd", TruncatedSVD()),
]
)
X = df_module.make_column(
"col",
["this is a sentence", "this simple example is simple", "other words", ""],
)

enc_out = encoder.fit_transform(X)
pipe_out = pipeline.fit_transform(X)
pipe_out /= scaling_factor(pipe_out)

assert encoder.vectorizer_.vocabulary_ == voc
assert_almost_equal(enc_out, pipe_out)


def test_vocabulary_on_hashing_vectorizer(df_module):
voc = {
"this": 5,
}
encoder = StringEncoder(vocabulary=voc, vectorizer="hashing")
with pytest.raises(ValueError, match="Custom vocabulary passed to StringEncoder*"):
X = df_module.make_column(
"col",
["this is a sentence", "this simple example is simple", "other words", ""],
)
encoder.fit_transform(X)