Skip to content

Commit 0d1b92a

Browse files
committed
fix: fix type errors and add logs
1 parent bfffbda commit 0d1b92a

28 files changed

+246
-300
lines changed

veadk/__init__.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,17 +12,23 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from .version import VERSION
15+
from typing import TYPE_CHECKING
16+
17+
from veadk.version import VERSION
18+
19+
if TYPE_CHECKING:
20+
from veadk.agent import Agent
21+
from veadk.runner import Runner
1622

1723

1824
# Lazy loading for `Agent` class
1925
def __getattr__(name):
2026
if name == "Agent":
21-
from .agent import Agent
27+
from veadk.agent import Agent
2228

2329
return Agent
2430
if name == "Runner":
25-
from .runner import Runner
31+
from veadk.runner import Runner
2632

2733
return Runner
2834
raise AttributeError(f"module 'veadk' has no attribute '{name}'")

veadk/agent.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from typing import Optional
1818

1919
from google.adk.agents import LlmAgent, RunConfig
20+
from google.adk.agents.base_agent import BaseAgent
2021
from google.adk.agents.llm_agent import ToolUnion
2122
from google.adk.agents.run_config import StreamingMode
2223
from google.adk.models.lite_llm import LiteLlm
@@ -39,7 +40,6 @@
3940
from veadk.tracing.base_tracer import BaseTracer
4041
from veadk.utils.logger import get_logger
4142
from veadk.utils.patches import patch_asyncio
42-
from google.adk.agents.base_agent import BaseAgent
4343

4444
patch_asyncio()
4545
logger = get_logger(__name__)
@@ -70,9 +70,7 @@ class Agent(LlmAgent):
7070
model_api_base: str = getenv("MODEL_AGENT_API_BASE", DEFAULT_MODEL_AGENT_API_BASE)
7171
"""The api base of the model for agent running."""
7272

73-
model_api_key: str = Field(
74-
..., default_factory=lambda: getenv("MODEL_AGENT_API_KEY")
75-
)
73+
model_api_key: str = Field(default_factory=lambda: getenv("MODEL_AGENT_API_KEY"))
7674
"""The api key of the model for agent running."""
7775

7876
tools: list[ToolUnion] = []
@@ -244,8 +242,13 @@ async def run(
244242
user_id=user_id,
245243
session_id=session_id,
246244
)
247-
await self.long_term_memory.add_session_to_memory(session)
248-
logger.info(f"Add session `{session.id}` to your long-term memory.")
245+
if session:
246+
await self.long_term_memory.add_session_to_memory(session)
247+
logger.info(f"Add session `{session.id}` to your long-term memory.")
248+
else:
249+
logger.error(
250+
f"Session {session_id} not found in session service, cannot save to long-term memory."
251+
)
249252

