Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/agentex/lib/adk/_modules/acp.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from temporalio.common import RetryPolicy

from agentex import AsyncAgentex
from agentex.lib.adk.utils._modules.client import get_async_agentex_client
from agentex.lib.adk.utils._modules.client import create_async_agentex_client
from agentex.lib.core.services.adk.acp.acp import ACPService
from agentex.lib.core.temporal.activities.activity_helpers import ActivityHelpers
from agentex.lib.core.temporal.activities.adk.acp.acp_activities import (
Expand Down Expand Up @@ -41,7 +41,7 @@ def __init__(self, acp_service: ACPService | None = None):
acp_activities (Optional[ACPActivities]): Optional pre-configured ACP activities. If None, will be auto-initialized.
"""
if acp_service is None:
agentex_client = get_async_agentex_client()
agentex_client = create_async_agentex_client()
tracer = AsyncTracer(agentex_client)
self._acp_service = ACPService(agentex_client=agentex_client, tracer=tracer)
else:
Expand Down
4 changes: 2 additions & 2 deletions src/agentex/lib/adk/_modules/agent_task_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from temporalio.common import RetryPolicy

from agentex import AsyncAgentex
from agentex.lib.adk.utils._modules.client import get_async_agentex_client
from agentex.lib.adk.utils._modules.client import create_async_agentex_client
from agentex.lib.core.services.adk.agent_task_tracker import AgentTaskTrackerService
from agentex.lib.core.temporal.activities.activity_helpers import ActivityHelpers
from agentex.lib.core.temporal.activities.adk.agent_task_tracker_activities import (
Expand Down Expand Up @@ -34,7 +34,7 @@ def __init__(
agent_task_tracker_service: AgentTaskTrackerService | None = None,
):
if agent_task_tracker_service is None:
agentex_client = get_async_agentex_client()
agentex_client = create_async_agentex_client()
tracer = AsyncTracer(agentex_client)
self._agent_task_tracker_service = AgentTaskTrackerService(
agentex_client=agentex_client, tracer=tracer
Expand Down
4 changes: 2 additions & 2 deletions src/agentex/lib/adk/_modules/agents.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from datetime import timedelta
from typing import Optional

from agentex.lib.adk.utils._modules.client import get_async_agentex_client
from agentex.lib.adk.utils._modules.client import create_async_agentex_client
from agentex.lib.core.temporal.activities.adk.agents_activities import AgentsActivityName, GetAgentParams
from temporalio.common import RetryPolicy

Expand Down Expand Up @@ -29,7 +29,7 @@ def __init__(
agents_service: Optional[AgentsService] = None,
):
if agents_service is None:
agentex_client = get_async_agentex_client()
agentex_client = create_async_agentex_client()
tracer = AsyncTracer(agentex_client)
self._agents_service = AgentsService(agentex_client=agentex_client, tracer=tracer)
else:
Expand Down
4 changes: 2 additions & 2 deletions src/agentex/lib/adk/_modules/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from temporalio.common import RetryPolicy

from agentex import AsyncAgentex
from agentex.lib.adk.utils._modules.client import get_async_agentex_client
from agentex.lib.adk.utils._modules.client import create_async_agentex_client
from agentex.lib.core.services.adk.events import EventsService
from agentex.lib.core.temporal.activities.activity_helpers import ActivityHelpers
from agentex.lib.core.temporal.activities.adk.events_activities import (
Expand Down Expand Up @@ -33,7 +33,7 @@ def __init__(
events_service: EventsService | None = None,
):
if events_service is None:
agentex_client = get_async_agentex_client()
agentex_client = create_async_agentex_client()
tracer = AsyncTracer(agentex_client)
self._events_service = EventsService(
agentex_client=agentex_client, tracer=tracer
Expand Down
4 changes: 2 additions & 2 deletions src/agentex/lib/adk/_modules/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from temporalio.common import RetryPolicy

from agentex import AsyncAgentex
from agentex.lib.adk.utils._modules.client import get_async_agentex_client
from agentex.lib.adk.utils._modules.client import create_async_agentex_client
from agentex.lib.core.adapters.streams.adapter_redis import RedisStreamRepository
from agentex.lib.core.services.adk.messages import MessagesService
from agentex.lib.core.services.adk.streaming import StreamingService
Expand Down Expand Up @@ -38,7 +38,7 @@ def __init__(
messages_service: MessagesService | None = None,
):
if messages_service is None:
agentex_client = get_async_agentex_client()
agentex_client = create_async_agentex_client()
stream_repository = RedisStreamRepository()
streaming_service = StreamingService(
agentex_client=agentex_client,
Expand Down
4 changes: 2 additions & 2 deletions src/agentex/lib/adk/_modules/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from temporalio.common import RetryPolicy

from agentex import AsyncAgentex
from agentex.lib.adk.utils._modules.client import get_async_agentex_client
from agentex.lib.adk.utils._modules.client import create_async_agentex_client
from agentex.lib.core.services.adk.state import StateService
from agentex.lib.core.temporal.activities.activity_helpers import ActivityHelpers
from agentex.lib.core.temporal.activities.adk.state_activities import (
Expand Down Expand Up @@ -37,7 +37,7 @@ def __init__(
state_service: StateService | None = None,
):
if state_service is None:
agentex_client = get_async_agentex_client()
agentex_client = create_async_agentex_client()
tracer = AsyncTracer(agentex_client)
self._state_service = StateService(
agentex_client=agentex_client, tracer=tracer
Expand Down
4 changes: 2 additions & 2 deletions src/agentex/lib/adk/_modules/streaming.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from temporalio.common import RetryPolicy

from agentex import AsyncAgentex
from agentex.lib.adk.utils._modules.client import get_async_agentex_client
from agentex.lib.adk.utils._modules.client import create_async_agentex_client
from agentex.lib.core.adapters.streams.adapter_redis import RedisStreamRepository
from agentex.lib.core.services.adk.streaming import (
StreamingService,
Expand Down Expand Up @@ -35,7 +35,7 @@ def __init__(self, streaming_service: StreamingService | None = None):
"""
if streaming_service is None:
stream_repository = RedisStreamRepository()
agentex_client = get_async_agentex_client()
agentex_client = create_async_agentex_client()
self._streaming_service = StreamingService(
agentex_client=agentex_client,
stream_repository=stream_repository,
Expand Down
4 changes: 2 additions & 2 deletions src/agentex/lib/adk/_modules/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from temporalio.common import RetryPolicy

from agentex import AsyncAgentex
from agentex.lib.adk.utils._modules.client import get_async_agentex_client
from agentex.lib.adk.utils._modules.client import create_async_agentex_client
from agentex.lib.core.services.adk.tasks import TasksService
from agentex.lib.core.temporal.activities.activity_helpers import ActivityHelpers
from agentex.lib.core.temporal.activities.adk.tasks_activities import (
Expand Down Expand Up @@ -32,7 +32,7 @@ def __init__(
tasks_service: TasksService | None = None,
):
if tasks_service is None:
agentex_client = get_async_agentex_client()
agentex_client = create_async_agentex_client()
tracer = AsyncTracer(agentex_client)
self._tasks_service = TasksService(
agentex_client=agentex_client, tracer=tracer
Expand Down
4 changes: 2 additions & 2 deletions src/agentex/lib/adk/_modules/tracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from temporalio.common import RetryPolicy

from agentex import AsyncAgentex
from agentex.lib.adk.utils._modules.client import get_async_agentex_client
from agentex.lib.adk.utils._modules.client import create_async_agentex_client
from agentex.lib.core.services.adk.tracing import TracingService
from agentex.lib.core.temporal.activities.activity_helpers import ActivityHelpers
from agentex.lib.core.temporal.activities.adk.tracing_activities import (
Expand Down Expand Up @@ -39,7 +39,7 @@ def __init__(self, tracing_service: TracingService | None = None):
tracing_activities (Optional[TracingActivities]): Optional pre-configured tracing activities. If None, will be auto-initialized.
"""
if tracing_service is None:
agentex_client = get_async_agentex_client()
agentex_client = create_async_agentex_client()
tracer = AsyncTracer(agentex_client)
self._tracing_service = TracingService(tracer=tracer)
else:
Expand Down
3 changes: 2 additions & 1 deletion src/agentex/lib/adk/providers/_modules/litellm.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from collections.abc import AsyncGenerator
from datetime import timedelta

from agentex.lib.adk.utils._modules.client import create_async_agentex_client
from temporalio.common import RetryPolicy

from agentex import AsyncAgentex
Expand Down Expand Up @@ -39,7 +40,7 @@ def __init__(
):
if litellm_service is None:
# Create default service
agentex_client = AsyncAgentex()
agentex_client = create_async_agentex_client()
stream_repository = RedisStreamRepository()
streaming_service = StreamingService(
agentex_client=agentex_client,
Expand Down
3 changes: 2 additions & 1 deletion src/agentex/lib/adk/providers/_modules/openai.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from datetime import timedelta
from typing import Any, Literal

from agentex.lib.adk.utils._modules.client import create_async_agentex_client
from agents import Agent, RunResult, RunResultStreaming
from agents.agent import StopAtTools, ToolsToFinalOutputFunction
from agents.agent_output import AgentOutputSchemaBase
Expand Down Expand Up @@ -46,7 +47,7 @@ def __init__(
):
if openai_service is None:
# Create default service
agentex_client = AsyncAgentex()
agentex_client = create_async_agentex_client()
stream_repository = RedisStreamRepository()
streaming_service = StreamingService(
agentex_client=agentex_client,
Expand Down
3 changes: 2 additions & 1 deletion src/agentex/lib/adk/providers/_modules/sgp.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from datetime import timedelta

from agentex.lib.adk.utils._modules.client import create_async_agentex_client
from scale_gp import SGPClient, SGPClientError
from temporalio.common import RetryPolicy

Expand Down Expand Up @@ -33,7 +34,7 @@ def __init__(
if sgp_service is None:
try:
sgp_client = SGPClient()
agentex_client = AsyncAgentex()
agentex_client = create_async_agentex_client()
tracer = AsyncTracer(agentex_client)
self._sgp_service = SGPService(sgp_client=sgp_client, tracer=tracer)
except SGPClientError:
Expand Down
55 changes: 20 additions & 35 deletions src/agentex/lib/adk/utils/_modules/client.py
Original file line number Diff line number Diff line change
@@ -1,43 +1,28 @@
import threading
from typing import Dict, Optional, Any
import httpx

from agentex import AsyncAgentex
from agentex.lib.environment_variables import EnvironmentVariables, refreshed_environment_variables
from agentex.lib.environment_variables import EnvironmentVariables
from agentex.lib.utils.logging import make_logger

_client: Optional["AsyncAgentex"] = None
_cached_headers: Dict[str, str] = {}
_init_kwargs: Dict[str, Any] = {}
_lock = threading.RLock()
logger = make_logger(__name__)


def _build_headers() -> Dict[str, str]:
EnvironmentVariables.refresh()
if refreshed_environment_variables and getattr(refreshed_environment_variables, "AGENT_ID", None):
return {"x-agent-identity": refreshed_environment_variables.AGENT_ID}
return {}
class EnvAuth(httpx.Auth):
def __init__(self, header_name="x-agent-identity"):
self.header_name = header_name

def auth_flow(self, request):
# This gets called for every request
env_vars = EnvironmentVariables.refresh()
if env_vars:
agent_id = env_vars.AGENT_ID
if agent_id:
request.headers[self.header_name] = agent_id
logger.info(f"Adding header {self.header_name}:{agent_id}")
yield request

def get_async_agentex_client(**kwargs) -> "AsyncAgentex":
"""
Return a cached AsyncAgentex instance (created synchronously).
Each call re-checks env vars and updates client.default_headers if needed.
"""
global _client, _cached_headers, _init_kwargs

new_headers = _build_headers()

with _lock:
# First time (or kwargs changed) -> build a new client
if _client is None or kwargs != _init_kwargs:
_client = AsyncAgentex(default_headers=new_headers.copy(), **kwargs)
_cached_headers = new_headers
_init_kwargs = dict(kwargs)
return _client

# Same client; maybe headers changed
if new_headers != _cached_headers:
_cached_headers = new_headers
_client.default_headers.clear()
_client.default_headers.update(new_headers)

return _client
def create_async_agentex_client(**kwargs) -> AsyncAgentex:
client = AsyncAgentex(**kwargs)
client._client.auth = EnvAuth()
return client
3 changes: 2 additions & 1 deletion src/agentex/lib/adk/utils/_modules/templating.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from datetime import timedelta
from typing import Any

from agentex.lib.adk.utils._modules.client import create_async_agentex_client
from temporalio.common import RetryPolicy

from agentex import AsyncAgentex
Expand Down Expand Up @@ -39,7 +40,7 @@ def __init__(
templating_service (Optional[TemplatingService]): Optional pre-configured templating service. If None, will be auto-initialized.
"""
if templating_service is None:
agentex_client = AsyncAgentex()
agentex_client = create_async_agentex_client()
tracer = AsyncTracer(agentex_client)
self._templating_service = TemplatingService(tracer=tracer)
else:
Expand Down
3 changes: 2 additions & 1 deletion src/agentex/lib/core/temporal/activities/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from agentex.lib.adk.utils._modules.client import create_async_agentex_client
from scale_gp import SGPClient, SGPClientError

from agentex import AsyncAgentex
Expand Down Expand Up @@ -58,7 +59,7 @@ def get_all_activities(sgp_client=None):

llm_gateway = LiteLLMGateway()
stream_repository = RedisStreamRepository()
agentex_client = AsyncAgentex()
agentex_client = create_async_agentex_client()
tracer = AsyncTracer(agentex_client)

# Services
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Any, Dict, override

from agentex import Agentex, AsyncAgentex
from agentex.lib.adk.utils._modules.client import create_async_agentex_client
from agentex.lib.core.tracing.processors.tracing_processor_interface import (
AsyncTracingProcessor,
SyncTracingProcessor,
Expand Down Expand Up @@ -65,7 +66,7 @@ def shutdown(self) -> None:

class AgentexAsyncTracingProcessor(AsyncTracingProcessor):
def __init__(self, config: AgentexTracingProcessorConfig):
self.client = AsyncAgentex()
self.client = create_async_agentex_client()

@override
async def on_span_start(self, span: Span) -> None:
Expand Down
6 changes: 5 additions & 1 deletion src/agentex/lib/environment_variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,15 @@
from enum import Enum
from pathlib import Path

from agentex.lib.utils.logging import make_logger
from dotenv import load_dotenv

from agentex.lib.utils.model_utils import BaseModel

PROJECT_ROOT = Path(__file__).resolve().parents[2]

logger = make_logger(__name__)


class EnvVarKeys(str, Enum):
ENVIRONMENT = "ENVIRONMENT"
Expand Down Expand Up @@ -37,7 +40,7 @@ class Environment(str, Enum):
PROD = "production"


refreshed_environment_variables = None
refreshed_environment_variables: "EnvironmentVariables" | None = None


class EnvironmentVariables(BaseModel):
Expand All @@ -64,6 +67,7 @@ def refresh(cls) -> EnvironmentVariables | None:
if refreshed_environment_variables is not None:
return refreshed_environment_variables

logger.info("Refreshing environment variables")
if os.environ.get(EnvVarKeys.ENVIRONMENT) == Environment.DEV:
# Load global .env file first
global_env_path = PROJECT_ROOT / ".env"
Expand Down
9 changes: 5 additions & 4 deletions src/agentex/lib/sdk/fastacp/base/base_acp_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

import httpx
import uvicorn
from agentex.lib.adk.utils._modules.client import get_async_agentex_client
from agentex.lib.adk.utils._modules.client import create_async_agentex_client
from fastapi import FastAPI, Request
from fastapi.responses import StreamingResponse
from pydantic import TypeAdapter, ValidationError
Expand Down Expand Up @@ -397,9 +397,10 @@ async def _register_agent(self, env_vars: EnvironmentVariables):

os.environ["AGENT_ID"] = agent_id
os.environ["AGENT_NAME"] = agent_name
refreshed_environment_variables.AGENT_ID = agent_id
refreshed_environment_variables.AGENT_NAME = agent_name
get_async_agentex_client() # refresh cache
env_vars.AGENT_ID = agent_id
env_vars.AGENT_NAME = agent_name
global refreshed_environment_variables
refreshed_environment_variables = env_vars
logger.info(
f"Successfully registered agent '{env_vars.AGENT_NAME}' with Agentex server with acp_url: {full_acp_url}. Registration data: {registration_data}"
)
Expand Down
4 changes: 3 additions & 1 deletion src/agentex/lib/sdk/fastacp/impl/agentic_base_acp.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from typing import Any

from agentex.lib.adk.utils._modules.client import create_async_agentex_client
from typing_extensions import override
from agentex import AsyncAgentex
from agentex.lib.sdk.fastacp.base.base_acp_server import BaseACPServer
Expand All @@ -24,7 +26,7 @@ class AgenticBaseACP(BaseACPServer):
def __init__(self):
super().__init__()
self._setup_handlers()
self._agentex_client = AsyncAgentex()
self._agentex_client = create_async_agentex_client()

@classmethod
@override
Expand Down
Loading
Loading