Skip to content

Commit d53c013

Browse files
zakahanwgzesg-bd
authored andcommitted
feat: agent
1 parent a8afb60 commit d53c013

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

79 files changed

+2076
-22731
lines changed

.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,4 +49,5 @@ __pycache__/
4949
*.local.yaml
5050
.cloudide/
5151

52-
dist/*.whl
52+
dist/*.whl
53+
examples/**/*.json

Makefile

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,8 @@ test:
4343
# Define a variable for Python and notebook files.
4444
PYTHON_FILES=.
4545
MYPY_CACHE=.mypy_cache
46-
format: PYTHON_FILES=./arkitect
47-
lint: PYTHON_FILES=./arkitect
46+
format: PYTHON_FILES=./arkitect ./examples
47+
lint: PYTHON_FILES=./arkitect ./examples
4848
lint_diff format_diff: PYTHON_FILES=$(shell git diff --relative=arkitect/ --name-only --diff-filter=d master | grep -E '\.py$$|\.ipynb$$')
4949
lint_package: PYTHON_FILES=arkitect
5050
lint_tests: PYTHON_FILES=tests

arkitect/core/client/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
from .base import Client, ClientPool, get_client_pool
1616
from .http import default_ark_client, load_request
17+
from .redis import RedisClient
1718
from .sse import AsyncSSEDecoder
1819

1920
__all__ = [
@@ -23,4 +24,5 @@
2324
"default_ark_client",
2425
"load_request",
2526
"get_client_pool",
27+
"RedisClient",
2628
]

