Skip to content

Commit 67867ac

Browse files
committed
Made the prompt source optionally resilient to downstream errors
1 parent 4476a97 commit 67867ac

File tree

3 files changed

+163
-10
lines changed

3 files changed

+163
-10
lines changed

CHANGELOG.md

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,19 +4,25 @@
44
### Improvements
55
- Changes several logging statements to make the consistent with the rest of the code
66
- Refactored **lambda** to use AbstractFieldSegment for consistency
7-
- Removed fail_on_error parameter in **lambda** and **lambdaFilter**. This is an API breaking change.
7+
- Removed fail_on_error parameter in **lambda** and **lambdaFilter**. This is an API breaking change.
88
The segments will now always fail on an error and there is no option to silently fail.
99
- Changed chatterlang_serve so that it raises and exception an exits if given a script that can't be
1010
compiled. The previous behavior was that it would issue a log message and fall back to the default
1111
script.
12-
- Added unit test for parse_unknown_args and expanded it to support boolean flag parameters
12+
- Added unit test for parse_unknown_args and expanded it to support boolean flag parameters
1313
- Marked UMAP as deprecated. It will be removed in 1.0. It requires additional dependencies that no
14-
other core modules need.
14+
other core modules need.
1515
- Updated the ollama embedding and chat connector so that it uses the OLLAMA_SERVER_URL configuration
1616
variable as the host where ollama is installed. So OLLAMA_SERVER_URL can be set in the TOML configuration
1717
file or TALKPIPE_OLLAMA_SERVER_URL can be set as an environment variable.
1818
- Implemented better compile errors for when sources or segments are not found. It had been a key error.
1919
It will now be a compile error.
20+
- Added batch files for the tutorials so it is easier to run them on Windows
21+
- Enhanced **prompt** source to catch downstream pipeline errors and continue prompting instead of
22+
crashing. When a user enters invalid input that causes an exception in downstream segments, the error
23+
is displayed with a full traceback, and the prompt continues accepting input. This makes interactive
24+
pipelines more robust and user-friendly. The error-resilient behavior is enabled by default but can
25+
be disabled by setting `error_resilient=False`.
2026

2127
## 0.9.3
2228
### Improvements

src/talkpipe/pipe/io.py

Lines changed: 72 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,61 @@
55
import os
66
import pickle # nosec B403 - Used only for write operations, not loading untrusted data
77
import json
8+
import traceback
89
from pprint import pformat
910
from prompt_toolkit import PromptSession
1011

1112
from talkpipe.chatterlang.registry import register_source, register_segment
1213
import talkpipe.chatterlang.registry as registry
13-
from talkpipe.pipe.core import AbstractSource, source, AbstractSegment, segment
14+
from talkpipe.pipe.core import AbstractSource, source, AbstractSegment, segment, Pipeline
1415
from talkpipe.util import data_manipulation
1516

1617

