diff --git a/CHANGELOG.md b/CHANGELOG.md index 2cc3a11..78edc51 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,9 @@ # Changelog +## Unreleased + +- Add `replace`, `__iter__`, and `__getitem__` to pipeline + ## 0.8.0 (2025-03-08) - Drop support for Python 3.7 and 3.8 @@ -139,4 +143,4 @@ meta-data whitelisting. ## 0.1.1a1 (2018-03-01) -- Initial release \ No newline at end of file +- Initial release diff --git a/docs/customisation.md b/docs/customisation.md index 6dfe318..29971d0 100644 --- a/docs/customisation.md +++ b/docs/customisation.md @@ -76,9 +76,7 @@ documents = [...] builder = get_default_builder("fr") for funcname in "stopWordFilter-fr", "stemmer-fr": - builder.pipeline.skip( - builder.pipeline.registered_functions[funcname], ["titre"] - ) + builder.pipeline.skip(builder.pipeline[funcname], ["titre"]) idx = lunr(ref="id", fields=("titre", "texte"), documents=documents, builder=builder) ``` @@ -86,11 +84,10 @@ idx = lunr(ref="id", fields=("titre", "texte"), documents=documents, builder=bui The current language support registers the functions `lunr-multi-trimmer-{lang}`, `stopWordFilter-{lang}` and `stemmer-{lang}` but these are by convention only. You can access the -full list through the `registered_functions` attribute of the -pipeline, but this is not necessarily the list of actual pipeline -steps, which is contained in a private field (though you can see them -in the string representation of the pipeline). - +full list (for all languages) through the `registered_functions` +attribute of the pipeline, but this is not necessarily the list of +steps for a given pipeline. The list of names of registered functions +in a pipeline can be obtained by iterating over it or converting it to a `list`. ## Token meta-data diff --git a/lunr/pipeline.py b/lunr/pipeline.py index 0e0ab4e..98bdc4d 100644 --- a/lunr/pipeline.py +++ b/lunr/pipeline.py @@ -21,12 +21,25 @@ def __init__(self): self._skip: Dict[Callable, Set[str]] = defaultdict(set) def __len__(self): + # Note this does not necessarily match len(list(self._stack)) + # because unregistered functions are not iterated over... return len(self._stack) def __repr__(self): return ''.format(",".join(fn.label for fn in self._stack)) - # TODO: add iterator methods? + def __getitem__(self, label: str) -> Callable: + for fn in self._stack: + if hasattr(fn, "label") and fn.label == label: + return fn + raise BaseLunrException( + f"Cannot find registered function {label} in pipeline" + ) from KeyError + + def __iter__(self): + for fn in self._stack: + if hasattr(fn, "label"): + yield fn.label @classmethod def register_function(cls, fn, label=None): @@ -106,6 +119,14 @@ def remove(self, fn): except ValueError: pass + def replace(self, existing_fn, new_fn): + """Replaces a function in the pipeline with a better one.""" + try: + index = self._stack.index(existing_fn) + self._stack[index] = new_fn + except ValueError as e: + raise BaseLunrException("Cannot find existing_fn") from e + def skip(self, fn: Callable, field_names: List[str]): """ Make the pipeline skip the function based on field name we're processing. diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index e248961..9403ef4 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -275,3 +275,44 @@ def test_reset_empties_the_stack(self): self.pipeline.reset() assert len(self.pipeline) == 0 + + +class TestAccess(BaseTestPipeline): + def test_access_function_in_pipeline(self): + Pipeline.register_function(fn, "fn") + self.pipeline.add(fn) + assert self.pipeline["fn"] == fn + + def test_access_function_not_in_pipeline(self): + with pytest.raises(BaseLunrException): + _ = self.pipeline["fn"] + + +class TestReplace(BaseTestPipeline): + def test_replace_function_in_pipeline(self): + Pipeline.register_function(fn, "fn") + self.pipeline.add(noop) + assert len(self.pipeline) == 1 + with pytest.raises(BaseLunrException): + _ = self.pipeline["fn"] + + self.pipeline.replace(noop, fn) + assert len(self.pipeline) == 1 + assert self.pipeline["fn"] == fn + + +class TestIterate: + def test_iterate(self): + def fn1(t, *args): + return t + + def fn2(t, *args): + return t + + pipeline = Pipeline() + pipeline.register_function(fn1) + pipeline.register_function(fn2, "foo") + pipeline.add(fn1, fn2) + assert list(pipeline) == ["fn1", "foo"] + assert "fn1" in pipeline + assert "fn2" not in pipeline