Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions tests/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from google.adk.tools import load_memory

from veadk import Agent
from veadk.consts import DEFAULT_MODEL_EXTRA_HEADERS
from veadk.knowledgebase import KnowledgeBase
from veadk.memory.long_term_memory import LongTermMemory
from veadk.tools import load_knowledgebase_tool
Expand All @@ -26,11 +27,14 @@ def test_agent():
long_term_memory = LongTermMemory(backend="local")
tracer = OpentelemetryTracer()

model_extra_headers = {"test-header": "test-value"}

agent = Agent(
model_name="test_model_name",
model_provider="test_model_provider",
model_api_key="test_model_api_key",
model_api_base="test_model_api_base",
model_extra_headers=model_extra_headers,
tools=[],
sub_agents=[],
knowledgebase=knowledgebase,
Expand All @@ -39,7 +43,10 @@ def test_agent():
serve_url="",
)

model_extra_headers |= DEFAULT_MODEL_EXTRA_HEADERS

assert agent.model.model == f"{agent.model_provider}/{agent.model_name}"
assert agent.model_extra_headers == model_extra_headers

assert agent.knowledgebase == knowledgebase
assert agent.knowledgebase.backend == "local"
Expand Down
26 changes: 21 additions & 5 deletions veadk/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
DEFALUT_MODEL_AGENT_PROVIDER,
DEFAULT_MODEL_AGENT_API_BASE,
DEFAULT_MODEL_AGENT_NAME,
DEFAULT_MODEL_EXTRA_HEADERS,
)
from veadk.evaluation import EvalSetRecorder
from veadk.knowledgebase import KnowledgeBase
Expand Down Expand Up @@ -73,6 +74,9 @@ class Agent(LlmAgent):
model_api_key: str = Field(default_factory=lambda: getenv("MODEL_AGENT_API_KEY"))
"""The api key of the model for agent running."""

model_extra_headers: dict = Field(default_factory=dict)
"""The extra headers to include in the model requests."""

tools: list[ToolUnion] = []
"""The tools provided to agent."""

Expand All @@ -96,11 +100,23 @@ class Agent(LlmAgent):

def model_post_init(self, __context: Any) -> None:
super().model_post_init(None) # for sub_agents init
self.model = LiteLlm(
model=f"{self.model_provider}/{self.model_name}",
api_key=self.model_api_key,
api_base=self.model_api_base,
)

self.model_extra_headers |= DEFAULT_MODEL_EXTRA_HEADERS

if not self.model:
self.model = LiteLlm(
model=f"{self.model_provider}/{self.model_name}",
api_key=self.model_api_key,
api_base=self.model_api_base,
extra_headers=self.model_extra_headers,
)
logger.debug(
f"LiteLLM client created with extra headers: {self.model_extra_headers}"
)
else:
logger.warning(
"You are trying to use your own LiteLLM client, some default request headers may be missing."
)

if self.knowledgebase:
from veadk.tools import load_knowledgebase_tool
Expand Down
3 changes: 3 additions & 0 deletions veadk/consts.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from veadk.version import VERSION

DEFAULT_MODEL_AGENT_NAME = "doubao-seed-1-6-250615"
DEFALUT_MODEL_AGENT_PROVIDER = "openai"
DEFAULT_MODEL_AGENT_API_BASE = "https://ark.cn-beijing.volces.com/api/v3/"
DEFAULT_MODEL_EXTRA_HEADERS = {"veadk-source": "veadk", "veadk-version": VERSION}
2 changes: 1 addition & 1 deletion veadk/tracing/telemetry/telemetry.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def set_common_attributes(
)
return

if isinstance(invocation_context.agent, Agent):
if isinstance(invocation_context.agent, Agent) and invocation_context.agent.tracers:
try:
from veadk.tracing.telemetry.opentelemetry_tracer import OpentelemetryTracer

Expand Down