Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
# Changelog

## Unreleased

- Add `replace`, `__iter__`, and `__getitem__` to pipeline

## 0.8.0 (2025-03-08)

- Drop support for Python 3.7 and 3.8
Expand Down Expand Up @@ -139,4 +143,4 @@ meta-data whitelisting.

## 0.1.1a1 (2018-03-01)

- Initial release
- Initial release
13 changes: 5 additions & 8 deletions docs/customisation.md
Original file line number Diff line number Diff line change
Expand Up @@ -76,21 +76,18 @@ documents = [...]
builder = get_default_builder("fr")

for funcname in "stopWordFilter-fr", "stemmer-fr":
builder.pipeline.skip(
builder.pipeline.registered_functions[funcname], ["titre"]
)
builder.pipeline.skip(builder.pipeline[funcname], ["titre"])

idx = lunr(ref="id", fields=("titre", "texte"), documents=documents, builder=builder)
```

The current language support registers the functions
`lunr-multi-trimmer-{lang}`, `stopWordFilter-{lang}` and
`stemmer-{lang}` but these are by convention only. You can access the
full list through the `registered_functions` attribute of the
pipeline, but this is not necessarily the list of actual pipeline
steps, which is contained in a private field (though you can see them
in the string representation of the pipeline).

full list (for all languages) through the `registered_functions`
attribute of the pipeline, but this is not necessarily the list of
steps for a given pipeline. The list of names of registered functions
in a pipeline can be obtained by iterating over it or converting it to a `list`.

## Token meta-data

Expand Down
23 changes: 22 additions & 1 deletion lunr/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,25 @@ def __init__(self):
self._skip: Dict[Callable, Set[str]] = defaultdict(set)

def __len__(self):
# Note this does not necessarily match len(list(self._stack))
# because unregistered functions are not iterated over...
return len(self._stack)

def __repr__(self):
return '<Pipeline stack="{}">'.format(",".join(fn.label for fn in self._stack))

# TODO: add iterator methods?
def __getitem__(self, label: str) -> Callable:
for fn in self._stack:
if hasattr(fn, "label") and fn.label == label:
return fn
raise BaseLunrException(
f"Cannot find registered function {label} in pipeline"
) from KeyError

def __iter__(self):
for fn in self._stack:
if hasattr(fn, "label"):
yield fn.label

@classmethod
def register_function(cls, fn, label=None):
Expand Down Expand Up @@ -106,6 +119,14 @@ def remove(self, fn):
except ValueError:
pass

def replace(self, existing_fn, new_fn):
"""Replaces a function in the pipeline with a better one."""
try:
index = self._stack.index(existing_fn)
self._stack[index] = new_fn
except ValueError as e:
raise BaseLunrException("Cannot find existing_fn") from e

def skip(self, fn: Callable, field_names: List[str]):
"""
Make the pipeline skip the function based on field name we're processing.
Expand Down
41 changes: 41 additions & 0 deletions tests/test_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,3 +275,44 @@ def test_reset_empties_the_stack(self):

self.pipeline.reset()
assert len(self.pipeline) == 0


class TestAccess(BaseTestPipeline):
def test_access_function_in_pipeline(self):
Pipeline.register_function(fn, "fn")
self.pipeline.add(fn)
assert self.pipeline["fn"] == fn

def test_access_function_not_in_pipeline(self):
with pytest.raises(BaseLunrException):
_ = self.pipeline["fn"]


class TestReplace(BaseTestPipeline):
def test_replace_function_in_pipeline(self):
Pipeline.register_function(fn, "fn")
self.pipeline.add(noop)
assert len(self.pipeline) == 1
with pytest.raises(BaseLunrException):
_ = self.pipeline["fn"]

self.pipeline.replace(noop, fn)
assert len(self.pipeline) == 1
assert self.pipeline["fn"] == fn


class TestIterate:
def test_iterate(self):
def fn1(t, *args):
return t

def fn2(t, *args):
return t

pipeline = Pipeline()
pipeline.register_function(fn1)
pipeline.register_function(fn2, "foo")
pipeline.add(fn1, fn2)
assert list(pipeline) == ["fn1", "foo"]
assert "fn1" in pipeline
assert "fn2" not in pipeline