Skip to content
Open
19 changes: 19 additions & 0 deletions src/google/adk/plugins/base_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,3 +370,22 @@ async def on_tool_error_callback(
allows the original error to be raised.
"""
pass

async def on_pipeline_error_callback(
self,
*,
invocation_context: InvocationContext,
error: Exception,
) -> Exception:
"""Callback executed when the runner pipeline encounters an error.

This callback provides an opportunity to handle pipeline errors globally.

Args:
invocation_context: The context for the entire invocation.
error: The exception that was raised during runner execution.

Returns:
An Exception to be raised (either the original error or a new/modified one).
"""
return error
36 changes: 35 additions & 1 deletion src/google/adk/plugins/plugin_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
"after_model_callback",
"on_tool_error_callback",
"on_model_error_callback",
"on_pipeline_error_callback",
]

logger = logging.getLogger("google_adk." + __name__)
Expand Down Expand Up @@ -272,6 +273,33 @@ async def run_on_tool_error_callback(
error=error,
)

async def run_on_pipeline_error_callback(
self,
*,
invocation_context: InvocationContext,
error: Exception,
) -> Exception:
"""Runs the `on_pipeline_error_callback` for all plugins sequentially, chaining the error."""
for plugin in self.plugins:
try:
error = await plugin.on_pipeline_error_callback(
invocation_context=invocation_context, error=error
)
except Exception as e:
error_message = (
f"Error in plugin '{plugin.name}' during "
f"'on_pipeline_error_callback' callback: {e}"
)
logger.error(
"Error in plugin '%s' during 'on_pipeline_error_callback'"
" callback: %s",
plugin.name,
e,
exc_info=True,
)
raise RuntimeError(error_message) from e
return error

async def _run_callbacks(
self, callback_name: PluginCallbackName, **kwargs: Any
) -> Optional[Any]:
Expand Down Expand Up @@ -316,7 +344,13 @@ async def _run_callbacks(
f"Error in plugin '{plugin.name}' during '{callback_name}'"
f" callback: {e}"
)
logger.error(error_message, exc_info=True)
logger.error(
"Error in plugin '%s' during '%s' callback: %s",
plugin.name,
callback_name,
e,
exc_info=True,
)
raise RuntimeError(error_message) from e

return None
Expand Down
217 changes: 112 additions & 105 deletions src/google/adk/runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -1373,118 +1373,125 @@ async def _exec_with_plugin(

plugin_manager = invocation_context.plugin_manager

# Step 1: Run the before_run callbacks to see if we should early exit.
early_exit_result = await plugin_manager.run_before_run_callback(
invocation_context=invocation_context
)
if isinstance(early_exit_result, types.Content):
early_exit_event = Event(
invocation_id=invocation_context.invocation_id,
author='model',
content=early_exit_result,
)
_apply_run_config_custom_metadata(
early_exit_event, invocation_context.run_config
try:
# Step 1: Run the before_run callbacks to see if we should early exit.
early_exit_result = await plugin_manager.run_before_run_callback(
invocation_context=invocation_context
)
if self._should_append_event(early_exit_event, is_live_call):
await self.session_service.append_event(
session=invocation_context.session,
event=early_exit_event,
if isinstance(early_exit_result, types.Content):
early_exit_event = Event(
invocation_id=invocation_context.invocation_id,
author='model',
content=early_exit_result,
)
yield early_exit_event
else:
# Step 2: Otherwise continue with normal execution
# Note for live/bidi:
# the transcription may arrive later than the action(function call
# event and thus function response event). In this case, the order of
# transcription and function call event will be wrong if we just
# append as it arrives. To address this, we should check if there is
# transcription going on. If there is transcription going on, we
# should hold on appending the function call event until the
# transcription is finished. The transcription in progress can be
# identified by checking if the transcription event is partial. When
# the next transcription event is not partial, it means the previous
# transcription is finished. Then if there is any buffered function
# call event, we should append them after this finished(non-partial)
# transcription event.
buffered_events: list[Event] = []
is_transcribing: bool = False

async with aclosing(execute_fn(invocation_context)) as agen:
async for event in agen:
_apply_run_config_custom_metadata(
event, invocation_context.run_config
)
# Step 3: Run the on_event callbacks before persisting so callback
# changes are stored in the session and match the streamed event.
modified_event = await plugin_manager.run_on_event_callback(
invocation_context=invocation_context, event=event
)
output_event = self._get_output_event(
original_event=event,
modified_event=modified_event,
run_config=invocation_context.run_config,
_apply_run_config_custom_metadata(
early_exit_event, invocation_context.run_config
)
if self._should_append_event(early_exit_event, is_live_call):
await self.session_service.append_event(
session=invocation_context.session,
event=early_exit_event,
)
yield early_exit_event
else:
# Step 2: Otherwise continue with normal execution
# Note for live/bidi:
# the transcription may arrive later than the action(function call
# event and thus function response event). In this case, the order of
# transcription and function call event will be wrong if we just
# append as it arrives. To address this, we should check if there is
# transcription going on. If there is transcription going on, we
# should hold on appending the function call event until the
# transcription is finished. The transcription in progress can be
# identified by checking if the transcription event is partial. When
# the next transcription event is not partial, it means the previous
# transcription is finished. Then if there is any buffered function
# call event, we should append them after this finished(non-partial)
# transcription event.
buffered_events: list[Event] = []
is_transcribing: bool = False

async with aclosing(execute_fn(invocation_context)) as agen:
async for event in agen:
_apply_run_config_custom_metadata(
event, invocation_context.run_config
)
# Step 3: Run the on_event callbacks before persisting so callback
# changes are stored in the session and match the streamed event.
modified_event = await plugin_manager.run_on_event_callback(
invocation_context=invocation_context, event=event
)
output_event = self._get_output_event(
original_event=event,
modified_event=modified_event,
run_config=invocation_context.run_config,
)

if is_live_call:
if event.partial and _is_transcription(event):
is_transcribing = True
if is_transcribing and _is_tool_call_or_response(event):
# only buffer function call and function response event which is
# non-partial
buffered_events.append(output_event)
continue
# Note for live/bidi: for audio response, it's considered as
# non-partial event(event.partial=None)
# event.partial=False and event.partial=None are considered as
# non-partial event; event.partial=True is considered as partial
# event.
if event.partial is not True:
if _is_transcription(event) and (
_has_non_empty_transcription_text(event.input_transcription)
or _has_non_empty_transcription_text(
event.output_transcription
if is_live_call:
if event.partial and _is_transcription(event):
is_transcribing = True
if is_transcribing and _is_tool_call_or_response(event):
# only buffer function call and function response event which is
# non-partial
buffered_events.append(output_event)
continue
# Note for live/bidi: for audio response, it's considered as
# non-partial event(event.partial=None)
# event.partial=False and event.partial=None are considered as
# non-partial event; event.partial=True is considered as partial
# event.
if event.partial is not True:
if _is_transcription(event) and (
_has_non_empty_transcription_text(event.input_transcription)
or _has_non_empty_transcription_text(
event.output_transcription
)
):
# transcription end signal, append buffered events
is_transcribing = False
logger.debug(
'Appending transcription finished event: %s', event
)
):
# transcription end signal, append buffered events
is_transcribing = False
logger.debug(
'Appending transcription finished event: %s', event
if self._should_append_event(event, is_live_call):
await self.session_service.append_event(
session=invocation_context.session, event=output_event
)

for buffered_event in buffered_events:
logger.debug('Appending buffered event: %s', buffered_event)
await self.session_service.append_event(
session=invocation_context.session, event=buffered_event
)
yield buffered_event # yield buffered events to caller
buffered_events = []
else:
# non-transcription event or empty transcription event, for
# example, event that stores blob reference, should be appended.
if self._should_append_event(event, is_live_call):
logger.debug('Appending non-buffered event: %s', event)
await self.session_service.append_event(
session=invocation_context.session, event=output_event
)
else:
if event.partial is not True:
await self.session_service.append_event(
session=invocation_context.session, event=output_event
)
if self._should_append_event(event, is_live_call):
await self.session_service.append_event(
session=invocation_context.session, event=output_event
)

for buffered_event in buffered_events:
logger.debug('Appending buffered event: %s', buffered_event)
await self.session_service.append_event(
session=invocation_context.session, event=buffered_event
)
yield buffered_event # yield buffered events to caller
buffered_events = []
else:
# non-transcription event or empty transcription event, for
# example, event that stores blob reference, should be appended.
if self._should_append_event(event, is_live_call):
logger.debug('Appending non-buffered event: %s', event)
await self.session_service.append_event(
session=invocation_context.session, event=output_event
)
else:
if event.partial is not True:
await self.session_service.append_event(
session=invocation_context.session, event=output_event
)

yield output_event

# Step 4: Run the after_run callbacks to perform global cleanup tasks or
# finalizing logs and metrics data.
# This does NOT emit any event.
await plugin_manager.run_after_run_callback(
invocation_context=invocation_context
)
yield output_event
except Exception as e:
if plugin_manager:
e = await plugin_manager.run_on_pipeline_error_callback(
invocation_context=invocation_context, error=e
)
raise e
finally:
# Step 4: Run the after_run callbacks to perform global cleanup tasks or
# finalizing logs and metrics data.
# This does NOT emit any event.
await plugin_manager.run_after_run_callback(
invocation_context=invocation_context
)

async def _append_new_message_to_session(
self,
Expand Down
51 changes: 51 additions & 0 deletions tests/unittests/plugins/test_plugin_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,12 @@ async def after_model_callback(self, **kwargs):
async def on_model_error_callback(self, **kwargs):
return await self._handle_callback("on_model_error_callback")

async def on_pipeline_error_callback(self, error: Exception, **kwargs):
self.call_log.append("on_pipeline_error_callback")
if "on_pipeline_error_callback" in self.exceptions_to_raise:
raise self.exceptions_to_raise["on_pipeline_error_callback"]
return self.return_values.get("on_pipeline_error_callback", error)


@pytest.fixture
def service() -> PluginManager:
Expand Down Expand Up @@ -252,6 +258,10 @@ async def test_all_callbacks_are_supported(
llm_request=mock_context,
error=mock_context,
)
await service.run_on_pipeline_error_callback(
invocation_context=mock_context,
error=ValueError("err"),
)

# Verify all callbacks were logged
expected_callbacks = [
Expand All @@ -267,6 +277,7 @@ async def test_all_callbacks_are_supported(
"before_model_callback",
"after_model_callback",
"on_model_error_callback",
"on_pipeline_error_callback",
]
assert set(plugin1.call_log) == set(expected_callbacks)

Expand Down Expand Up @@ -363,3 +374,43 @@ async def test_set_skip_closing_plugins_false_reverts_to_closing(
await service.close()

plugin1.close.assert_awaited_once()


@pytest.mark.asyncio
async def test_pipeline_error_callback_chaining(
service: PluginManager, plugin1: TestPlugin, plugin2: TestPlugin
):
"""Tests that on_pipeline_error_callback is called and errors are chained."""
error1 = ValueError("Original error")
error2 = RuntimeError("Chained error")
plugin1.return_values["on_pipeline_error_callback"] = error2

service.register_plugin(plugin1)
service.register_plugin(plugin2)

result_err = await service.run_on_pipeline_error_callback(
invocation_context=Mock(), error=error1
)

assert result_err is error2
assert "on_pipeline_error_callback" in plugin1.call_log
assert "on_pipeline_error_callback" in plugin2.call_log


@pytest.mark.asyncio
async def test_pipeline_error_callback_exception_wrap(
service: PluginManager, plugin1: TestPlugin
):
"""Tests that if on_pipeline_error_callback raises, it wraps in RuntimeError."""
plugin1.exceptions_to_raise["on_pipeline_error_callback"] = ValueError(
"Callback crashed"
)
service.register_plugin(plugin1)

with pytest.raises(RuntimeError) as excinfo:
await service.run_on_pipeline_error_callback(
invocation_context=Mock(), error=ValueError("Original")
)

assert "Error in plugin 'plugin1'" in str(excinfo.value)
assert "on_pipeline_error_callback" in str(excinfo.value)