arkitect/core/client/redis.py

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
# Copyright 2025 Bytedance Ltd. and/or its affiliates
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import redis.asyncio as redis
16+
from redis.asyncio.retry import Retry
17+
from redis.backoff import ExponentialBackoff
18+
from redis.exceptions import BusyLoadingError, ConnectionError, TimeoutError
19+
20+
from arkitect.core.client.base import Client
21+
22+
23+
class RedisClient(Client):
24+
"""
25+
Initialize a new Redis client object.
26+
27+
Parameters:
28+
host (str): The hostname of the Redis server.
29+
username (str): The username for the Redis server.
30+
password (str): The password for the Redis server.
31+
32+
Returns:
33+
None.
34+
35+
"""
36+
37+
def __init__(self, host: str, username: str, password: str):
38+
self.client = redis.Redis(
39+
host=host,
40+
username=username,
41+
password=password,
42+
retry=Retry(ExponentialBackoff(), 3),
43+
retry_on_error=[BusyLoadingError, ConnectionError, TimeoutError],
44+
)
45+
46+
async def get(self, key: str) -> str:
47+
"""
48+
Get the value of a key from the Redis database.
49+
50+
Args:
51+
key (str): The key to retrieve from the Redis database.
52+
53+
Returns:
54+
str: The value of the key, or None if the key does not exist.
55+
56+
"""
57+
return await self.client.get(key)
58+
59+
async def set(self, key: str, value: str) -> None:
60+
"""
61+
Set the value of a key in the Redis database.
62+
Args:
63+
key (str): The key to set in the Redis database.
64+
value (str): The value to set for the key.
65+
Returns:
66+
None.
67+
"""
68+
await self.client.set(key, value)
69+
70+
async def get_with_prefix(self, prefix: str) -> tuple[list[str], list[str]]:
71+
"""
72+
Asynchronous method to obtain all keys and values from the
73+
Redis database that match the specified prefix
74+
75+
:param prefix: The specified prefix
76+
77+
:return: A list of tuples containing matching keys
78+
and their corresponding values
79+
"""
80+
81+
cursor = 0
82+
keys = []
83+
84+
while True:
85+
# 使用 SCAN 命令进行迭代查询
86+
cursor, key_data = await self.client.scan(cursor, match=prefix, count=1000)
87+
88+
# 将匹配到的 key 添加到列表中
89+
keys.extend(key_data)
90+
91+
# 如果游标值为 0,则表示遍历完成
92+
if cursor == 0 or len(key_data) == 0:
93+
break
94+
95+
# 使用 MGET 命令获取所有匹配到的 key 的对应 value
96+
values = await self.client.mget(keys)
97+
98+
return keys, values
99+
100+
async def mget(self, keys: list[str]) -> list[str]:
101+
"""
102+
Get the values of multiple keys from the Redis database.
103+
104+
Args:
105+
keys (list): A list of keys to retrieve from the Redis database.
106+
107+
Returns:
108+
list: A list of values corresponding to the given keys.
109+
110+
"""
111+
return await self.client.mget(keys)
112+
113+
async def delete(self, key: str) -> None:
114+
"""
115+
Delete a key from the Redis database.
116+
Args:
117+
key (str): The key to delete from the Redis database.
118+
Returns:
119+
None.
120+
"""
121+
await self.client.delete(key)
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
# Copyright 2025 Bytedance Ltd. and/or its affiliates
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from arkitect.core.component.agent.base_agent import BaseAgent
16+
from arkitect.core.component.agent.default_agent import DefaultAgent
17+
from arkitect.core.component.agent.parallel_agent import ParallelAgent
18+
19+
__all__ = ["BaseAgent", "ParallelAgent", "DefaultAgent"]
Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
# Copyright 2025 Bytedance Ltd. and/or its affiliates
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import abc
16+
from typing import Any, AsyncIterable, Callable, Union
17+
18+
from pydantic import BaseModel
19+
from volcenginesdkarkruntime import AsyncArk
20+
21+
from arkitect.core.component.llm_event_stream.model import State
22+
from arkitect.core.component.tool import MCPClient
23+
from arkitect.types.llm.model import ArkChatParameters
24+
from arkitect.types.responses.event import BaseEvent
25+
26+
"""
27+
Agent is the core interface for all runnable agents
28+
"""
29+
30+
31+
class PreAgentCallHook(abc.ABC):
32+
@abc.abstractmethod
33+
async def pre_agent_call(
34+
self,
35+
state: State,
36+
) -> AsyncIterable[BaseEvent]:
37+
return
38+
yield
39+
40+
41+
class PostAgentCallHook(abc.ABC):
42+
@abc.abstractmethod
43+
async def post_agent_call(
44+
self,
45+
state: State,
46+
) -> AsyncIterable[BaseEvent]:
47+
return
48+
yield
49+
50+
51+
class BaseAgent(abc.ABC, BaseModel):
52+
name: str
53+
description: str = ""
54+
model: str
55+
tools: list[Union[MCPClient | Callable]] = []
56+
sub_agents: list["BaseAgent"] = []
57+
instruction: str | None = None
58+
parameters: ArkChatParameters | None = None
59+
client: AsyncArk | None = None
60+
61+
pre_agent_call_hook: PreAgentCallHook | None = None
62+
post_agent_call_hook: PostAgentCallHook | None = None
63+
64+
model_config = {
65+
"arbitrary_types_allowed": True,
66+
}
67+
68+
# stream run step
69+
@abc.abstractmethod
70+
async def _astream(self, state: State, **kwargs: Any) -> AsyncIterable[BaseEvent]:
71+
return
72+
yield
73+
74+
async def astream(self, state: State, **kwargs: Any) -> AsyncIterable[BaseEvent]:
75+
if self.pre_agent_call_hook:
76+
async for event in self.pre_agent_call_hook.pre_agent_call(state):
77+
yield event
78+
79+
async for event in self._astream(state, **kwargs):
80+
if event.author == "":
81+
event.author = self.name
82+
yield event
83+
84+
if self.post_agent_call_hook:
85+
async for event in self.post_agent_call_hook.post_agent_call(state):
86+
yield event
87+
88+
async def __call__(self, state: State, **kwargs: Any) -> AsyncIterable[BaseEvent]:
89+
async for event in self.astream(state, **kwargs):
90+
yield event
91+
92+
93+
class SwitchAgent(BaseModel):
94+
agent_name: str
95+
message: str
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
# Copyright 2025 Bytedance Ltd. and/or its affiliates
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from typing import Any, AsyncIterable
16+
17+
from pydantic import BaseModel
18+
19+
from arkitect.core.component.agent import BaseAgent
20+
from arkitect.core.component.llm_event_stream.hooks import (
21+
PostLLMCallHook,
22+
PostToolCallHook,
23+
PreLLMCallHook,
24+
PreToolCallHook,
25+
)
26+
from arkitect.core.component.llm_event_stream.llm_event_stream import LLMEventStream
27+
from arkitect.core.component.llm_event_stream.model import State
28+
from arkitect.types.responses.event import BaseEvent
29+
30+
"""
31+
Agent is the core interface for all runnable agents
32+
"""
33+
34+
35+
class DefaultAgent(BaseAgent):
36+
model_config = {
37+
"arbitrary_types_allowed": True,
38+
}
39+
40+
pre_tool_call_hook: PreToolCallHook | None = None
41+
post_tool_call_hook: PostToolCallHook | None = None
42+
pre_llm_call_hook: PreLLMCallHook | None = None
43+
post_llm_call_hook: PostLLMCallHook | None = None
44+
45+
# stream run step
46+
async def _astream(self, state: State, **kwargs: Any) -> AsyncIterable[BaseEvent]:
47+
event_stream = LLMEventStream(
48+
model=self.model,
49+
agent_name=self.name,
50+
tools=self.tools,
51+
sub_agents=self.sub_agents,
52+
state=state,
53+
instruction=self.instruction,
54+
pre_tool_call_hook=self.pre_tool_call_hook,
55+
post_tool_call_hook=self.post_tool_call_hook,
56+
pre_llm_call_hook=self.pre_llm_call_hook,
57+
post_llm_call_hook=self.post_llm_call_hook,
58+
parameters=self.parameters,
59+
client=self.client,
60+
)
61+
await event_stream.init()
62+
resp_stream = await event_stream.completions.create(
63+
messages=[],
64+
**kwargs,
65+
)
66+
67+
async for event in resp_stream:
68+
yield event
69+
70+
71+
class SwitchAgent(BaseModel):
72+
agent_name: str
73+
message: str

0 commit comments

Comments
 (0)