|
5 | 5 | import os |
6 | 6 | import pickle # nosec B403 - Used only for write operations, not loading untrusted data |
7 | 7 | import json |
| 8 | +import traceback |
8 | 9 | from pprint import pformat |
9 | 10 | from prompt_toolkit import PromptSession |
10 | 11 |
|
11 | 12 | from talkpipe.chatterlang.registry import register_source, register_segment |
12 | 13 | import talkpipe.chatterlang.registry as registry |
13 | | -from talkpipe.pipe.core import AbstractSource, source, AbstractSegment, segment |
| 14 | +from talkpipe.pipe.core import AbstractSource, source, AbstractSegment, segment, Pipeline |
14 | 15 | from talkpipe.util import data_manipulation |
15 | 16 |
|
16 | 17 |
|
| 18 | +class ErrorResilientPromptPipeline(Pipeline): |
| 19 | + """A special pipeline that handles errors in interactive prompt workflows. |
| 20 | +
|
| 21 | + When a downstream segment raises an exception, this pipeline catches it, |
| 22 | + displays the error, and continues processing the next item from the prompt. |
| 23 | + """ |
| 24 | + |
| 25 | + def __init__(self, prompt_source, *operations): |
| 26 | + super().__init__(prompt_source, *operations) |
| 27 | + self.prompt_source = prompt_source |
| 28 | + |
| 29 | + def transform(self, input_iter=None): |
| 30 | + """Execute pipeline with error resilience for prompt-based workflows.""" |
| 31 | + # Get the prompt generator |
| 32 | + prompt_iter = self.prompt_source() |
| 33 | + |
| 34 | + # Build the downstream pipeline (everything after the prompt) |
| 35 | + downstream_ops = [op for op in self.operations if op is not self.prompt_source] |
| 36 | + |
| 37 | + # Process each prompt input with error handling |
| 38 | + for user_input in prompt_iter: |
| 39 | + try: |
| 40 | + # Create a single-item iterator for this input |
| 41 | + current_iter = iter([user_input]) |
| 42 | + |
| 43 | + # Pass through each downstream operation |
| 44 | + for op in downstream_ops: |
| 45 | + current_iter = op(current_iter) |
| 46 | + |
| 47 | + # Consume and yield results |
| 48 | + for result in current_iter: |
| 49 | + yield result |
| 50 | + |
| 51 | + except Exception as e: |
| 52 | + # Catch and display errors, but continue prompting |
| 53 | + print(f"Error: {e}") |
| 54 | + traceback.print_exc() |
| 55 | + # Don't yield anything for this failed input, just continue to next prompt |
| 56 | + |
| 57 | + def __or__(self, other): |
| 58 | + """Support chaining additional operations to the error-resilient pipeline.""" |
| 59 | + # Add the new operation to our operations list |
| 60 | + return ErrorResilientPromptPipeline(self.prompt_source, *self.operations[1:], other) |
| 61 | + |
| 62 | + |
17 | 63 | @registry.register_segment(name="print") |
18 | 64 | class Print(AbstractSegment): |
19 | 65 | """ |
@@ -73,22 +119,44 @@ def transform(self, input_iter: Annotated[Iterable[int], "Iterable input data"]) |
73 | 119 | @register_source('prompt') |
74 | 120 | class Prompt(AbstractSource): |
75 | 121 | """A source that generates input from a prompt. |
76 | | - |
| 122 | +
|
77 | 123 | This source will generate input from a prompt until the user enters an EOF. |
78 | 124 | It is for creating interactive pipelines. It uses prompt_toolkit under the |
79 | 125 | hood to provide a nice prompt experience. |
| 126 | +
|
| 127 | + To enable error recovery (continue prompting after downstream errors), this source |
| 128 | + needs to actively consume the downstream pipeline with error handling. This is done |
| 129 | + by overriding the __or__ method to wrap the downstream in error handling. |
80 | 130 | """ |
81 | 131 |
|
82 | | - def __init__(self): |
| 132 | + def __init__(self, error_resilient: Annotated[bool, "If True, catches downstream errors and continues prompting"] = True): |
83 | 133 | super().__init__() |
84 | 134 | self.session = PromptSession() |
| 135 | + self.error_resilient = error_resilient |
85 | 136 |
|
86 | 137 | def generate(self) -> Iterable[str]: |
87 | 138 | while True: |
88 | 139 | try: |
89 | | - yield self.session.prompt('> ') |
| 140 | + user_input = self.session.prompt('> ') |
| 141 | + yield user_input |
90 | 142 | except EOFError: |
91 | 143 | break |
| 144 | + except KeyboardInterrupt: |
| 145 | + print("\nInterrupted. Press Ctrl+D to exit or continue entering input.") |
| 146 | + continue |
| 147 | + |
| 148 | + def __or__(self, other): |
| 149 | + """Override to add error handling when chaining with other segments.""" |
| 150 | + if self.error_resilient: |
| 151 | + # Register the downstream relationship |
| 152 | + self.registerDownstream(other) |
| 153 | + other.registerUpstream(self) |
| 154 | + |
| 155 | + # Return an error-resilient pipeline instead of a regular one |
| 156 | + return ErrorResilientPromptPipeline(self, other) |
| 157 | + else: |
| 158 | + # Use default behavior |
| 159 | + return super().__or__(other) |
92 | 160 |
|
93 | 161 | @register_source('echo') |
94 | 162 | @source(delimiter=',') |
|
0 commit comments