1616from typing import Any , AsyncIterable , Callable , Union
1717
1818from pydantic import BaseModel
19+ from volcenginesdkarkruntime import AsyncArk
1920
20- from arkitect .core .component .llm_event_stream .model import ContextInterruption , NewState
21+ from arkitect .core .component .llm_event_stream .model import ContextInterruption , State
2122from arkitect .core .component .tool import MCPClient
23+ from arkitect .types .llm .model import ArkChatParameters
2224from arkitect .types .responses .event import BaseEvent
2325
2426"""
@@ -30,7 +32,7 @@ class PreAgentCallHook(abc.ABC):
3032 @abc .abstractmethod
3133 async def pre_agent_call (
3234 self ,
33- state : NewState ,
35+ state : State ,
3436 ) -> AsyncIterable [BaseEvent | ContextInterruption ]:
3537 pass
3638
@@ -39,7 +41,7 @@ class PostAgentCallHook(abc.ABC):
3941 @abc .abstractmethod
4042 async def post_agent_call (
4143 self ,
42- state : NewState ,
44+ state : State ,
4345 ) -> AsyncIterable [BaseEvent | ContextInterruption ]:
4446 pass
4547
@@ -51,6 +53,8 @@ class BaseAgent(abc.ABC, BaseModel):
5153 tools : list [Union [MCPClient | Callable ]] = []
5254 sub_agents : list ["BaseAgent" ] = []
5355 instruction : str | None = None
56+ parameters : ArkChatParameters | None = None
57+ client : AsyncArk | None = None
5458
5559 pre_agent_call_hook : PreAgentCallHook | None = None
5660 post_agent_call_hook : PostAgentCallHook | None = None
@@ -61,12 +65,10 @@ class BaseAgent(abc.ABC, BaseModel):
6165
6266 # stream run step
6367 @abc .abstractmethod
64- async def _astream (
65- self , state : NewState , ** kwargs : Any
66- ) -> AsyncIterable [BaseEvent ]:
68+ async def _astream (self , state : State , ** kwargs : Any ) -> AsyncIterable [BaseEvent ]:
6769 pass
6870
69- async def astream (self , state : NewState , ** kwargs : Any ) -> AsyncIterable [BaseEvent ]:
71+ async def astream (self , state : State , ** kwargs : Any ) -> AsyncIterable [BaseEvent ]:
7072
7173 if self .pre_agent_call_hook :
7274 async for event in self .pre_agent_call_hook .pre_agent_call (state ):
@@ -81,9 +83,7 @@ async def astream(self, state: NewState, **kwargs: Any) -> AsyncIterable[BaseEve
8183 async for event in self .post_agent_call_hook .post_agent_call (state ):
8284 yield event
8385
84- async def __call__ (
85- self , state : NewState , ** kwargs : Any
86- ) -> AsyncIterable [BaseEvent ]:
86+ async def __call__ (self , state : State , ** kwargs : Any ) -> AsyncIterable [BaseEvent ]:
8787 async for event in self .astream (state , ** kwargs ):
8888 yield event
8989
0 commit comments