Skip to content

Commit 2a1d4ff

Browse files
committed
✨(websearch) add Brave llm/context snippets
Use llm/context endpoint with snippets, change tool name for web_search Signed-off-by: camilleAND <camille.andre@modernisation.gouv.fr>
1 parent 6dd41e8 commit 2a1d4ff

File tree

9 files changed

+277
-179
lines changed

9 files changed

+277
-179
lines changed

src/backend/chat/agents/conversation.py

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -117,20 +117,14 @@ def enforce_response_language() -> str:
117117
"""Dynamic instruction function to set the expected language to use."""
118118
return f"Answer in {get_language_name(language).lower()}." if language else ""
119119

120-
def get_web_search_tool_name(self) -> str | None:
120+
def is_web_search_configured(self) -> bool:
121121
"""
122-
Get the name of the web search tool if available.
122+
Return True when a web search backend is configured on this model.
123123
124-
If several are available, return the first one found.
125-
126-
Warning, this says the tool is available, not that
127-
it (the tool/feature) is enabled for the current conversation.
124+
This does not mean web search is enabled for the current conversation
125+
(feature flags and runtime deps still apply).
128126
"""
129-
for toolset in self.toolsets:
130-
for tool in toolset.tools.values():
131-
if tool.name.startswith("web_search_"):
132-
return tool.name
133-
return None
127+
return bool(getattr(self.configuration, "web_search", None))
134128

135129

136130
@dataclasses.dataclass(init=False)

src/backend/chat/clients/pydantic_ai.py

Lines changed: 31 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,7 @@ def __init__( # pylint: disable=too-many-arguments,too-many-positional-argument
251251
session=session,
252252
web_search_enabled=self._is_web_search_enabled and self._is_smart_search_enabled,
253253
)
254+
self._web_search_tool_registered = False
254255

