| 
53 | 53 |     ToolCallItemTypes,  | 
54 | 54 |     TResponseInputItem,  | 
55 | 55 | )  | 
56 |  | -from .lifecycle import RunHooks  | 
 | 56 | +from .lifecycle import AgentHooksBase, RunHooks, RunHooksBase  | 
57 | 57 | from .logger import logger  | 
58 | 58 | from .memory import Session, SessionInputCallback  | 
59 | 59 | from .model_settings import ModelSettings  | 
@@ -461,13 +461,11 @@ async def run(  | 
461 | 461 |     ) -> RunResult:  | 
462 | 462 |         context = kwargs.get("context")  | 
463 | 463 |         max_turns = kwargs.get("max_turns", DEFAULT_MAX_TURNS)  | 
464 |  | -        hooks = kwargs.get("hooks")  | 
 | 464 | +        hooks = cast(RunHooks[TContext], self._validate_run_hooks(kwargs.get("hooks")))  | 
465 | 465 |         run_config = kwargs.get("run_config")  | 
466 | 466 |         previous_response_id = kwargs.get("previous_response_id")  | 
467 | 467 |         conversation_id = kwargs.get("conversation_id")  | 
468 | 468 |         session = kwargs.get("session")  | 
469 |  | -        if hooks is None:  | 
470 |  | -            hooks = RunHooks[Any]()  | 
471 | 469 |         if run_config is None:  | 
472 | 470 |             run_config = RunConfig()  | 
473 | 471 | 
 
  | 
@@ -668,14 +666,12 @@ def run_streamed(  | 
668 | 666 |     ) -> RunResultStreaming:  | 
669 | 667 |         context = kwargs.get("context")  | 
670 | 668 |         max_turns = kwargs.get("max_turns", DEFAULT_MAX_TURNS)  | 
671 |  | -        hooks = kwargs.get("hooks")  | 
 | 669 | +        hooks = cast(RunHooks[TContext], self._validate_run_hooks(kwargs.get("hooks")))  | 
672 | 670 |         run_config = kwargs.get("run_config")  | 
673 | 671 |         previous_response_id = kwargs.get("previous_response_id")  | 
674 | 672 |         conversation_id = kwargs.get("conversation_id")  | 
675 | 673 |         session = kwargs.get("session")  | 
676 | 674 | 
 
  | 
677 |  | -        if hooks is None:  | 
678 |  | -            hooks = RunHooks[Any]()  | 
679 | 675 |         if run_config is None:  | 
680 | 676 |             run_config = RunConfig()  | 
681 | 677 | 
 
  | 
@@ -732,6 +728,23 @@ def run_streamed(  | 
732 | 728 |         )  | 
733 | 729 |         return streamed_result  | 
734 | 730 | 
 
  | 
 | 731 | +    @staticmethod  | 
 | 732 | +    def _validate_run_hooks(  | 
 | 733 | +        hooks: RunHooksBase[Any, Agent[Any]] | AgentHooksBase[Any, Agent[Any]] | Any | None,  | 
 | 734 | +    ) -> RunHooks[Any]:  | 
 | 735 | +        if hooks is None:  | 
 | 736 | +            return RunHooks[Any]()  | 
 | 737 | +        input_hook_type = type(hooks).__name__  | 
 | 738 | +        if isinstance(hooks, AgentHooksBase):  | 
 | 739 | +            raise TypeError(  | 
 | 740 | +                "Run hooks must be instances of RunHooks. "  | 
 | 741 | +                f"Received agent-scoped hooks ({input_hook_type}). "  | 
 | 742 | +                "Attach AgentHooks to an Agent via Agent(..., hooks=...)."  | 
 | 743 | +            )  | 
 | 744 | +        if not isinstance(hooks, RunHooksBase):  | 
 | 745 | +            raise TypeError(f"Run hooks must be instances of RunHooks. Received {input_hook_type}.")  | 
 | 746 | +        return hooks  | 
 | 747 | + | 
735 | 748 |     @classmethod  | 
736 | 749 |     async def _maybe_filter_model_input(  | 
737 | 750 |         cls,  | 
 | 
0 commit comments