Skip to content

Commit ffc674a

Browse files
committed
feat: refine
1 parent 0f14a5f commit ffc674a

File tree

15 files changed

+856
-786
lines changed

15 files changed

+856
-786
lines changed

arkitect/core/component/agent/base_agent.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,11 @@
1616
from typing import Any, AsyncIterable, Callable, Union
1717

1818
from pydantic import BaseModel
19+
from volcenginesdkarkruntime import AsyncArk
1920

20-
from arkitect.core.component.llm_event_stream.model import ContextInterruption, NewState
21+
from arkitect.core.component.llm_event_stream.model import ContextInterruption, State
2122
from arkitect.core.component.tool import MCPClient
23+
from arkitect.types.llm.model import ArkChatParameters
2224
from arkitect.types.responses.event import BaseEvent
2325

2426
"""
@@ -30,7 +32,7 @@ class PreAgentCallHook(abc.ABC):
3032
@abc.abstractmethod
3133
async def pre_agent_call(
3234
self,
33-
state: NewState,
35+
state: State,
3436
) -> AsyncIterable[BaseEvent | ContextInterruption]:
3537
pass
3638

@@ -39,7 +41,7 @@ class PostAgentCallHook(abc.ABC):
3941
@abc.abstractmethod
4042
async def post_agent_call(
4143
self,
42-
state: NewState,
44+
state: State,
4345
) -> AsyncIterable[BaseEvent | ContextInterruption]:
4446
pass
4547

@@ -51,6 +53,8 @@ class BaseAgent(abc.ABC, BaseModel):
5153
tools: list[Union[MCPClient | Callable]] = []
5254
sub_agents: list["BaseAgent"] = []
5355
instruction: str | None = None
56+
parameters: ArkChatParameters | None = None
57+
client: AsyncArk | None = None
5458

5559
pre_agent_call_hook: PreAgentCallHook | None = None
5660
post_agent_call_hook: PostAgentCallHook | None = None
@@ -61,12 +65,10 @@ class BaseAgent(abc.ABC, BaseModel):
6165

6266
# stream run step
6367
@abc.abstractmethod
64-
async def _astream(
65-
self, state: NewState, **kwargs: Any
66-
) -> AsyncIterable[BaseEvent]:
68+
async def _astream(self, state: State, **kwargs: Any) -> AsyncIterable[BaseEvent]:
6769
pass
6870

69-
async def astream(self, state: NewState, **kwargs: Any) -> AsyncIterable[BaseEvent]:
71+
async def astream(self, state: State, **kwargs: Any) -> AsyncIterable[BaseEvent]:
7072

7173
if self.pre_agent_call_hook:
7274
async for event in self.pre_agent_call_hook.pre_agent_call(state):
@@ -81,9 +83,7 @@ async def astream(self, state: NewState, **kwargs: Any) -> AsyncIterable[BaseEve
8183
async for event in self.post_agent_call_hook.post_agent_call(state):
8284
yield event
8385

84-
async def __call__(
85-
self, state: NewState, **kwargs: Any
86-
) -> AsyncIterable[BaseEvent]:
86+
async def __call__(self, state: State, **kwargs: Any) -> AsyncIterable[BaseEvent]:
8787
async for event in self.astream(state, **kwargs):
8888
yield event
8989

arkitect/core/component/agent/default_agent.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,14 @@
1818

1919
from arkitect.core.component.agent import BaseAgent
2020
from arkitect.core.component.llm_event_stream.llm_event_stream import LLMEventStream
21-
from arkitect.core.component.llm_event_stream.model import NewState
21+
22+
from arkitect.core.component.llm_event_stream.hooks import (
23+
PostLLMCallHook,
24+
PostToolCallHook,
25+
PreLLMCallHook,
26+
PreToolCallHook,
27+
)
28+
from arkitect.core.component.llm_event_stream.model import State
2229
from arkitect.types.responses.event import BaseEvent
2330

2431
"""
@@ -31,17 +38,26 @@ class DefaultAgent(BaseAgent):
3138
"arbitrary_types_allowed": True,
3239
}
3340