255256
self.conversation_agent = ConversationAgent(
256257
model_hrid=self.model_hrid,
@@ -440,8 +441,7 @@ def _setup_web_search(self, force_web_search: bool) -> bool:
440441
logger.warning("Web search is forced but the feature is disabled, ignoring.")
441442
return False
442443

443-
web_search_tool_name = self.conversation_agent.get_web_search_tool_name()
444-
if not web_search_tool_name:
444+
if not self.conversation_agent.is_web_search_configured():
445445
logger.warning("Web search is forced but no web search tool is available, ignoring.")
446446
return False
447447

@@ -450,9 +450,7 @@ def _setup_web_search(self, force_web_search: bool) -> bool:
450450
@self.conversation_agent.instructions
451451
def force_web_search_prompt() -> str:
452452
"""Dynamic system prompt function to force web search."""
453-
return (
454-
f"You must call the {web_search_tool_name} tool before answering the user request."
455-
)
453+
return "You must call the web_search tool before answering the user request."
456454

457455
return True
458456

@@ -699,6 +697,33 @@ async def summarize(ctx: RunContext, *args, **kwargs) -> ToolReturn:
699697
"""Wrap the document_summarize tool to provide context and add the tool."""
700698
return await document_summarize(ctx, *args, **kwargs)
701699

700+
def _setup_web_search_tool(self) -> None:
701+
"""Register model-specific web search tool when configured."""
702+
if self._web_search_tool_registered:
703+
return
704+
configuration = self.conversation_agent.configuration
705+
if not getattr(configuration, "web_search", None):
706+
return
707+
708+
async def only_if_web_search_enabled(ctx, tool_def):
709+
"""Prepare function to include a tool only if web search is enabled in the context."""
710+
return tool_def if ctx.deps.web_search_enabled else None
711+
712+
web_search_impl = import_string(configuration.web_search)
713+
714+
@self.conversation_agent.tool(
715+
name="web_search",
716+
retries=1,
717+
prepare=only_if_web_search_enabled,
718+
description="Search the web for up-to-date information",
719+
)
720+
@functools.wraps(web_search_impl)
721+
async def web_search(ctx: RunContext, *args, **kwargs) -> ToolReturn:
722+
"""Wrap the web_search tool to provide context and add the tool."""
723+
return await web_search_impl(ctx, *args, **kwargs)
724+
725+
self._web_search_tool_registered = True
726+
702727
async def _handle_input_documents(
703728
self,
704729
input_documents: List[BinaryContent | DocumentUrl],
@@ -1016,6 +1041,7 @@ async def _run_agent( # pylint: disable=too-many-locals
10161041
conversation_has_documents = doc_result.has_documents
10171042

10181043
await self._agent_stop_streaming(force_cache_check=True)
1044+
self._setup_web_search_tool()
10191045
self._setup_web_search(force_web_search)
10201046

10211047
if await self._check_should_enable_rag(conversation_has_documents):

src/backend/chat/llm_configuration.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,7 @@ class LLModel(BaseModel):
125125
supports_streaming: bool | None = None
126126
system_prompt: SettingEnvValue
127127
tools: list[str]
128+
web_search: SettingEnvValue | None = None
128129

129130
@field_validator("tools", mode="before")
130131
@classmethod
@@ -134,6 +135,14 @@ def validate_tools(cls, value: list[str] | str) -> list[str]:
134135
return _get_setting_or_env_or_value(value)
135136
return value
136137

138+
@field_validator("web_search", mode="before")
139+
@classmethod
140+
def validate_web_search(cls, value: str | None) -> str | None:
141+
"""Convert web_search path if it's a setting or environment variable."""
142+
if isinstance(value, str):
143+
return _get_setting_or_env_or_value(value)
144+
return value
145+
137146
@model_validator(mode="after")
138147
def check_provider_or_provider_name(self) -> Self:
139148
"""

src/backend/chat/tests/agents/test_build_conversation_agent.py

Lines changed: 26 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,12 @@
33
# pylint:disable=protected-access
44

55
import pytest
6-
import responses
76
from freezegun import freeze_time
87
from pydantic_ai import Agent
98
from pydantic_ai.models.openai import OpenAIChatModel
10-
from pydantic_ai.models.test import TestModel
119

1210
from chat.agents.conversation import ConversationAgent
13-
from chat.clients.pydantic_ai import ContextDeps
11+
from chat.llm_configuration import LLModel, LLMProvider
1412

1513

1614
@pytest.fixture(autouse=True)
@@ -87,47 +85,30 @@ def test_add_dynamic_system_prompt():
8785
assert agent._instructions[2]() == "Answer in french."
8886

8987

90-
def test_agent_get_web_search_tool_name(settings):
91-
"""Test the web_search_available method."""
92-
settings.AI_AGENT_TOOLS = ["get_current_weather", "web_search_albert_rag"]
88+
def test_agent_is_web_search_configured():
89+
"""Test whether web search backend is configured on the model."""
9390
agent = ConversationAgent(model_hrid="default-model")
94-
assert agent.get_web_search_tool_name() == "web_search_albert_rag"
95-
96-
settings.AI_AGENT_TOOLS = ["get_current_weather"]
97-
agent = ConversationAgent(model_hrid="default-model")
98-
assert agent.get_web_search_tool_name() is None
99-
100-
settings.AI_AGENT_TOOLS = ["get_current_weather", "web_search_tavily", "web_search_albert_rag"]
101-
agent = ConversationAgent(model_hrid="default-model")
102-
assert agent.get_web_search_tool_name() == "web_search_tavily"
103-
104-
105-
@responses.activate
106-
def test_web_search_tool_avalability(settings):
107-
"""Test the web search tool availability according to context."""
108-
responses.add(
109-
responses.POST,
110-
"https://api.tavily.com/search",
111-
json={"results": []},
112-
status=200,
113-
)
114-
context_deps = ContextDeps(conversation=None, user=None, web_search_enabled=True)
115-
116-
# No tools (context allows web search, but no tool configured)
91+
assert agent.is_web_search_configured() is False
92+
93+
94+
def test_agent_is_web_search_configured_when_defined_in_model_config(settings):
95+
"""Web search is configured when LLModel.web_search is set."""
96+
settings.LLM_CONFIGURATIONS = {
97+
"default-model": LLModel(
98+
hrid="default-model",
99+
model_name="model-123",
100+
human_readable_name="Default Model",
101+
is_active=True,
102+
icon=None,
103+
system_prompt="You are a helpful assistant",
104+
tools=[],
105+
web_search="chat.tools.web_search_brave.web_search_brave_llm_context",
106+
provider=LLMProvider(
107+
hrid="default-provider",
108+
base_url="https://api.llm.com/v1/",
109+
api_key="test-key",
110+
),
111+
),
112+
}
117113
agent = ConversationAgent(model_hrid="default-model")
118-
with agent.override(model=TestModel(), deps=context_deps):
119-
response = agent.run_sync("What tools do you have?")
120-
assert response.output == "success (no tool calls)"
121-
122-
# Tool configured, context allows web search
123-
settings.AI_AGENT_TOOLS = ["web_search_tavily"]
124-
agent = ConversationAgent(model_hrid="default-model") # re-init to pick up new settings
125-
with agent.override(model=TestModel(), deps=context_deps):
126-
response = agent.run_sync("What tools do you have?")
127-
assert response.output == '{"web_search_tavily":[]}'
128-
129-
# Tool configured, context disables web search
130-
context_deps.web_search_enabled = False
131-
with agent.override(model=TestModel(), deps=context_deps):
132-
response = agent.run_sync("What tools do you have?")
133-
assert response.output == "success (no tool calls)"
114+
assert agent.is_web_search_configured() is True

src/backend/chat/tests/clients/pydantic_ai/test_smart_web_search.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,8 @@ def _llm_config_with_websearch(settings):
2323
is_active=True,
2424
icon=None,
2525
system_prompt="You are an amazing assistant.",
26-
tools=["web_search_brave_with_document_backend"],
26+
tools=[],
27+
web_search="chat.tools.web_search_brave.web_search_brave_llm_context",
2728
provider=LLMProvider(
2829
hrid="unused",
2930
base_url="https://example.com",
@@ -48,6 +49,7 @@ def test_smart_search_disabled_suppresses_tool_at_runtime(_llm_config_with_webse
4849
if not service._is_smart_search_enabled and service._is_web_search_enabled:
4950
service._context_deps.web_search_enabled = False
5051

52+
service._setup_web_search_tool()
5153
with service.conversation_agent.override(model=TestModel(), deps=service._context_deps):
5254
response = service.conversation_agent.run_sync("Search the web for something.")
5355

@@ -65,10 +67,11 @@ def test_smart_search_enabled_tool_is_called(_llm_config_with_websearch):
6567
assert service._is_smart_search_enabled is True
6668
assert service._context_deps.web_search_enabled is True
6769

70+
service._setup_web_search_tool()
6871
with service.conversation_agent.override(model=TestModel(), deps=service._context_deps):
6972
response = service.conversation_agent.run_sync("Search the web for something.")
7073

71-
assert "web_search_brave_with_document_backend" in response.output
74+
assert "web_search" in response.output
7275

7376

7477
def test_force_websearch_overrides_smart_search_disabled(_llm_config_with_websearch):
@@ -82,14 +85,16 @@ def test_force_websearch_overrides_smart_search_disabled(_llm_config_with_websea
8285
assert service._is_smart_search_enabled is False
8386
assert service._context_deps.web_search_enabled is False
8487

88+
# Match _run_agent: register the tool first, then enable deps + forced prompt.
89+
service._setup_web_search_tool()
8590
service._setup_web_search(force_web_search=True)
8691

87-
web_search_tool_name = service.conversation_agent.get_web_search_tool_name()
92+
assert service.conversation_agent.is_web_search_configured() is True
8893
assert service._context_deps.web_search_enabled is True
8994
assert any(
90-
callable(instr) and web_search_tool_name in instr()
95+
callable(instr) and "web_search" in instr()
9196
for instr in service.conversation_agent._instructions
9297
)
9398
with service.conversation_agent.override(model=TestModel(), deps=service._context_deps):
9499
response = service.conversation_agent.run_sync("Search the web for something.")
95-
assert "web_search_brave_with_document_backend" in response.output
100+
assert "web_search" in response.output

0 commit comments

Comments
 (0)