18+
class ErrorResilientPromptPipeline(Pipeline):
19+
"""A special pipeline that handles errors in interactive prompt workflows.
20+
21+
When a downstream segment raises an exception, this pipeline catches it,
22+
displays the error, and continues processing the next item from the prompt.
23+
"""
24+
25+
def __init__(self, prompt_source, *operations):
26+
super().__init__(prompt_source, *operations)
27+
self.prompt_source = prompt_source
28+
29+
def transform(self, input_iter=None):
30+
"""Execute pipeline with error resilience for prompt-based workflows."""
31+
# Get the prompt generator
32+
prompt_iter = self.prompt_source()
33+
34+
# Build the downstream pipeline (everything after the prompt)
35+
downstream_ops = [op for op in self.operations if op is not self.prompt_source]
36+
37+
# Process each prompt input with error handling
38+
for user_input in prompt_iter:
39+
try:
40+
# Create a single-item iterator for this input
41+
current_iter = iter([user_input])
42+
43+
# Pass through each downstream operation
44+
for op in downstream_ops:
45+
current_iter = op(current_iter)
46+
47+
# Consume and yield results
48+
for result in current_iter:
49+
yield result
50+
51+
except Exception as e:
52+
# Catch and display errors, but continue prompting
53+
print(f"Error: {e}")
54+
traceback.print_exc()
55+
# Don't yield anything for this failed input, just continue to next prompt
56+
57+
def __or__(self, other):
58+
"""Support chaining additional operations to the error-resilient pipeline."""
59+
# Add the new operation to our operations list
60+
return ErrorResilientPromptPipeline(self.prompt_source, *self.operations[1:], other)
61+
62+
1763
@registry.register_segment(name="print")
1864
class Print(AbstractSegment):
1965
"""
@@ -73,22 +119,44 @@ def transform(self, input_iter: Annotated[Iterable[int], "Iterable input data"])
73119
@register_source('prompt')
74120
class Prompt(AbstractSource):
75121
"""A source that generates input from a prompt.
76-
122+
77123
This source will generate input from a prompt until the user enters an EOF.
78124
It is for creating interactive pipelines. It uses prompt_toolkit under the
79125
hood to provide a nice prompt experience.
126+
127+
To enable error recovery (continue prompting after downstream errors), this source
128+
needs to actively consume the downstream pipeline with error handling. This is done
129+
by overriding the __or__ method to wrap the downstream in error handling.
80130
"""
81131

82-
def __init__(self):
132+
def __init__(self, error_resilient: Annotated[bool, "If True, catches downstream errors and continues prompting"] = True):
83133
super().__init__()
84134
self.session = PromptSession()
135+
self.error_resilient = error_resilient
85136

86137
def generate(self) -> Iterable[str]:
87138
while True:
88139
try:
89-
yield self.session.prompt('> ')
140+
user_input = self.session.prompt('> ')
141+
yield user_input
90142
except EOFError:
91143
break
144+
except KeyboardInterrupt:
145+
print("\nInterrupted. Press Ctrl+D to exit or continue entering input.")
146+
continue
147+
148+
def __or__(self, other):
149+
"""Override to add error handling when chaining with other segments."""
150+
if self.error_resilient:
151+
# Register the downstream relationship
152+
self.registerDownstream(other)
153+
other.registerUpstream(self)
154+
155+
# Return an error-resilient pipeline instead of a regular one
156+
return ErrorResilientPromptPipeline(self, other)
157+
else:
158+
# Use default behavior
159+
return super().__or__(other)
92160

93161
@register_source('echo')
94162
@source(delimiter=',')

tests/talkpipe/pipe/test_io.py

Lines changed: 82 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import json, pickle, os
2+
from unittest.mock import Mock, patch
23
from talkpipe.pipe import io
34
from talkpipe.util import config
45
from talkpipe.chatterlang import compiler
@@ -326,13 +327,91 @@ def test_writeString_compiler_integration(tmpdir):
326327
def test_writeString_single_item(tmpdir):
327328
"""Test writeString with single item."""
328329
data = ["single_item"]
329-
330+
330331
temp_file_path = tmpdir.join("test_single.txt")
331332
f = io.writeString(fname=str(temp_file_path))
332-
333+
333334
result = list(f(data))
334335
assert result == data
335-
336+
336337
with open(str(temp_file_path), 'r') as file:
337338
content = file.read()
338339
assert content == "single_item\n"
340+
341+
def test_prompt_error_handling_continues_after_exception(capsys):
342+
"""Test that Prompt with error-resilient pipeline catches downstream exceptions and continues prompting."""
343+
# Mock the PromptSession to return specific values
344+
mock_session = Mock()
345+
mock_session.prompt.side_effect = [
346+
"1", # Valid input
347+
"abc", # Invalid input that will cause error
348+
"2", # Valid input after error
349+
EOFError() # Simulate user exiting
350+
]
351+
352+
# Create a lambda segment that will fail on non-numeric input
353+
from talkpipe.pipe.basic import EvalExpression
354+
355+
with patch('talkpipe.pipe.io.PromptSession', return_value=mock_session):
356+
prompt = io.Prompt()
357+
# Chain with a lambda that converts to int (will fail on "abc")
358+
pipeline = prompt | EvalExpression(expression="int(item)")
359+
360+
# Convert to function and execute
361+
func = pipeline.as_function()
362+
results = func()
363+
364+
# Should get results for valid inputs only
365+
assert 1 in results
366+
assert 2 in results
367+
# "abc" should have caused an error but not stopped the pipeline
368+
369+
# Verify error was printed
370+
captured = capsys.readouterr()
371+
assert "Error:" in captured.out or "invalid literal" in captured.out.lower()
372+
373+
def test_prompt_keyboard_interrupt_continues(capsys):
374+
"""Test that Prompt handles KeyboardInterrupt and continues prompting."""
375+
mock_session = Mock()
376+
mock_session.prompt.side_effect = [
377+
"first input",
378+
KeyboardInterrupt(),
379+
"second input",
380+
EOFError()
381+
]
382+
383+
with patch('talkpipe.pipe.io.PromptSession', return_value=mock_session):
384+
prompt = io.Prompt()
385+
generator = prompt.generate()
386+
387+
# First input should work normally
388+
assert next(generator) == "first input"
389+
390+
# After KeyboardInterrupt, generator should continue
391+
assert next(generator) == "second input"
392+
393+
# Verify message was printed
394+
captured = capsys.readouterr()
395+
assert "Interrupted" in captured.out
396+
397+
def test_prompt_eof_terminates():
398+
"""Test that Prompt terminates on EOFError."""
399+
mock_session = Mock()
400+
mock_session.prompt.side_effect = [
401+
"first input",
402+
EOFError()
403+
]
404+
405+
with patch('talkpipe.pipe.io.PromptSession', return_value=mock_session):
406+
prompt = io.Prompt()
407+
generator = prompt.generate()
408+
409+
# First input should work
410+
assert next(generator) == "first input"
411+
412+
# EOFError should cause generator to stop
413+
try:
414+
next(generator)
415+
assert False, "Generator should have stopped on EOFError"
416+
except StopIteration:
417+
pass # Expected behavior

0 commit comments

Comments
 (0)