41+
pre_tool_call_hook: PreToolCallHook | None = None
42+
post_tool_call_hook: PostToolCallHook | None = None
43+
pre_llm_call_hook: PreLLMCallHook | None = None
44+
post_llm_call_hook: PostLLMCallHook | None = None
45+
3446
# stream run step
35-
async def _astream(
36-
self, state: NewState, **kwargs: Any
37-
) -> AsyncIterable[BaseEvent]:
47+
async def _astream(self, state: State, **kwargs: Any) -> AsyncIterable[BaseEvent]:
3848
event_stream = LLMEventStream(
3949
model=self.model,
4050
agent_name=self.name,
4151
tools=self.tools,
4252
sub_agents=self.sub_agents,
4353
state=state,
4454
instruction=self.instruction,
55+
pre_tool_call_hook=self.pre_tool_call_hook,
56+
post_tool_call_hook=self.post_tool_call_hook,
57+
pre_llm_call_hook=self.pre_llm_call_hook,
58+
post_llm_call_hook=self.post_llm_call_hook,
59+
parameters=self.parameters,
60+
client=self.client,
4561
)
4662
await event_stream.init()
4763
resp_stream = await event_stream.completions.create(

arkitect/core/component/agent/parallel_agent.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
from typing import AsyncIterable
2828

2929
from arkitect.core.component.agent.base_agent import BaseAgent
30-
from arkitect.core.component.llm_event_stream.model import NewState
30+
from arkitect.core.component.llm_event_stream.model import State
3131
from arkitect.types.responses.event import BaseEvent
3232

3333
"""
@@ -77,7 +77,7 @@ async def _merge_agent_run(
7777

7878
class ParallelAgent(BaseAgent):
7979
# stream run step
80-
async def _astream(self, state: NewState, **kwargs) -> AsyncIterable[BaseEvent]:
80+
async def _astream(self, state: State, **kwargs) -> AsyncIterable[BaseEvent]:
8181
agent_runs = [agent(state) for agent in self.sub_agents]
8282
async for event in _merge_agent_run(agent_runs):
8383
yield event

arkitect/core/component/checkpoint/checkpoint.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
import uuid
1515
from pydantic import BaseModel, ConfigDict, Field
1616

17-
from arkitect.core.component.llm_event_stream.model import NewState
17+
from arkitect.core.component.llm_event_stream.model import State
1818

1919

2020
class Checkpoint(BaseModel):
@@ -39,7 +39,7 @@ class Checkpoint(BaseModel):
3939
"""The name of the app."""
4040
user_id: str
4141
"""The user id of the checkpoint."""
42-
state: NewState = Field(default_factory=NewState)
42+
state: State = Field(default_factory=State)
4343
"""The state of the checkpoint."""
4444
last_update_time: float = 0.0
4545
"""The last update time of the checkpoint."""

arkitect/core/component/checkpoint/in_memory_checkpoint_service.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
BaseCheckpointService,
1919
)
2020
from arkitect.core.component.checkpoint.checkpoint import Checkpoint
21-
from arkitect.core.component.llm_event_stream.model import NewState
21+
from arkitect.core.component.llm_event_stream.model import State
2222
from arkitect.utils.common import Singleton
2323

2424

@@ -39,7 +39,7 @@ async def create_checkpoint(
3939
id=checkpoint_id,
4040
app_name=app_name,
4141
user_id=user_id,
42-
state=NewState(),
42+
state=State(),
4343
last_update_time=datetime.now().timestamp(),
4444
create_time=datetime.now().timestamp(),
4545
)

arkitect/core/component/checkpoint/redis_checkpoint_service.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
BaseCheckpointService,
2020
)
2121
from arkitect.core.component.checkpoint.checkpoint import Checkpoint
22-
from arkitect.core.component.llm_event_stream.model import NewState
22+
from arkitect.core.component.llm_event_stream.model import State
2323
from arkitect.utils.common import Singleton
2424

2525

@@ -48,7 +48,7 @@ async def create_checkpoint(
4848
id=checkpoint_id,
4949
app_name=app_name,
5050
user_id=user_id,
51-
state=NewState(),
51+
state=State(),
5252
last_update_time=datetime.now().timestamp(),
5353
create_time=datetime.now().timestamp(),
5454
)

arkitect/core/component/llm_event_stream/chat_completion.py

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -23,15 +23,18 @@
2323
)
2424

2525
from arkitect.core.component.tool.tool_pool import ToolPool
26-
from arkitect.types.llm.model import Message
26+
from arkitect.types.llm.model import Message, ArkChatParameters
2727
from arkitect.types.responses.event import BaseEvent, MessageEvent, StateUpdateEvent
2828

29-
from .model import NewState
29+
from .model import State
3030

3131

3232
class _AsyncCompletions(AsyncCompletions):
33-
def __init__(self, client: AsyncArk, state: NewState):
33+
def __init__(
34+
self, client: AsyncArk, state: State, parameters: ArkChatParameters | None
35+
):
3436
self._state = state
37+
self.parameters = parameters
3538
super().__init__(client)
3639

3740
async def create_event_stream(
@@ -41,11 +44,7 @@ async def create_event_stream(
4144
tool_pool: ToolPool | None = None,
4245
**kwargs: Dict[str, Any],
4346
) -> AsyncIterable[BaseEvent]:
44-
parameters = (
45-
self._state.parameters.__dict__
46-
if self._state.parameters is not None
47-
else {}
48-
)
47+
parameters = self.parameters.__dict__ if self.parameters is not None else {}
4948
if tool_pool:
5049
tools = await tool_pool.list_tools()
5150
parameters["tools"] = [t.model_dump() for t in tools]
@@ -103,10 +102,16 @@ async def iterator() -> AsyncIterable[BaseEvent]:
103102

104103

105104
class _AsyncChat(AsyncChat):
106-
def __init__(self, client: AsyncArk, state: NewState):
105+
def __init__(
106+
self,
107+
client: AsyncArk,
108+
state: State,
109+
parameters: ArkChatParameters | None,
110+
):
107111
self._state = state
112+
self.parameters = parameters
108113
super().__init__(client)
109114

110115
@property
111116
def completions(self) -> _AsyncCompletions:
112-
return _AsyncCompletions(self._client, self._state)
117+
return _AsyncCompletions(self._client, self._state, self.parameters)

arkitect/core/component/llm_event_stream/context_completion.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,11 +29,11 @@
2929
from arkitect.core.component.tool.tool_pool import ToolPool
3030
from arkitect.types.responses.event import BaseEvent
3131

32-
from .model import NewState
32+
from .model import State
3333

3434

3535
class _AsyncCompletions(AsyncCompletions):
36-
def __init__(self, client: AsyncArk, state: NewState):
36+
def __init__(self, client: AsyncArk, state: State):
3737
self._state = state
3838
super().__init__(client)
3939

@@ -89,7 +89,7 @@ async def create_event_stream(
8989

9090

9191
class _AsyncContext(AsyncContext):
92-
def __init__(self, client: AsyncArk, state: NewState):
92+
def __init__(self, client: AsyncArk, state: State):
9393
self._state = state
9494
super().__init__(client)
9595

arkitect/core/component/llm_event_stream/hooks.py

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -12,25 +12,23 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import abc
15-
from typing import Any, Optional, Union
16-
17-
from typing import Any, AsyncIterable, Optional
15+
from typing import Any, AsyncIterable, Optional, Union
1816

1917

2018
from arkitect.types.responses.event import (
2119
BaseEvent,
2220
)
2321

24-
from .model import ContextInterruption, NewState
22+
from .model import ContextInterruption, State
2523

26-
from .model import NewState
24+
from .model import State
2725

2826

2927
class HookInterruptException(Exception):
3028
def __init__(
3129
self,
3230
reason: str,
33-
state: Optional[NewState] = None,
31+
state: Optional[State] = None,
3432
details: Optional[Any] = None,
3533
):
3634
self.reason = reason
@@ -44,7 +42,7 @@ async def pre_tool_call(
4442
self,
4543
name: str,
4644
arguments: str,
47-
state: NewState,
45+
state: State,
4846
) -> AsyncIterable[BaseEvent | ContextInterruption]:
4947
pass
5048

@@ -57,7 +55,7 @@ async def post_tool_call(
5755
arguments: str,
5856
response: Any,
5957
exception: Optional[Exception],
60-
state: NewState,
58+
state: State,
6159
) -> AsyncIterable[BaseEvent | ContextInterruption]:
6260
pass
6361

@@ -66,7 +64,7 @@ class PreLLMCallHook(abc.ABC):
6664
@abc.abstractmethod
6765
async def pre_llm_call(
6866
self,
69-
state: NewState,
67+
state: State,
7068
) -> AsyncIterable[BaseEvent | ContextInterruption]:
7169
pass
7270

@@ -75,7 +73,7 @@ class PostLLMCallHook(abc.ABC):
7573
@abc.abstractmethod
7674
async def post_llm_call(
7775
self,
78-
state: NewState,
76+
state: State,
7977
) -> AsyncIterable[BaseEvent | ContextInterruption]:
8078
pass
8179

@@ -93,7 +91,7 @@ async def pre_tool_call(
9391
self,
9492
name: str,
9593
arguments: str,
96-
state: NewState,
94+
state: State,
9795
) -> AsyncIterable[BaseEvent | ContextInterruption]:
9896
if len(state.events) == 0:
9997
return

0 commit comments

Comments
 (0)