Skip to content

Commit 3da0fb9

Browse files
committed
feat: format
1 parent 5074940 commit 3da0fb9

File tree

15 files changed

+131
-294
lines changed

15 files changed

+131
-294
lines changed

arkitect/core/component/agent/base_agent.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from pydantic import BaseModel
1919
from volcenginesdkarkruntime import AsyncArk
2020

21-
from arkitect.core.component.llm_event_stream.model import ContextInterruption, State
21+
from arkitect.core.component.llm_event_stream.model import State
2222
from arkitect.core.component.tool import MCPClient
2323
from arkitect.types.llm.model import ArkChatParameters
2424
from arkitect.types.responses.event import BaseEvent
@@ -33,17 +33,19 @@ class PreAgentCallHook(abc.ABC):
3333
async def pre_agent_call(
3434
self,
3535
state: State,
36-
) -> AsyncIterable[BaseEvent | ContextInterruption]:
37-
pass
36+
) -> AsyncIterable[BaseEvent]:
37+
return
38+
yield
3839

3940

4041
class PostAgentCallHook(abc.ABC):
4142
@abc.abstractmethod
4243
async def post_agent_call(
4344
self,
4445
state: State,
45-
) -> AsyncIterable[BaseEvent | ContextInterruption]:
46-
pass
46+
) -> AsyncIterable[BaseEvent]:
47+
return
48+
yield
4749

4850

4951
class BaseAgent(abc.ABC, BaseModel):
@@ -66,7 +68,8 @@ class BaseAgent(abc.ABC, BaseModel):
6668
# stream run step
6769
@abc.abstractmethod
6870
async def _astream(self, state: State, **kwargs: Any) -> AsyncIterable[BaseEvent]:
69-
pass
71+
return
72+
yield
7073

7174
async def astream(self, state: State, **kwargs: Any) -> AsyncIterable[BaseEvent]:
7275
if self.pre_agent_call_hook:

arkitect/core/component/agent/parallel_agent.py

Lines changed: 4 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -12,19 +12,8 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
16-
# Licensed under the 【火山方舟】原型应用软件自用许可协议
17-
# you may not use this file except in compliance with the License.
18-
# You may obtain a copy of the License at
19-
# https://www.volcengine.com/docs/82379/1433703
20-
# Unless required by applicable law or agreed to in writing, software
21-
# distributed under the License is distributed on an "AS IS" BASIS,
22-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
23-
# See the License for the specific language governing permissions and
24-
# limitations under the License.
25-
2615
import asyncio
27-
from typing import AsyncIterable
16+
from typing import Any, AsyncIterable
2817