250253
if collect_runtime_data:
251254
eval_set_recorder = EvalSetRecorder(session_service, eval_set_id)
@@ -254,6 +257,6 @@ async def run(
254257

255258
if self.tracers:
256259
for tracer in self.tracers:
257-
tracer.dump(user_id, session_id)
260+
tracer.dump(user_id=user_id, session_id=session_id)
258261

259262
return final_output

veadk/cli/services/vefaas/vefaas.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -304,6 +304,7 @@ def find_app_id_by_name(self, name: str):
304304
for app in apps:
305305
if app["Name"] == name:
306306
return app["Id"]
307+
logger.warning(f"Application with name {name} not found.")
307308
return None
308309

309310
def delete(self, app_id: str):

veadk/cloud/cloud_agent_engine.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,10 @@ def remove(self, app_name: str):
166166
return
167167
else:
168168
app_id = self._vefaas_service.find_app_id_by_name(app_name)
169+
if not app_id:
170+
raise ValueError(
171+
f"Cloud app {app_name} not found, cannot delete it. Please check the app name."
172+
)
169173
self._vefaas_service.delete(app_id)
170174

171175
def update_function_code(

veadk/cloud/cloud_app.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,10 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import json
1516
from typing import Any
1617
from uuid import uuid4
1718

18-
import json
1919
import httpx
2020
from a2a.client import A2ACardResolver, A2AClient
2121
from a2a.types import AgentCard, Message, MessageSendParams, SendMessageRequest
@@ -82,11 +82,18 @@ def _get_vefaas_endpoint(
8282
from veadk.cli.services.vefaas.vefaas import VeFaaS
8383

8484
vefaas_client = VeFaaS(access_key=volcengine_ak, secret_key=volcengine_sk)
85+
8586
app = vefaas_client.get_application_details(
8687
app_id=self.vefaas_application_id,
8788
app_name=self.vefaas_application_name,
8889
)
90+
91+
if not app:
92+
raise ValueError(
93+
f"VeFaaS CloudAPP with application_id `{self.vefaas_application_id}` or application_name `{self.vefaas_application_name}` not found."
94+
)
8995
cloud_resource = json.loads(app["CloudResource"])
96+
9097
try:
9198
vefaas_endpoint = cloud_resource["framework"]["url"]["system_url"]
9299
except Exception as e:
@@ -180,11 +187,19 @@ async def message_send(
180187
id=uuid4().hex,
181188
params=MessageSendParams(**send_message_payload),
182189
)
190+
183191
res = await a2a_client.send_message(
184192
message_send_request,
185193
http_kwargs={"timeout": httpx.Timeout(timeout)},
186194
)
187-
return res.root.result
195+
196+
logger.debug(
197+
f"Message sent to cloud app {self.vefaas_application_name} with response: {res}"
198+
)
199+
200+
# we ignore type checking here, because the response
201+
# from CloudApp will not be `Task` type
202+
return res.root.result # type: ignore
188203
except Exception as e:
189204
# TODO(floritange): show error log on VeFaaS function
190205
print(e)
@@ -194,7 +209,7 @@ async def message_send(
194209
def get_message_id(message: Message):
195210
"""Get the messageId of the a2a message"""
196211
if getattr(message, "messageId", None):
197-
# Compatible with the messageId of the old version
198-
return message.messageId
212+
# Compatible with the messageId of the old a2a-python version (<0.3.0) in cloud app
213+
return message.messageId # type: ignore
199214
else:
200215
return message.message_id

veadk/database/database_adapter.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@
1515
import re
1616
import time
1717
from typing import BinaryIO, TextIO
18-
from veadk.database.base_database import BaseDatabase
1918

19+
from veadk.database.base_database import BaseDatabase
2020
from veadk.utils.logger import get_logger
2121

2222
logger = get_logger(__name__)
@@ -41,7 +41,7 @@ def add(self, data: list[str], index: str):
4141
)
4242
raise e
4343

44-
def query(self, query: str, index: str, top_k: int = 0) -> list[str]:
44+
def query(self, query: str, index: str, top_k: int = 0) -> list:
4545
logger.debug(f"Querying Redis database: index={index} query={query}")
4646

4747
# ignore top_k, as KV search only return one result

veadk/database/kv/redis_database.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from typing import Any
1818

1919
import redis
20-
from pydantic import BaseModel, Field, PrivateAttr
20+
from pydantic import BaseModel, Field
2121
from typing_extensions import override
2222

2323
from veadk.config import getenv
@@ -30,19 +30,19 @@
3030

3131
class RedisDatabaseConfig(BaseModel):
3232
host: str = Field(
33-
default=getenv("DATABASE_REDIS_HOST"),
33+
default_factory=lambda: getenv("DATABASE_REDIS_HOST"),
3434
description="Redis host",
3535
)
3636
port: int = Field(
37-
default=getenv("DATABASE_REDIS_PORT"),
37+
default_factory=lambda: int(getenv("DATABASE_REDIS_PORT")),
3838
description="Redis port",
3939
)
4040
db: int = Field(
41-
default=getenv("DATABASE_REDIS_DB"),
41+
default_factory=lambda: int(getenv("DATABASE_REDIS_DB")),
4242
description="Redis db",
4343
)
4444
password: str = Field(
45-
default=getenv("DATABASE_REDIS_PASSWORD"),
45+
default_factory=lambda: getenv("DATABASE_REDIS_PASSWORD"),
4646
description="Redis password",
4747
)
4848
decode_responses: bool = Field(
@@ -53,7 +53,6 @@ class RedisDatabaseConfig(BaseModel):
5353

5454
class RedisDatabase(BaseModel, BaseDatabase):
5555
config: RedisDatabaseConfig = Field(default_factory=RedisDatabaseConfig)
56-
_client: redis.Redis = PrivateAttr(default=None)
5756

5857
def model_post_init(self, context: Any, /) -> None:
5958
try:
@@ -64,6 +63,7 @@ def model_post_init(self, context: Any, /) -> None:
6463
password=self.config.password,
6564
decode_responses=self.config.decode_responses,
6665
)
66+
6767
self._client.ping()
6868
logger.info("Connected to Redis successfully.")
6969
except Exception as e:
@@ -79,10 +79,10 @@ def add(self, key: str, value: str, **kwargs):
7979
raise e
8080

8181
@override
82-
def query(self, key: str, query: str = "", **kwargs) -> list[str]:
82+
def query(self, key: str, query: str = "", **kwargs) -> list:
8383
try:
8484
result = self._client.lrange(key, 0, -1)
85-
return result
85+
return result # type: ignore
8686
except Exception as e:
8787
logger.error(f"Failed to search from Redis list key '{key}': {e}")
8888
raise e
@@ -99,8 +99,11 @@ def delete(self, **kwargs):
9999

100100
try:
101101
# For simple key deletion
102+
# We use sync Redis client to delete the key
103+
# so the result will be `int`
102104
result = self._client.delete(key)
103-
if result > 0:
105+
106+
if result > 0: # type: ignore
104107
logger.info(f"Deleted key `{key}` from Redis.")
105108
else:
106109
logger.info(f"Key `{key}` not found in Redis. Skipping deletion.")

veadk/database/relational/mysql_database.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from typing import Any
1818

1919
import pymysql
20-
from pydantic import BaseModel, Field, PrivateAttr
20+
from pydantic import BaseModel, Field
2121
from typing_extensions import override
2222

2323
from veadk.config import getenv
@@ -30,32 +30,30 @@
3030

3131
class MysqlDatabaseConfig(BaseModel):
3232
host: str = Field(
33-
default=getenv("DATABASE_MYSQL_HOST"),
33+
default_factory=lambda: getenv("DATABASE_MYSQL_HOST"),
3434
description="Mysql host",
3535
)
3636
user: str = Field(
37-
default=getenv("DATABASE_MYSQL_USER"),
37+
default_factory=lambda: getenv("DATABASE_MYSQL_USER"),
3838
description="Mysql user",
3939
)
4040
password: str = Field(
41-
default=getenv("DATABASE_MYSQL_PASSWORD"),
41+
default_factory=lambda: getenv("DATABASE_MYSQL_PASSWORD"),
4242
description="Mysql password",
4343
)
4444
database: str = Field(
45-
default=getenv("DATABASE_MYSQL_DATABASE"),
45+
default_factory=lambda: getenv("DATABASE_MYSQL_DATABASE"),
4646
description="Mysql database",
4747
)
4848
charset: str = Field(
49-
default=getenv("DATABASE_MYSQL_CHARSET", "utf8mb4"),
49+
default_factory=lambda: getenv("DATABASE_MYSQL_CHARSET", "utf8mb4"),
5050
description="Mysql charset",
5151
)
5252

5353

5454
class MysqlDatabase(BaseModel, BaseDatabase):
5555
config: MysqlDatabaseConfig = Field(default_factory=MysqlDatabaseConfig)
5656

57-
_connection: pymysql.Connection = PrivateAttr(default=None)
58-
5957
def model_post_init(self, context: Any, /) -> None:
6058
self._connection = pymysql.connect(
6159
host=self.config.host,
@@ -65,6 +63,9 @@ def model_post_init(self, context: Any, /) -> None:
6563
charset=self.config.charset,
6664
cursorclass=pymysql.cursors.DictCursor,
6765
)
66+
self._connection.ping()
67+
logger.info("Connected to MySQL successfully.")
68+
6869
self._type = "mysql"
6970

7071
def table_exists(self, table: str) -> bool:
@@ -83,7 +84,7 @@ def add(self, sql: str, params=None, **kwargs):
8384
self._connection.commit()
8485

8586
@override
86-
def query(self, sql: str, params=None, **kwargs) -> list[str]:
87+
def query(self, sql: str, params=None, **kwargs) -> tuple[dict[str, Any], ...]:
8788
with self._connection.cursor() as cursor:
8889
cursor.execute(sql, params)
8990
return cursor.fetchall()

veadk/database/vector/opensearch_vector_database.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -32,23 +32,29 @@
3232

3333
class OpenSearchVectorDatabaseConfig(BaseModel):
3434
host: str = Field(
35-
default=getenv("DATABASE_OPENSEARCH_HOST"),
35+
default_factory=lambda: getenv("DATABASE_OPENSEARCH_HOST"),
3636
description="OpenSearch host",
3737
)
38+
3839
port: str | int = Field(
39-
default=getenv("DATABASE_OPENSEARCH_PORT"),
40+
default_factory=lambda: getenv("DATABASE_OPENSEARCH_PORT"),
4041
description="OpenSearch port",
4142
)
43+
4244
username: Optional[str] = Field(
43-
default=getenv("DATABASE_OPENSEARCH_USERNAME"),
45+
default_factory=lambda: getenv("DATABASE_OPENSEARCH_USERNAME"),
4446
description="OpenSearch username",
4547
)
48+
4649
password: Optional[str] = Field(
47-
default=getenv("DATABASE_OPENSEARCH_PASSWORD"),
50+
default_factory=lambda: getenv("DATABASE_OPENSEARCH_PASSWORD"),
4851
description="OpenSearch password",
4952
)
53+
5054
secure: bool = Field(default=True, description="Whether enable SSL")
55+
5156
verify_certs: bool = Field(default=False, description="Whether verify SSL certs")
57+
5258
auth_method: Literal["basic", "aws_managed_iam"] = Field(
5359
default="basic", description="OpenSearch auth method"
5460
)
@@ -231,15 +237,16 @@ def get_all_docs(self, collection_name: str, size: int = 10000) -> list[dict]:
231237
for hit in response["hits"]["hits"]
232238
]
233239

234-
def delete_by_query(self, collection_name: str, query: str):
240+
def delete_by_query(self, collection_name: str, query: str) -> Any:
235241
"""Delete docs by query in one index of OpenSearch"""
236242
if not self.collection_exists(collection_name):
237243
raise ValueError(f"Collection {collection_name} does not exist.")
238244

239-
query = {"query": {"match": {"page_content": query}}}
245+
query_payload = {"query": {"match": {"page_content": query}}}
240246
response = self._opensearch_client.delete_by_query(
241-
index=collection_name, body=query
247+
index=collection_name, body=query_payload
242248
)
249+
243250
self._opensearch_client.indices.refresh(index=collection_name)
244251
return response
245252

veadk/database/viking/viking_database.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,13 +136,17 @@ def _upload_to_tos(
136136
file_ext = kwargs.get(
137137
"file_ext", ".pdf"
138138
) # when bytes data, file_ext is required
139+
139140
ak = self.config.volcengine_ak
140141
sk = self.config.volcengine_sk
142+
141143
tos_bucket = self.config.tos.bucket
142144
tos_endpoint = self.config.tos.endpoint
143145
tos_region = self.config.tos.region
144146
tos_key = self.config.tos.base_key
147+
145148
client = tos.TosClientV2(ak, sk, tos_endpoint, tos_region, max_connections=1024)
149+
146150
if isinstance(data, str) and os.path.isfile(data): # Process file path
147151
file_ext = os.path.splitext(data)[1]
148152
new_key = f"{tos_key}/{str(uuid.uuid4())}{file_ext}"

0 commit comments

Comments
 (0)