Skip to content

Commit 69452e7

Browse files
committed
fix: sgp migration
1 parent 361c2a6 commit 69452e7

File tree

4 files changed

+231
-11
lines changed

4 files changed

+231
-11
lines changed

examples/tutorials/00_sync/010_multiturn/project/acp.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -39,10 +39,16 @@ async def handle_message_send(
3939
if params.content.author != "user":
4040
raise ValueError(f"Expected user message, got {params.content.author}")
4141

42-
if not os.environ.get("OPENAI_API_KEY"):
42+
if not os.environ.get("SGP_API_KEY"):
4343
return TextContent(
4444
author="agent",
45-
content="Hey, sorry I'm unable to respond to your message because you're running this example without an OpenAI API key. Please set the OPENAI_API_KEY environment variable to run this example. Do this by either by adding a .env file to the project/ directory or by setting the environment variable in your terminal.",
45+
content="Hey, sorry I'm unable to respond to your message because you're running this example without an SGP API key. Please set the SGP_API_KEY environment variable to run this example. Do this by either by adding a .env file to the project/ directory or by setting the environment variable in your terminal.",
46+
)
47+
48+
if not os.environ.get("SGP_ACCOUNT_ID"):
49+
return TextContent(
50+
author="agent",
51+
content="Hey, sorry I'm unable to respond to your message because you're running this example without an SGP Account ID. Please set the SGP_ACCOUNT_ID environment variable to run this example. Do this by either by adding a .env file to the project/ directory or by setting the environment variable in your terminal.",
4652
)
4753

4854
#########################################################
@@ -54,7 +60,7 @@ async def handle_message_send(
5460

5561
if not task_state:
5662
# If the state doesn't exist, create it.
57-
state = StateModel(system_prompt="You are a helpful assistant that can answer questions.", model="gpt-4o-mini")
63+
state = StateModel(system_prompt="You are a helpful assistant that can answer questions.", model="openai/gpt-4o-mini")
5864
task_state = await adk.state.create(task_id=params.task.id, agent_id=params.agent.id, state=state)
5965
else:
6066
state = StateModel.model_validate(task_state.state)
@@ -96,7 +102,7 @@ async def handle_message_send(
96102
#########################################################
97103

98104
# Call an LLM to respond to the user's message
99-
chat_completion = await adk.providers.litellm.chat_completion(
105+
chat_completion = await adk.providers.sgp.chat_completion(
100106
llm_config=LLMConfig(model=state.model, messages=llm_messages),
101107
trace_id=params.task.id,
102108
)

examples/tutorials/00_sync/020_streaming/project/acp.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,13 +41,23 @@ async def handle_message_send(
4141
if params.content.author != "user":
4242
raise ValueError(f"Expected user message, got {params.content.author}")
4343

44-
if not os.environ.get("OPENAI_API_KEY"):
44+
if not os.environ.get("SGP_API_KEY"):
4545
yield StreamTaskMessageFull(
4646
index=0,
4747
type="full",
4848
content=TextContent(
4949
author="agent",
50-
content="Hey, sorry I'm unable to respond to your message because you're running this example without an OpenAI API key. Please set the OPENAI_API_KEY environment variable to run this example. Do this by either by adding a .env file to the project/ directory or by setting the environment variable in your terminal.",
50+
content="Hey, sorry I'm unable to respond to your message because you're running this example without an SGP API key. Please set the SGP_API_KEY environment variable to run this example. Do this by either by adding a .env file to the project/ directory or by setting the environment variable in your terminal.",
51+
),
52+
)
53+
54+
if not os.environ.get("SGP_ACCOUNT_ID"):
55+
yield StreamTaskMessageFull(
56+
index=0,
57+
type="full",
58+
content=TextContent(
59+
author="agent",
60+
content="Hey, sorry I'm unable to respond to your message because you're running this example without an SGP Account ID. Please set the SGP_ACCOUNT_ID environment variable to run this example. Do this by either by adding a .env file to the project/ directory or by setting the environment variable in your terminal.",
5161
),
5262
)
5363

@@ -56,7 +66,7 @@ async def handle_message_send(
5666

5767
if not task_state:
5868
# If the state doesn't exist, create it.
59-
state = StateModel(system_prompt="You are a helpful assistant that can answer questions.", model="gpt-4o-mini")
69+
state = StateModel(system_prompt="You are a helpful assistant that can answer questions.", model="openai/gpt-4o-mini")
6070
task_state = await adk.state.create(task_id=params.task.id, agent_id=params.agent.id, state=state)
6171
else:
6272
state = StateModel.model_validate(task_state.state)
@@ -83,7 +93,7 @@ async def handle_message_send(
8393
# The Agentex server automatically commits input and output messages to the database so you don't need to do this yourself, simply process the input content and return the output content.
8494

8595
message_index = 0
86-
async for chunk in adk.providers.litellm.chat_completion_stream(
96+
async for chunk in adk.providers.sgp.chat_completion_stream(
8797
llm_config=LLMConfig(model=state.model, messages=llm_messages, stream=True),
8898
trace_id=params.task.id,
8999
):

src/agentex/lib/adk/providers/_modules/sgp.py

Lines changed: 200 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,29 @@
11
from datetime import timedelta
2+
from typing import AsyncGenerator
23

34
from agentex.lib.adk.utils._modules.client import create_async_agentex_client
45
from scale_gp import SGPClient, SGPClientError
56
from temporalio.common import RetryPolicy
67

78
from agentex import AsyncAgentex
9+
from agentex.lib.core.adapters.llm.adapter_sgp import SGPLLMGateway
10+
from agentex.lib.core.adapters.streams.adapter_redis import RedisStreamRepository
11+
from agentex.lib.core.services.adk.providers.litellm import LiteLLMService
812
from agentex.lib.core.services.adk.providers.sgp import SGPService
13+
from agentex.lib.core.services.adk.streaming import StreamingService
914
from agentex.lib.core.temporal.activities.activity_helpers import ActivityHelpers
15+
from agentex.lib.core.temporal.activities.adk.providers.litellm_activities import ChatCompletionParams, \
16+
LiteLLMActivityName, ChatCompletionAutoSendParams, ChatCompletionStreamAutoSendParams
1017
from agentex.lib.core.temporal.activities.adk.providers.sgp_activities import (
1118
DownloadFileParams,
1219
FileContentResponse,
1320
SGPActivityName,
1421
)
1522
from agentex.lib.core.tracing.tracer import AsyncTracer
23+
from agentex.lib.types.llm_messages import LLMConfig, Completion
1624
from agentex.lib.utils.logging import make_logger
1725
from agentex.lib.utils.temporal import in_temporal_workflow
26+
from agentex.types import TaskMessage
1827

1928
logger = make_logger(__name__)
2029

@@ -30,6 +39,7 @@ class SGPModule:
3039
def __init__(
3140
self,
3241
sgp_service: SGPService | None = None,
42+
litellm_service: LiteLLMService | None = None,
3343
):
3444
if sgp_service is None:
3545
try:
@@ -42,6 +52,21 @@ def __init__(
4252
else:
4353
self._sgp_service = sgp_service
4454

55+
agentex_client = create_async_agentex_client()
56+
stream_repository = RedisStreamRepository()
57+
streaming_service = StreamingService(
58+
agentex_client=agentex_client,
59+
stream_repository=stream_repository,
60+
)
61+
litellm_gateway = SGPLLMGateway()
62+
tracer = AsyncTracer(agentex_client)
63+
self._litellm_service = LiteLLMService(
64+
agentex_client=agentex_client,
65+
llm_gateway=litellm_gateway,
66+
streaming_service=streaming_service,
67+
tracer=tracer,
68+
)
69+
4570
async def download_file_content(
4671
self,
4772
params: DownloadFileParams,
@@ -84,3 +109,178 @@ async def download_file_content(
84109
file_id=params.file_id,
85110
filename=params.filename,
86111
)
112+
113+
async def chat_completion(
114+
self,
115+
llm_config: LLMConfig,
116+
trace_id: str | None = None,
117+
parent_span_id: str | None = None,
118+
start_to_close_timeout: timedelta = timedelta(seconds=120),
119+
heartbeat_timeout: timedelta = timedelta(seconds=120),
120+
retry_policy: RetryPolicy = DEFAULT_RETRY_POLICY,
121+
) -> Completion:
122+
"""
123+
Perform a chat completion using LiteLLM.
124+
125+
Args:
126+
llm_config (LLMConfig): The configuration for the LLM.
127+
trace_id (Optional[str]): The trace ID for tracing.
128+
parent_span_id (Optional[str]): The parent span ID for tracing.
129+
start_to_close_timeout (timedelta): The start to close timeout.
130+
heartbeat_timeout (timedelta): The heartbeat timeout.
131+
retry_policy (RetryPolicy): The retry policy.
132+
133+
Returns:
134+
Completion: An OpenAI compatible Completion object
135+
"""
136+
if in_temporal_workflow():
137+
params = ChatCompletionParams(
138+
trace_id=trace_id, parent_span_id=parent_span_id, llm_config=llm_config
139+
)
140+
return await ActivityHelpers.execute_activity(
141+
activity_name=LiteLLMActivityName.CHAT_COMPLETION,
142+
request=params,
143+
response_type=Completion,
144+
start_to_close_timeout=start_to_close_timeout,
145+
heartbeat_timeout=heartbeat_timeout,
146+
retry_policy=retry_policy,
147+
)
148+
else:
149+
return await self._litellm_service.chat_completion(
150+
llm_config=llm_config,
151+
trace_id=trace_id,
152+
parent_span_id=parent_span_id,
153+
)
154+
155+
async def chat_completion_auto_send(
156+
self,
157+
task_id: str,
158+
llm_config: LLMConfig,
159+
trace_id: str | None = None,
160+
parent_span_id: str | None = None,
161+
start_to_close_timeout: timedelta = timedelta(seconds=120),
162+
heartbeat_timeout: timedelta = timedelta(seconds=120),
163+
retry_policy: RetryPolicy = DEFAULT_RETRY_POLICY,
164+
) -> TaskMessage | None:
165+
"""
166+
Chat completion with automatic TaskMessage creation.
167+
168+
Args:
169+
task_id (str): The ID of the task.
170+
llm_config (LLMConfig): The configuration for the LLM (must have stream=False).
171+
trace_id (Optional[str]): The trace ID for tracing.
172+
parent_span_id (Optional[str]): The parent span ID for tracing.
173+
start_to_close_timeout (timedelta): The start to close timeout.
174+
heartbeat_timeout (timedelta): The heartbeat timeout.
175+
retry_policy (RetryPolicy): The retry policy.
176+
177+
Returns:
178+
TaskMessage: The final TaskMessage
179+
"""
180+
if in_temporal_workflow():
181+
# Use streaming activity with stream=False for non-streaming auto-send
182+
params = ChatCompletionAutoSendParams(
183+
trace_id=trace_id,
184+
parent_span_id=parent_span_id,
185+
task_id=task_id,
186+
llm_config=llm_config,
187+
)
188+
return await ActivityHelpers.execute_activity(
189+
activity_name=LiteLLMActivityName.CHAT_COMPLETION_AUTO_SEND,
190+
request=params,
191+
response_type=TaskMessage,
192+
start_to_close_timeout=start_to_close_timeout,
193+
heartbeat_timeout=heartbeat_timeout,
194+
retry_policy=retry_policy,
195+
)
196+
else:
197+
return await self._litellm_service.chat_completion_auto_send(
198+
task_id=task_id,
199+
llm_config=llm_config,
200+
trace_id=trace_id,
201+
parent_span_id=parent_span_id,
202+
)
203+
204+
async def chat_completion_stream(
205+
self,
206+
llm_config: LLMConfig,
207+
trace_id: str | None = None,
208+
parent_span_id: str | None = None,
209+
) -> AsyncGenerator[Completion, None]:
210+
"""
211+
Stream chat completion chunks using LiteLLM.
212+
213+
DEFAULT: Returns raw streaming chunks for manual handling.
214+
215+
NOTE: This method does NOT work in Temporal workflows!
216+
Temporal activities cannot return generators. Use chat_completion_stream_auto_send() instead.
217+
218+
Args:
219+
llm_config (LLMConfig): The configuration for the LLM (must have stream=True).
220+
trace_id (Optional[str]): The trace ID for tracing.
221+
parent_span_id (Optional[str]): The parent span ID for tracing.
222+
start_to_close_timeout (timedelta): The start to close timeout.
223+
heartbeat_timeout (timedelta): The heartbeat timeout.
224+
retry_policy (RetryPolicy): The retry policy.
225+
226+
Returns:
227+
AsyncGenerator[Completion, None]: Generator yielding completion chunks
228+
229+
Raises:
230+
ValueError: If called from within a Temporal workflow
231+
"""
232+
# Delegate to service - it handles temporal workflow checks
233+
async for chunk in self._litellm_service.chat_completion_stream(
234+
llm_config=llm_config,
235+
trace_id=trace_id,
236+
parent_span_id=parent_span_id,
237+
):
238+
yield chunk
239+
240+
async def chat_completion_stream_auto_send(
241+
self,
242+
task_id: str,
243+
llm_config: LLMConfig,
244+
trace_id: str | None = None,
245+
parent_span_id: str | None = None,
246+
start_to_close_timeout: timedelta = timedelta(seconds=120),
247+
heartbeat_timeout: timedelta = timedelta(seconds=120),
248+
retry_policy: RetryPolicy = DEFAULT_RETRY_POLICY,
249+
) -> TaskMessage | None:
250+
"""
251+
Stream chat completion with automatic TaskMessage creation and streaming.
252+
253+
Args:
254+
task_id (str): The ID of the task to run the agent for.
255+
llm_config (LLMConfig): The configuration for the LLM (must have stream=True).
256+
trace_id (Optional[str]): The trace ID for tracing.
257+
parent_span_id (Optional[str]): The parent span ID for tracing.
258+
start_to_close_timeout (timedelta): The start to close timeout.
259+
heartbeat_timeout (timedelta): The heartbeat timeout.
260+
retry_policy (RetryPolicy): The retry policy.
261+
262+
Returns:
263+
TaskMessage: The final TaskMessage after streaming is complete
264+
"""
265+
if in_temporal_workflow():
266+
params = ChatCompletionStreamAutoSendParams(
267+
trace_id=trace_id,
268+
parent_span_id=parent_span_id,
269+
task_id=task_id,
270+
llm_config=llm_config,
271+
)
272+
return await ActivityHelpers.execute_activity(
273+
activity_name=LiteLLMActivityName.CHAT_COMPLETION_STREAM_AUTO_SEND,
274+
request=params,
275+
response_type=TaskMessage,
276+
start_to_close_timeout=start_to_close_timeout,
277+
heartbeat_timeout=heartbeat_timeout,
278+
retry_policy=retry_policy,
279+
)
280+
else:
281+
return await self._litellm_service.chat_completion_stream_auto_send(
282+
task_id=task_id,
283+
llm_config=llm_config,
284+
trace_id=trace_id,
285+
parent_span_id=parent_span_id,
286+
)

src/agentex/lib/core/adapters/llm/adapter_sgp.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,14 @@
1111

1212

1313
class SGPLLMGateway(LLMGateway):
14-
def __init__(self, sgp_api_key: str | None = None):
15-
self.sync_client = SGPClient(api_key=os.environ.get("SGP_API_KEY", sgp_api_key))
14+
def __init__(self, sgp_api_key: str | None = None, sgp_account_id: str | None = None):
15+
self.sync_client = SGPClient(
16+
api_key=os.environ.get("SGP_API_KEY", sgp_api_key),
17+
account_id=os.environ.get("SGP_ACCOUNT_ID", sgp_account_id)
18+
)
1619
self.async_client = AsyncSGPClient(
17-
api_key=os.environ.get("SGP_API_KEY", sgp_api_key)
20+
api_key=os.environ.get("SGP_API_KEY", sgp_api_key),
21+
account_id=os.environ.get("SGP_ACCOUNT_ID", sgp_account_id)
1822
)
1923

2024
def completion(self, *args, **kwargs) -> Completion:

0 commit comments

Comments
 (0)