Skip to content

Commit 29125f7

Browse files
committed
feat: workflow step interceptor
1 parent ae2ffbb commit 29125f7

5 files changed

Lines changed: 320 additions & 5 deletions

File tree

src/memu/app/service.py

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
LLMInterceptorHandle,
3131
LLMInterceptorRegistry,
3232
)
33+
from memu.workflow.interceptor import WorkflowInterceptorHandle, WorkflowInterceptorRegistry
3334
from memu.workflow.pipeline import PipelineManager
3435
from memu.workflow.runner import WorkflowRunner, resolve_workflow_runner
3536
from memu.workflow.step import WorkflowState, WorkflowStep
@@ -83,6 +84,7 @@ def __init__(
8384
# Initialize client caches (lazy creation on first use)
8485
self._llm_clients: dict[str, Any] = {}
8586
self._llm_interceptors = LLMInterceptorRegistry()
87+
self._workflow_interceptors = WorkflowInterceptorRegistry()
8688

8789
self._workflow_runner = resolve_workflow_runner(workflow_runner)
8890

@@ -240,6 +242,45 @@ def intercept_on_error_llm_call(
240242
) -> LLMInterceptorHandle:
241243
return self._llm_interceptors.register_on_error(fn, name=name, priority=priority, where=where)
242244

245+
def intercept_before_workflow_step(
246+
self,
247+
fn: Callable[..., Any],
248+
*,
249+
name: str | None = None,
250+
) -> WorkflowInterceptorHandle:
251+
"""
252+
Register an interceptor to be called before each workflow step.
253+
254+
The interceptor receives (step_context: WorkflowStepContext, state: WorkflowState).
255+
"""
256+
return self._workflow_interceptors.register_before(fn, name=name)
257+
258+
def intercept_after_workflow_step(
259+
self,
260+
fn: Callable[..., Any],
261+
*,
262+
name: str | None = None,
263+
) -> WorkflowInterceptorHandle:
264+
"""
265+
Register an interceptor to be called after each workflow step.
266+
267+
The interceptor receives (step_context: WorkflowStepContext, state: WorkflowState).
268+
"""
269+
return self._workflow_interceptors.register_after(fn, name=name)
270+
271+
def intercept_on_error_workflow_step(
272+
self,
273+
fn: Callable[..., Any],
274+
*,
275+
name: str | None = None,
276+
) -> WorkflowInterceptorHandle:
277+
"""
278+
Register an interceptor to be called when a workflow step raises an exception.
279+
280+
The interceptor receives (step_context: WorkflowStepContext, state: WorkflowState, error: Exception).
281+
"""
282+
return self._workflow_interceptors.register_on_error(fn, name=name)
283+
243284
def _get_context(self) -> Context:
244285
return self._context
245286

@@ -292,7 +333,13 @@ async def _run_workflow(self, workflow_name: str, initial_state: WorkflowState)
292333
"""Execute a workflow through the configured runner backend."""
293334
steps = self._pipelines.build(workflow_name)
294335
runner_context = {"workflow_name": workflow_name}
295-
return await self._workflow_runner.run(workflow_name, steps, initial_state, runner_context)
336+
return await self._workflow_runner.run(
337+
workflow_name,
338+
steps,
339+
initial_state,
340+
runner_context,
341+
interceptor_registry=self._workflow_interceptors,
342+
)
296343

297344
@staticmethod
298345
def _extract_json_blob(raw: str) -> str:

src/memu/workflow/__init__.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,8 @@
1+
from memu.workflow.interceptor import (
2+
WorkflowInterceptorHandle,
3+
WorkflowInterceptorRegistry,
4+
WorkflowStepContext,
5+
)
16
from memu.workflow.pipeline import PipelineManager, PipelineRevision
27
from memu.workflow.runner import (
38
LocalWorkflowRunner,
@@ -12,9 +17,12 @@
1217
"PipelineManager",
1318
"PipelineRevision",
1419
"WorkflowContext",
20+
"WorkflowInterceptorHandle",
21+
"WorkflowInterceptorRegistry",
1522
"WorkflowRunner",
1623
"WorkflowState",
1724
"WorkflowStep",
25+
"WorkflowStepContext",
1826
"register_workflow_runner",
1927
"resolve_workflow_runner",
2028
"run_steps",

src/memu/workflow/interceptor.py

Lines changed: 218 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,218 @@
1+
from __future__ import annotations
2+
3+
import inspect
4+
import logging
5+
import threading
6+
from collections.abc import Callable
7+
from dataclasses import dataclass
8+
from typing import TYPE_CHECKING, Any
9+
10+
if TYPE_CHECKING:
11+
from memu.workflow.step import WorkflowState
12+
13+
logger = logging.getLogger(__name__)
14+
15+
16+
@dataclass(frozen=True)
17+
class WorkflowStepContext:
18+
"""Context information for a workflow step execution."""
19+
20+
workflow_name: str
21+
step_id: str
22+
step_role: str
23+
step_context: dict[str, Any]
24+
25+
26+
@dataclass(frozen=True)
27+
class _WorkflowInterceptor:
28+
interceptor_id: int
29+
fn: Callable[..., Any]
30+
name: str | None
31+
32+
33+
@dataclass(frozen=True)
34+
class _WorkflowInterceptorSnapshot:
35+
before: tuple[_WorkflowInterceptor, ...]
36+
after: tuple[_WorkflowInterceptor, ...]
37+
on_error: tuple[_WorkflowInterceptor, ...]
38+
39+
40+
class WorkflowInterceptorHandle:
41+
"""Handle for disposing a registered workflow interceptor."""
42+
43+
def __init__(self, registry: WorkflowInterceptorRegistry, interceptor_id: int) -> None:
44+
self._registry = registry
45+
self._interceptor_id = interceptor_id
46+
self._disposed = False
47+
48+
def dispose(self) -> bool:
49+
"""Remove the interceptor from the registry. Returns True if removed."""
50+
if self._disposed:
51+
return False
52+
self._disposed = True
53+
return self._registry.remove(self._interceptor_id)
54+
55+
56+
class WorkflowInterceptorRegistry:
57+
"""
58+
Registry for workflow step interceptors.
59+
60+
Interceptors are called before and after each workflow step execution.
61+
Unlike LLM interceptors, workflow interceptors do not support filtering,
62+
priority, or ordering - they are called in registration order.
63+
"""
64+
65+
def __init__(self, *, strict: bool = False) -> None:
66+
self._before: tuple[_WorkflowInterceptor, ...] = ()
67+
self._after: tuple[_WorkflowInterceptor, ...] = ()
68+
self._on_error: tuple[_WorkflowInterceptor, ...] = ()
69+
self._lock = threading.Lock()
70+
self._seq = 0
71+
self._strict = strict
72+
73+
@property
74+
def strict(self) -> bool:
75+
"""If True, interceptor exceptions will propagate instead of being logged."""
76+
return self._strict
77+
78+
def register_before(
79+
self,
80+
fn: Callable[..., Any],
81+
*,
82+
name: str | None = None,
83+
) -> WorkflowInterceptorHandle:
84+
"""
85+
Register an interceptor to be called before each step.
86+
87+
The interceptor receives (step_context: WorkflowStepContext, state: WorkflowState).
88+
"""
89+
return self._register("before", fn, name=name)
90+
91+
def register_after(
92+
self,
93+
fn: Callable[..., Any],
94+
*,
95+
name: str | None = None,
96+
) -> WorkflowInterceptorHandle:
97+
"""
98+
Register an interceptor to be called after each step.
99+
100+
The interceptor receives (step_context: WorkflowStepContext, state: WorkflowState).
101+
"""
102+
return self._register("after", fn, name=name)
103+
104+
def register_on_error(
105+
self,
106+
fn: Callable[..., Any],
107+
*,
108+
name: str | None = None,
109+
) -> WorkflowInterceptorHandle:
110+
"""
111+
Register an interceptor to be called when a step raises an exception.
112+
113+
The interceptor receives (step_context: WorkflowStepContext, state: WorkflowState, error: Exception).
114+
"""
115+
return self._register("on_error", fn, name=name)
116+
117+
def _register(
118+
self,
119+
kind: str,
120+
fn: Callable[..., Any],
121+
*,
122+
name: str | None,
123+
) -> WorkflowInterceptorHandle:
124+
if not callable(fn):
125+
msg = "Interceptor must be callable"
126+
raise TypeError(msg)
127+
with self._lock:
128+
self._seq += 1
129+
interceptor = _WorkflowInterceptor(
130+
interceptor_id=self._seq,
131+
fn=fn,
132+
name=name,
133+
)
134+
if kind == "before":
135+
self._before = (*self._before, interceptor)
136+
elif kind == "after":
137+
self._after = (*self._after, interceptor)
138+
elif kind == "on_error":
139+
self._on_error = (*self._on_error, interceptor)
140+
else:
141+
msg = f"Unknown interceptor kind '{kind}'"
142+
raise ValueError(msg)
143+
return WorkflowInterceptorHandle(self, interceptor.interceptor_id)
144+
145+
def remove(self, interceptor_id: int) -> bool:
146+
"""Remove an interceptor by ID. Returns True if found and removed."""
147+
with self._lock:
148+
removed = False
149+
before = tuple(i for i in self._before if i.interceptor_id != interceptor_id)
150+
after = tuple(i for i in self._after if i.interceptor_id != interceptor_id)
151+
on_error = tuple(i for i in self._on_error if i.interceptor_id != interceptor_id)
152+
if len(before) != len(self._before):
153+
removed = True
154+
self._before = before
155+
if len(after) != len(self._after):
156+
removed = True
157+
self._after = after
158+
if len(on_error) != len(self._on_error):
159+
removed = True
160+
self._on_error = on_error
161+
return removed
162+
163+
def snapshot(self) -> _WorkflowInterceptorSnapshot:
164+
"""Get a point-in-time snapshot of registered interceptors."""
165+
return _WorkflowInterceptorSnapshot(self._before, self._after, self._on_error)
166+
167+
168+
async def run_before_interceptors(
169+
interceptors: tuple[_WorkflowInterceptor, ...],
170+
step_context: WorkflowStepContext,
171+
state: WorkflowState,
172+
*,
173+
strict: bool = False,
174+
) -> None:
175+
"""Run all before-step interceptors."""
176+
for interceptor in interceptors:
177+
await _safe_invoke_interceptor(interceptor, strict, step_context, state)
178+
179+
180+
async def run_after_interceptors(
181+
interceptors: tuple[_WorkflowInterceptor, ...],
182+
step_context: WorkflowStepContext,
183+
state: WorkflowState,
184+
*,
185+
strict: bool = False,
186+
) -> None:
187+
"""Run all after-step interceptors in reverse order."""
188+
for interceptor in reversed(interceptors):
189+
await _safe_invoke_interceptor(interceptor, strict, step_context, state)
190+
191+
192+
async def run_on_error_interceptors(
193+
interceptors: tuple[_WorkflowInterceptor, ...],
194+
step_context: WorkflowStepContext,
195+
state: WorkflowState,
196+
error: Exception,
197+
*,
198+
strict: bool = False,
199+
) -> None:
200+
"""Run all on-error interceptors in reverse order."""
201+
for interceptor in reversed(interceptors):
202+
await _safe_invoke_interceptor(interceptor, strict, step_context, state, error)
203+
204+
205+
async def _safe_invoke_interceptor(
206+
interceptor: _WorkflowInterceptor,
207+
strict: bool,
208+
*args: Any,
209+
) -> None:
210+
"""Safely invoke an interceptor, handling exceptions based on strict mode."""
211+
try:
212+
result = interceptor.fn(*args)
213+
if inspect.isawaitable(result):
214+
await result
215+
except Exception:
216+
if strict:
217+
raise
218+
logger.exception("Workflow interceptor failed: %s", interceptor.name or interceptor.interceptor_id)

src/memu/workflow/runner.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
11
from __future__ import annotations
22

33
from collections.abc import Callable
4-
from typing import Protocol, runtime_checkable
4+
from typing import TYPE_CHECKING, Protocol, runtime_checkable
55

66
from memu.workflow.step import WorkflowContext, WorkflowState, WorkflowStep, run_steps
77

8+
if TYPE_CHECKING:
9+
from memu.workflow.interceptor import WorkflowInterceptorRegistry
10+
811

912
@runtime_checkable
1013
class WorkflowRunner(Protocol):
@@ -18,6 +21,7 @@ async def run(
1821
steps: list[WorkflowStep],
1922
initial_state: WorkflowState,
2023
context: WorkflowContext = None,
24+
interceptor_registry: WorkflowInterceptorRegistry | None = None,
2125
) -> WorkflowState: ...
2226

2327

@@ -30,8 +34,9 @@ async def run(
3034
steps: list[WorkflowStep],
3135
initial_state: WorkflowState,
3236
context: WorkflowContext = None,
37+
interceptor_registry: WorkflowInterceptorRegistry | None = None,
3338
) -> WorkflowState:
34-
return await run_steps(workflow_name, steps, initial_state, context)
39+
return await run_steps(workflow_name, steps, initial_state, context, interceptor_registry)
3540

3641

3742
RunnerFactory = Callable[[], WorkflowRunner]

0 commit comments

Comments
 (0)