2918
from arkitect.core.component.agent.base_agent import BaseAgent
3019
from arkitect.core.component.llm_event_stream.model import State
@@ -50,7 +39,7 @@ async def _merge_agent_run(
5039
Event: The next event from the merged generator.
5140
"""
5241
tasks = [
53-
asyncio.create_task(events_for_one_agent.__anext__())
42+
asyncio.create_task(events_for_one_agent.__anext__()) # type: ignore
5443
for events_for_one_agent in agent_runs
5544
]
5645
pending_tasks = set(tasks)
@@ -66,7 +55,7 @@ async def _merge_agent_run(
6655
# Find the generator that produced this event and move it on.
6756
for i, original_task in enumerate(tasks):
6857
if task == original_task:
69-
new_task = asyncio.create_task(agent_runs[i].__anext__())
58+
new_task = asyncio.create_task(agent_runs[i].__anext__()) # type: ignore
7059
tasks[i] = new_task
7160
pending_tasks.add(new_task)
7261
break # stop iterating once found
@@ -77,7 +66,7 @@ async def _merge_agent_run(
7766

7867
class ParallelAgent(BaseAgent):
7968
# stream run step
80-
async def _astream(self, state: State, **kwargs) -> AsyncIterable[BaseEvent]:
69+
async def _astream(self, state: State, **kwargs: Any) -> AsyncIterable[BaseEvent]:
8170
agent_runs = [agent(state) for agent in self.sub_agents]
8271
async for event in _merge_agent_run(agent_runs):
8372
yield event

arkitect/core/component/checkpoint/base_checkpoint_service.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,27 +13,31 @@
1313
# limitations under the License.
1414

1515
from abc import ABC, abstractmethod
16+
from typing import Any
1617

1718
from arkitect.core.component.checkpoint.checkpoint import Checkpoint
1819

1920

2021
class BaseCheckpointService(ABC):
2122
@abstractmethod
22-
def create_checkpoint(
23+
async def create_checkpoint(
2324
self,
2425
app_name: str,
2526
checkpoint_id: str,
2627
user_id: str,
2728
checkpoint: Checkpoint | None = None,
29+
**kwargs: Any,
2830
) -> Checkpoint:
2931
pass
3032

3133
@abstractmethod
32-
async def get_checkpoint(self, app_name: str, checkpoint_id: str) -> Checkpoint:
34+
async def get_checkpoint(
35+
self, app_name: str, checkpoint_id: str
36+
) -> Checkpoint | None:
3337
pass
3438

3539
@abstractmethod
36-
async def list_checkpoints(self, app_name: str) -> list[Checkpoint]:
40+
async def list_checkpoints(self, app_name: str, **kwargs: Any) -> list[Checkpoint]:
3741
pass
3842

3943
@abstractmethod

arkitect/core/component/checkpoint/in_memory_checkpoint_service.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414

1515
from datetime import datetime
16+
from typing import Any
1617

1718
from arkitect.core.component.checkpoint.base_checkpoint_service import (
1819
BaseCheckpointService,
@@ -23,7 +24,7 @@
2324

2425

2526
class InMemoryCheckpointService(BaseCheckpointService):
26-
def __init__(self):
27+
def __init__(self) -> None:
2728
# A map from app name to a map from user ID to a map from session ID to session.
2829
self.checkpoints: dict[str, dict[str, Checkpoint]] = {}
2930

@@ -33,6 +34,7 @@ async def create_checkpoint(
3334
checkpoint_id: str,
3435
user_id: str,
3536
checkpoint: Checkpoint | None = None,
37+
**kwargs: Any,
3638
) -> Checkpoint:
3739
checkpoint = (
3840
Checkpoint(
@@ -53,10 +55,12 @@ async def create_checkpoint(
5355

5456
return checkpoint
5557

56-
async def get_checkpoint(self, app_name: str, checkpoint_id: str) -> Checkpoint:
58+
async def get_checkpoint(
59+
self, app_name: str, checkpoint_id: str
60+
) -> Checkpoint | None:
5761
return self.checkpoints.get(app_name, {}).get(checkpoint_id, None)
5862

59-
async def list_checkpoints(self, app_name: str) -> list[Checkpoint]:
63+
async def list_checkpoints(self, app_name: str, **kwargs: Any) -> list[Checkpoint]:
6064
return list(self.checkpoints.get(app_name, {}).values())
6165

6266
async def update_checkpoint(

arkitect/core/component/checkpoint/redis_checkpoint_service.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414

1515
from datetime import datetime
16+
from typing import Any
1617

1718
from arkitect.core.client.redis import RedisClient
1819
from arkitect.core.component.checkpoint.base_checkpoint_service import (
@@ -42,6 +43,7 @@ async def create_checkpoint(
4243
checkpoint_id: str,
4344
user_id: str,
4445
checkpoint: Checkpoint | None = None,
46+
**kwargs: Any,
4547
) -> Checkpoint:
4648
checkpoint = (
4749
Checkpoint(
@@ -60,13 +62,19 @@ async def create_checkpoint(
6062
await self.redis_client.set(key, checkpoint.model_dump_json())
6163
return checkpoint
6264

63-
async def get_checkpoint(self, app_name: str, checkpoint_id: str) -> Checkpoint:
65+
async def get_checkpoint(
66+
self, app_name: str, checkpoint_id: str
67+
) -> Checkpoint | None:
6468
value = await self.redis_client.get(make_key(app_name, checkpoint_id))
6569
if value is None:
6670
return None
6771
return Checkpoint.model_validate_json(value)
6872

69-
async def list_checkpoints(self, app_name: str) -> list[Checkpoint]:
73+
async def list_checkpoints(
74+
self,
75+
app_name: str,
76+
**kwargs: Any,
77+
) -> list[Checkpoint]:
7078
keys, values = await self.redis_client.get_with_prefix(make_key(app_name, "*"))
7179
return [Checkpoint.model_validate_json(value) for value in values]
7280

arkitect/core/component/context/context.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -95,11 +95,11 @@ async def handle_tool_call(self) -> bool:
9595
tool_exception = e
9696

9797
self._ctx.state.messages.append(
98-
{
98+
{ # type: ignore
9999
"role": "tool",
100100
"tool_call_id": tool_call_param.get("id", ""),
101101
"content": tool_resp if tool_resp else str(tool_exception),
102-
}
102+
} # type: ignore
103103
)
104104
if self._ctx.post_tool_call_hook:
105105
# post tool call hooks
@@ -327,7 +327,7 @@ async def tool_call_events() -> AsyncIterable[ToolChunk]:
327327
tool_response=resp,
328328
)
329329
self._ctx.state.messages.append(
330-
{
330+
{ # type: ignore
331331
"role": "tool",
332332
"tool_call_id": tool_call.get("id", ""),
333333
"content": resp,

arkitect/core/component/llm_event_stream/context_completion.py

Lines changed: 0 additions & 98 deletions
This file was deleted.

0 commit comments

Comments
 (0)