diff --git a/python/packages/foundry/agent_framework_foundry/_agent.py b/python/packages/foundry/agent_framework_foundry/_agent.py index 7001e0bf81e..d290407edb9 100644 --- a/python/packages/foundry/agent_framework_foundry/_agent.py +++ b/python/packages/foundry/agent_framework_foundry/_agent.py @@ -365,6 +365,11 @@ async def _prepare_options( # ``agent_name`` instead, so skip there. See issue #5582. if not self.allow_preview: extra_body.setdefault("agent_reference", _build_agent_reference(self.agent_name, self.agent_version)) + should_strip_model = _uses_foundry_agent_session(conversation_id) or ( + conversation_id is None and options.get("model") is None + ) + if should_strip_model: + run_options.pop("model", None) if extra_body: run_options["extra_body"] = extra_body diff --git a/python/packages/foundry/tests/foundry/test_foundry_agent.py b/python/packages/foundry/tests/foundry/test_foundry_agent.py index 672a7aba690..bba5942eb6e 100644 --- a/python/packages/foundry/tests/foundry/test_foundry_agent.py +++ b/python/packages/foundry/tests/foundry/test_foundry_agent.py @@ -8,6 +8,7 @@ from types import SimpleNamespace from typing import Any from unittest.mock import AsyncMock, MagicMock, patch +from uuid import uuid4 import pytest from agent_framework import ( @@ -21,8 +22,10 @@ tool, ) from agent_framework_openai._chat_client import RawOpenAIChatClient +from azure.ai.projects import models as projects_models from azure.core.exceptions import ResourceNotFoundError from azure.identity import AzureCliCredential +from azure.identity.aio import AzureCliCredential as AsyncAzureCliCredential from agent_framework_foundry._agent import ( FoundryAgent, @@ -30,6 +33,7 @@ RawFoundryAgentChatClient, _FoundryAgentChatClient, ) +from agent_framework_foundry._chat_client import FoundryChatClient skip_if_foundry_agent_integration_tests_disabled = pytest.mark.skipif( os.getenv("FOUNDRY_PROJECT_ENDPOINT", "") in ("", "https://test-project.services.ai.azure.com/") @@ -37,10 +41,34 @@ reason="No real FOUNDRY_PROJECT_ENDPOINT or FOUNDRY_AGENT_NAME provided; skipping integration tests.", ) +_FOUNDRY_AZURE_AI_SEARCH_MODEL_ENV_VARS = ( + "FOUNDRY_AZURE_AI_SEARCH_MODEL", + "OPENAI_MODEL", + "AZURE_OPENAI_MODEL", + "AZURE_OPENAI_CHAT_MODEL", + "FOUNDRY_MODEL", +) + + +def _get_foundry_azure_ai_search_model() -> str | None: + """Return the model/deployment to use for local Azure AI Search integration validation.""" + return next((os.environ[key] for key in _FOUNDRY_AZURE_AI_SEARCH_MODEL_ENV_VARS if os.getenv(key)), None) + + +skip_if_foundry_azure_ai_search_integration_tests_disabled = pytest.mark.skipif( + os.getenv("FOUNDRY_PROJECT_ENDPOINT", "") in ("", "https://test-project.services.ai.azure.com/") + or os.getenv("AZURE_SEARCH_INDEX_NAME", "") == "" + or _get_foundry_azure_ai_search_model() is None, + reason="No live Foundry project, Azure Search index, or model provided for Azure AI Search integration tests.", +) + _FOUNDRY_AGENT_ENV_VARS = ( "FOUNDRY_PROJECT_ENDPOINT", "FOUNDRY_AGENT_NAME", "FOUNDRY_AGENT_VERSION", + "FOUNDRY_AZURE_AI_SEARCH_AGENT_NAME", + "FOUNDRY_AZURE_AI_SEARCH_AGENT_VERSION", + "FOUNDRY_AZURE_AI_SEARCH_MODEL", ) @@ -261,7 +289,7 @@ def my_func() -> str: async def test_raw_foundry_agent_chat_client_prepare_options_strips_client_side_fields() -> None: - """Test that _prepare_options strips tool-loop fields but preserves model for non-session requests.""" + """Test that _prepare_options strips client-side fields for Prompt Agent requests.""" mock_project = MagicMock() mock_openai = MagicMock() @@ -293,14 +321,12 @@ def my_func() -> str: options={"tools": [my_func]}, ) - # model is preserved for non-session (PromptAgent) requests - assert result["model"] == "gpt-4.1" + assert "model" not in result assert "tools" not in result assert "tool_choice" not in result assert "parallel_tool_calls" not in result # agent_reference is required so the Responses API can resolve model server-side; see #5582. assert result == { - "model": "gpt-4.1", "extra_body": {"agent_reference": {"name": "test-agent", "type": "agent_reference"}}, } @@ -336,6 +362,31 @@ async def test_raw_foundry_agent_chat_client_prepare_options_strips_model_for_ho assert result["extra_body"]["agent_reference"] == {"name": "test-agent", "type": "agent_reference"} +async def test_raw_foundry_agent_chat_client_prepare_options_preserves_explicit_model_first_turn() -> None: + """First-turn calls should keep an explicit caller-supplied model override.""" + + mock_project = MagicMock() + mock_project.get_openai_client.return_value = MagicMock() + + client = RawFoundryAgentChatClient( + project_client=mock_project, + agent_name="test-agent", + ) + + with patch( + "agent_framework_openai._chat_client.RawOpenAIChatClient._prepare_options", + new_callable=AsyncMock, + return_value={"model": "gpt-4.1"}, + ): + result = await client._prepare_options( + messages=[Message(role="user", contents="hi")], + options={"model": "gpt-4.1"}, + ) + + assert result["model"] == "gpt-4.1" + assert result["extra_body"] == {"agent_reference": {"name": "test-agent", "type": "agent_reference"}} + + async def test_raw_foundry_agent_chat_client_prepare_options_injects_agent_reference_first_turn() -> None: """First-turn (no conversation_id) Prompt Agent calls must carry agent_reference in extra_body. @@ -451,6 +502,7 @@ async def test_raw_foundry_agent_chat_client_prepare_options_respects_caller_age options={"extra_body": {"agent_reference": caller_reference}}, ) + assert "model" not in result assert result["extra_body"]["agent_reference"] == caller_reference @@ -1003,6 +1055,92 @@ async def test_foundry_agent_custom_client_run() -> None: assert "response test" in response.text.lower() +@pytest.mark.flaky +@pytest.mark.integration +@skip_if_foundry_azure_ai_search_integration_tests_disabled +async def test_foundry_agent_azure_ai_search_streaming_citation_get_url() -> None: + """Live regression for Foundry server-side Azure AI Search streaming output.""" + credential = AsyncAzureCliCredential() + project_client: Any | None = None + agent_created = False + agent_name = f"af-5995-{uuid4().hex[:12]}" + query = os.getenv("FOUNDRY_AZURE_AI_SEARCH_QUERY") or "Search the knowledge base for hotels and cite one result." + model = _get_foundry_azure_ai_search_model() + assert model is not None + + try: + from azure.ai.projects.aio import AIProjectClient + + project_client = AIProjectClient( + endpoint=os.environ["FOUNDRY_PROJECT_ENDPOINT"], + credential=credential, + allow_preview=True, + ) + try: + search_connection = await project_client.connections.get_default( + projects_models.ConnectionType.AZURE_AI_SEARCH + ) + except Exception as exc: + pytest.skip(f"No default Azure AI Search connection is configured in the Foundry project: {exc}") + if not search_connection.id: + pytest.skip("Default Azure AI Search connection does not expose an id.") + + tool = FoundryChatClient.get_azure_ai_search_tool( + index_connection_id=search_connection.id, + index_name=os.environ["AZURE_SEARCH_INDEX_NAME"], + query_type="simple", + top_k=3, + ) + definition = projects_models.PromptAgentDefinition( + model=model, + instructions="You must use Azure AI Search for every answer and cite retrieved documents.", + tools=[tool], + tool_choice="required", + ) + await project_client.agents.create_version(agent_name, definition=definition) + agent_created = True + + async with FoundryAgent(project_client=project_client, agent_name=agent_name, allow_preview=False) as agent: + stream = agent.run(query, stream=True) + async for _ in stream: + pass + response = await stream.get_final_response() + + raw_events = [] + for raw_agent_update in response.raw_representation or []: + raw_chat_update = getattr(raw_agent_update, "raw_representation", raw_agent_update) + raw_events.append(getattr(raw_chat_update, "raw_representation", raw_chat_update)) + + live_get_urls = [ + get_url for event in raw_events for get_url in RawOpenAIChatClient._extract_azure_ai_search_get_urls(event) + ] + assert live_get_urls, "Expected the live Azure AI Search stream to include get_urls." + + citations = [ + annotation + for message in response.messages + for content in message.contents + for annotation in (content.annotations or []) + if annotation.get("type") == "citation" + ] + doc_citations = [ + annotation + for annotation in citations + if isinstance(annotation.get("title"), str) and annotation["title"].startswith("doc_") + ] + if doc_citations: + assert any( + isinstance((annotation.get("additional_properties") or {}).get("get_url"), str) + for annotation in doc_citations + ), "Expected doc_N citations to be enriched with additional_properties.get_url." + finally: + if project_client is not None: + if agent_created: + await project_client.agents.delete(agent_name, force=True) + await project_client.close() + await credential.close() + + def test_parse_chunk_surfaces_oauth_consent_request() -> None: """An oauth_consent_request output item surfaces as Content with consent_link.""" diff --git a/python/packages/openai/agent_framework_openai/_chat_client.py b/python/packages/openai/agent_framework_openai/_chat_client.py index 14237e06a02..f4844ee624c 100644 --- a/python/packages/openai/agent_framework_openai/_chat_client.py +++ b/python/packages/openai/agent_framework_openai/_chat_client.py @@ -120,6 +120,9 @@ OPENAI_LOCAL_SHELL_COMMAND_PARTS_KEY = "openai.local_shell_command_parts" OPENAI_SHELL_OUTPUT_TYPE_SHELL_CALL = "shell_call_output" OPENAI_SHELL_OUTPUT_TYPE_LOCAL_SHELL_CALL = "local_shell_call_output" +_AZURE_AI_SEARCH_CALL_OUTPUT_TYPE = "azure_ai_search_call_output" +_AZURE_AI_SEARCH_OUTPUT_EVENT_TYPES = {"response.output_item.added", "response.output_item.done"} +_AZURE_AI_SEARCH_OUTPUT_EVENT_PREFIX = "response.azure_ai_search_call_output." # Internal marker emitted by `_prepare_content_for_openai` for an # `mcp_server_tool_result` Content. The Responses API expects an `mcp_call` @@ -750,6 +753,18 @@ async def _get_response() -> ChatResponse: return _get_response() + @override + def _finalize_response_updates( + self, + updates: Sequence[ChatResponseUpdate], + *, + response_format: Any | None = None, + ) -> ChatResponse[Any]: + """Finalize streamed updates and add post-stream Azure AI Search citation metadata.""" + self._enrich_streamed_azure_ai_search_citations(updates) + self._enrich_mcp_search_citations([content for update in updates for content in update.contents]) + return super()._finalize_response_updates(updates, response_format=response_format) + @classmethod def _extract_served_model(cls, headers: Any) -> str | None: """Return the Azure OpenAI ``x-ms-served-model`` response header value when present. @@ -1948,6 +1963,182 @@ def _serialize_provider_payload(value: Any) -> Any: return [RawOpenAIChatClient._serialize_provider_payload(item) for item in value] # type: ignore[reportUnknownVariableType] return value + @staticmethod + def _parse_azure_ai_search_output_payload(output: Any) -> Mapping[str, Any] | None: + """Parse an Azure AI Search tool output payload from a streamed Responses event.""" + if isinstance(output, str): + try: + output = json.loads(output) + except json.JSONDecodeError: + logger.debug("Unable to parse Azure AI Search call output JSON.", exc_info=True) + return None + + output = RawOpenAIChatClient._serialize_provider_payload(output) + if isinstance(output, Mapping): + return cast("Mapping[str, Any]", output) + return None + + @staticmethod + def _extract_azure_ai_search_output_payload(event: Any) -> Mapping[str, Any] | None: + """Return Azure AI Search output payload from either a top-level event or its nested item.""" + payload = RawOpenAIChatClient._parse_azure_ai_search_output_payload(getattr(event, "output", None)) + if payload is not None: + return payload + + item = getattr(event, "item", None) + if getattr(item, "type", None) == _AZURE_AI_SEARCH_CALL_OUTPUT_TYPE: + return RawOpenAIChatClient._parse_azure_ai_search_output_payload(getattr(item, "output", None)) + return None + + @staticmethod + def _extract_azure_ai_search_get_urls(event: Any) -> list[str]: + """Extract per-document Azure AI Search REST URLs from a streamed Responses event.""" + event_type = getattr(event, "type", None) + if event_type not in _AZURE_AI_SEARCH_OUTPUT_EVENT_TYPES and not ( + isinstance(event_type, str) and event_type.startswith(_AZURE_AI_SEARCH_OUTPUT_EVENT_PREFIX) + ): + return [] + + payload = RawOpenAIChatClient._extract_azure_ai_search_output_payload(event) + if payload is None: + return [] + + get_urls = payload.get("get_urls") + if not isinstance(get_urls, Sequence) or isinstance(get_urls, (str, bytes, bytearray)): + return [] + + urls: list[str] = [] + for url in cast("Sequence[object]", get_urls): + if isinstance(url, str) and url: + urls.append(url) + return urls + + @staticmethod + def _azure_ai_search_doc_index(annotation: Annotation) -> int | None: + """Return the document index encoded in a Foundry Azure AI Search `doc_N` citation title.""" + title = annotation.get("title") + if not isinstance(title, str) or not title.startswith("doc_"): + return None + index_text = title.removeprefix("doc_") + if not index_text.isdigit(): + return None + return int(index_text) + + @classmethod + def _enrich_streamed_azure_ai_search_citations(cls, updates: Sequence[ChatResponseUpdate]) -> None: + """Enrich streamed Azure AI Search citation annotations with per-document REST URLs.""" + # Azure AI Search citations are numbered with global `doc_N` ordinals across the + # whole streamed response, so concatenate `get_urls` in event order before resolving them. + get_urls: list[str] = [] + for update in updates: + get_urls.extend(cls._extract_azure_ai_search_get_urls(update.raw_representation)) + if not get_urls: + return + + for update in updates: + for content in update.contents: + if content.type != "text" or not content.annotations: + continue + for annotation in content.annotations: + if annotation.get("type") != "citation" or annotation.get("file_id"): + continue + + doc_index = cls._azure_ai_search_doc_index(annotation) + if doc_index is None or doc_index >= len(get_urls): + continue + + additional_properties = annotation.get("additional_properties") + if not isinstance(additional_properties, dict): + additional_properties = {} + annotation["additional_properties"] = additional_properties + if "get_url" in additional_properties: + continue + + additional_properties["get_url"] = get_urls[doc_index] + + @staticmethod + def _extract_mcp_search_documents_from_text(text: str) -> dict[str, Mapping[str, Any]]: + """Extract MCP search-index document metadata JSON objects from hosted-MCP output text.""" + documents: dict[str, Mapping[str, Any]] = {} + decoder = json.JSONDecoder() + start = 0 + while True: + object_start = text.find("{", start) + if object_start < 0: + return documents + try: + value, offset = decoder.raw_decode(text[object_start:]) + except json.JSONDecodeError: + start = object_start + 1 + continue + if isinstance(value, Mapping): + document = cast("Mapping[str, Any]", value) + document_id = document.get("id") + if isinstance(document_id, str) and document_id: + documents[document_id] = document + start = object_start + offset + + @classmethod + def _extract_mcp_search_documents_from_content(cls, content: Content) -> dict[str, Mapping[str, Any]]: + """Extract MCP search-index document metadata from an MCP tool-result content item.""" + documents: dict[str, Mapping[str, Any]] = {} + if content.type != "mcp_server_tool_result": + return documents + + def _add_from_output(output: Any) -> None: + if isinstance(output, str): + documents.update(cls._extract_mcp_search_documents_from_text(output)) + return + if isinstance(output, Content): + if output.type == "text" and isinstance(output.text, str): + documents.update(cls._extract_mcp_search_documents_from_text(output.text)) + return + if isinstance(output, Sequence) and not isinstance(output, (str, bytes, bytearray)): + for item in cast("Sequence[object]", output): + _add_from_output(item) + + _add_from_output(content.output) + raw_output = getattr(content.raw_representation, "output", None) + _add_from_output(raw_output) + return documents + + @staticmethod + def _mcp_search_document_id(annotation: Annotation) -> str | None: + """Return the document id encoded in an `mcp://searchindex/` citation.""" + for key in ("url", "title"): + value = annotation.get(key) + if isinstance(value, str) and value.startswith("mcp://searchindex/"): + return value.removeprefix("mcp://searchindex/").split("?", 1)[0].split("#", 1)[0] + return None + + @classmethod + def _enrich_mcp_search_citations(cls, contents: Sequence[Content]) -> None: + """Add MCP search-index document metadata to matching citation annotations.""" + documents: dict[str, Mapping[str, Any]] = {} + for content in contents: + documents.update(cls._extract_mcp_search_documents_from_content(content)) + if not documents: + return + + for content in contents: + if content.type != "text" or not content.annotations: + continue + for annotation in content.annotations: + document_id = cls._mcp_search_document_id(annotation) + if document_id is None: + continue + document = documents.get(document_id) + if document is None: + continue + additional_properties = annotation.setdefault("additional_properties", {}) + additional_properties.setdefault("mcp_document_id", document_id) + document_title = document.get("title") + if isinstance(document_title, str) and document_title: + additional_properties.setdefault("document_title", document_title) + source = document.get("source") + if isinstance(source, str) and source: + additional_properties.setdefault("source", source) + @staticmethod def _get_search_tool_name(item_type: str) -> str: """Map OpenAI search output item types to unified content tool names.""" @@ -2365,7 +2556,10 @@ def _parse_response_from_openai( # Set continuation_token when background operation is still in progress if response.status and response.status in ("in_progress", "queued"): args["continuation_token"] = OpenAIContinuationToken(response_id=response.id) - return ChatResponse(**args) + chat_response = ChatResponse(**args) + if chat_response.messages: + self._enrich_mcp_search_citations(chat_response.messages[0].contents) + return chat_response def _parse_chunk_from_openai( self, @@ -2789,7 +2983,8 @@ def _parse_chunk_from_openai( case "web_search_call" | "file_search_call": contents.append(self._parse_search_tool_call_content(event_item)) case _: - logger.debug("Unparsed event of type: %s: %s", event.type, event) + if getattr(event_item, "type", None) != _AZURE_AI_SEARCH_CALL_OUTPUT_TYPE: + logger.debug("Unparsed event of type: %s: %s", event.type, event) case ( "response.web_search_call.in_progress" | "response.web_search_call.searching" @@ -2958,8 +3153,11 @@ def _get_ann_value(key: str) -> Any: ) elif getattr(done_item, "type", None) in ("web_search_call", "file_search_call"): contents.append(self._parse_search_tool_result_content(done_item)) + elif getattr(done_item, "type", None) == _AZURE_AI_SEARCH_CALL_OUTPUT_TYPE: + pass case _: - logger.debug("Unparsed event of type: %s: %s", event.type, event) + if not isinstance(event.type, str) or not event.type.startswith(_AZURE_AI_SEARCH_OUTPUT_EVENT_PREFIX): + logger.debug("Unparsed event of type: %s: %s", event.type, event) return ChatResponseUpdate( contents=contents, diff --git a/python/packages/openai/tests/openai/test_openai_chat_client.py b/python/packages/openai/tests/openai/test_openai_chat_client.py index 2992e0f41d5..6e57a362f56 100644 --- a/python/packages/openai/tests/openai/test_openai_chat_client.py +++ b/python/packages/openai/tests/openai/test_openai_chat_client.py @@ -4,6 +4,7 @@ import inspect import json import os +from collections.abc import AsyncGenerator from datetime import datetime, timezone from pathlib import Path from typing import Annotated, Any @@ -12,6 +13,7 @@ import pytest from agent_framework import ( Agent, + Annotation, ChatOptions, ChatResponse, ChatResponseUpdate, @@ -3585,6 +3587,393 @@ def test_streaming_annotation_added_with_url_citation() -> None: assert region["end_index"] == 112 +def _make_url_citation_event( + *, + title: str, + get_url: str | None = None, + url: str = "https://example.search.windows.net/", +) -> MagicMock: + event = MagicMock() + event.type = "response.output_text.annotation.added" + event.annotation_index = 0 + event.annotation = { + "type": "url_citation", + "url": url, + "title": title, + "start_index": 100, + "end_index": 112, + } + if get_url is not None: + event.annotation["get_url"] = get_url + return event + + +def _make_mcp_call_done_event(output: str) -> MagicMock: + event = MagicMock() + event.type = "response.output_item.done" + event.item = MagicMock() + event.item.type = "mcp_call" + event.item.id = "mcp_test" + event.item.call_id = None + event.item.output = output + return event + + +def _make_azure_ai_search_output_event( + output: Any, + *, + event_type: str = "response.output_item.done", + top_level_output: bool = False, +) -> MagicMock: + event = MagicMock() + event.type = event_type + if top_level_output: + event.output = output + return event + event.item = MagicMock() + event.item.type = "azure_ai_search_call_output" + event.item.output = output + return event + + +def test_streaming_azure_ai_search_output_enriches_final_url_citation_get_url() -> None: + """Azure AI Search get_urls are resolved onto doc_N citation annotations after streaming completes.""" + client = OpenAIChatClient(model="test-model", api_key="test-key") + chat_options = ChatOptions() + function_call_ids: dict[int, tuple[str, str]] = {} + get_url = "https://example.search.windows.net/indexes/my-index/docs/doc-123?api-version=2024-07-01" + + citation_update = client._parse_chunk_from_openai( + _make_url_citation_event(title="doc_0"), + chat_options, + function_call_ids, + ) + search_update = client._parse_chunk_from_openai( + _make_azure_ai_search_output_event(json.dumps({"documents": [{"id": "doc-123"}], "get_urls": [get_url]})), + chat_options, + function_call_ids, + ) + + assert search_update.contents == [] + + response = client._finalize_response_updates([citation_update, search_update]) + + annotation = response.messages[0].contents[0].annotations[0] + assert annotation["additional_properties"]["get_url"] == get_url + + +async def test_streaming_azure_ai_search_output_enriches_mapped_agent_response() -> None: + """Finalization mutates collected chat updates so mapped agent streams receive the enriched citation too.""" + from agent_framework import AgentResponse + from agent_framework._types import ResponseStream, map_chat_to_agent_update + + client = OpenAIChatClient(model="test-model", api_key="test-key") + chat_options = ChatOptions() + function_call_ids: dict[int, tuple[str, str]] = {} + get_url = "https://example.search.windows.net/indexes/my-index/docs/doc-123?api-version=2024-07-01" + updates = [ + client._parse_chunk_from_openai( + _make_url_citation_event(title="doc_0"), + chat_options, + function_call_ids, + ), + client._parse_chunk_from_openai( + _make_azure_ai_search_output_event(json.dumps({"get_urls": [get_url]})), + chat_options, + function_call_ids, + ), + ] + + async def _stream() -> AsyncGenerator[ChatResponseUpdate, None]: + for update in updates: + yield update + + chat_stream = ResponseStream(_stream(), finalizer=client._finalize_response_updates) + agent_stream = chat_stream.map( + transform=lambda update: map_chat_to_agent_update(update, agent_name="test-agent"), + finalizer=AgentResponse.from_updates, + ) + + async for _ in agent_stream: + pass + response = await agent_stream.get_final_response() + + annotation = response.messages[0].contents[0].annotations[0] + assert annotation["additional_properties"]["get_url"] == get_url + + +def test_streaming_azure_ai_search_output_does_not_overwrite_existing_get_url() -> None: + """If the annotation already contains get_url, the later Search output does not replace it.""" + client = OpenAIChatClient(model="test-model", api_key="test-key") + chat_options = ChatOptions() + function_call_ids: dict[int, tuple[str, str]] = {} + existing_get_url = "https://example.search.windows.net/indexes/my-index/docs/existing?api-version=2024-07-01" + + citation_update = client._parse_chunk_from_openai( + _make_url_citation_event(title="doc_0", get_url=existing_get_url), + chat_options, + function_call_ids, + ) + search_update = client._parse_chunk_from_openai( + _make_azure_ai_search_output_event( + json.dumps({"get_urls": ["https://example.search.windows.net/indexes/my-index/docs/replacement"]}) + ), + chat_options, + function_call_ids, + ) + + response = client._finalize_response_updates([citation_update, search_update]) + + annotation = response.messages[0].contents[0].annotations[0] + assert annotation["additional_properties"]["get_url"] == existing_get_url + + +def test_streaming_azure_ai_search_output_uses_global_doc_index_across_search_events() -> None: + """Azure AI Search `doc_N` URLs are resolved against the concatenated stream order of all search events.""" + client = OpenAIChatClient(model="test-model", api_key="test-key") + chat_options = ChatOptions() + function_call_ids: dict[int, tuple[str, str]] = {} + + citation_update = client._parse_chunk_from_openai( + _make_url_citation_event(title="doc_2"), + chat_options, + function_call_ids, + ) + search_update_one = client._parse_chunk_from_openai( + _make_azure_ai_search_output_event(json.dumps({"get_urls": ["https://example.search.windows.net/docs/one"]})), + chat_options, + function_call_ids, + ) + search_update_two = client._parse_chunk_from_openai( + _make_azure_ai_search_output_event( + json.dumps({ + "get_urls": [ + "https://example.search.windows.net/docs/two", + "https://example.search.windows.net/docs/three", + ] + }) + ), + chat_options, + function_call_ids, + ) + + response = client._finalize_response_updates([citation_update, search_update_one, search_update_two]) + + annotation = response.messages[0].contents[0].annotations[0] + assert annotation["additional_properties"]["get_url"] == "https://example.search.windows.net/docs/three" + + +def test_streaming_azure_ai_search_output_normalizes_non_dict_additional_properties() -> None: + """Existing non-dict additional_properties should be normalized before enriching get_url.""" + client = OpenAIChatClient(model="test-model", api_key="test-key") + chat_options = ChatOptions() + function_call_ids: dict[int, tuple[str, str]] = {} + get_url = "https://example.search.windows.net/indexes/my-index/docs/doc-123?api-version=2024-07-01" + + citation_update = client._parse_chunk_from_openai( + _make_url_citation_event(title="doc_0"), + chat_options, + function_call_ids, + ) + citation_update.contents[0].annotations[0]["additional_properties"] = None + search_update = client._parse_chunk_from_openai( + _make_azure_ai_search_output_event(json.dumps({"get_urls": [get_url]})), + chat_options, + function_call_ids, + ) + + response = client._finalize_response_updates([citation_update, search_update]) + + annotation = response.messages[0].contents[0].annotations[0] + assert annotation["additional_properties"] == {"get_url": get_url} + + +def test_streaming_azure_ai_search_output_does_not_create_additional_properties_for_unusable_citation() -> None: + """Unenrichable Azure AI Search citations should keep their original annotation shape.""" + update = ChatResponseUpdate( + contents=[ + Content.from_text( + text="hello", + annotations=[Annotation(type="citation", title="source_0", url="https://example.invalid")], + ) + ], + raw_representation=_make_azure_ai_search_output_event( + json.dumps({"get_urls": ["https://example.search.windows.net/indexes/my-index/docs/doc-0"]}) + ), + ) + + RawOpenAIChatClient._enrich_streamed_azure_ai_search_citations([update]) + + annotation = update.contents[0].annotations[0] + assert annotation.get("additional_properties") is None + + +def test_extract_azure_ai_search_get_urls_accepts_dedicated_output_event() -> None: + """Dedicated response.azure_ai_search_call_output.* events should yield get_urls too.""" + get_url = "https://example.search.windows.net/indexes/my-index/docs/doc-123?api-version=2024-07-01" + event = _make_azure_ai_search_output_event( + json.dumps({"get_urls": [get_url]}), + event_type="response.azure_ai_search_call_output.done", + top_level_output=True, + ) + + assert RawOpenAIChatClient._extract_azure_ai_search_get_urls(event) == [get_url] + + +def test_parse_chunk_from_openai_ignores_dedicated_azure_ai_search_events() -> None: + """Dedicated Azure AI Search events should be treated as intentional no-op updates.""" + client = OpenAIChatClient(model="test-model", api_key="test-key") + chat_options = ChatOptions() + function_call_ids: dict[int, tuple[str, str]] = {} + event = _make_azure_ai_search_output_event( + json.dumps({"get_urls": ["https://example.search.windows.net/indexes/my-index/docs/doc-0"]}), + event_type="response.azure_ai_search_call_output.done", + top_level_output=True, + ) + + with patch("agent_framework_openai._chat_client.logger.debug") as mock_debug: + update = client._parse_chunk_from_openai(event, chat_options, function_call_ids) + + assert update.contents == [] + mock_debug.assert_not_called() + + +@pytest.mark.parametrize( + ("title", "output"), + [ + ("doc_2", json.dumps({"get_urls": ["https://example.search.windows.net/indexes/my-index/docs/doc-0"]})), + ("source_0", json.dumps({"get_urls": ["https://example.search.windows.net/indexes/my-index/docs/doc-0"]})), + ("doc_0", json.dumps({"documents": [{"id": "doc-0"}]})), + ("doc_0", "{not-json"), + ], +) +def test_streaming_azure_ai_search_output_ignores_unusable_get_url_data(title: str, output: str) -> None: + """Malformed or non-matching Azure AI Search output leaves citations unchanged.""" + client = OpenAIChatClient(model="test-model", api_key="test-key") + chat_options = ChatOptions() + function_call_ids: dict[int, tuple[str, str]] = {} + + citation_update = client._parse_chunk_from_openai( + _make_url_citation_event(title=title), + chat_options, + function_call_ids, + ) + search_update = client._parse_chunk_from_openai( + _make_azure_ai_search_output_event(output), + chat_options, + function_call_ids, + ) + + response = client._finalize_response_updates([citation_update, search_update]) + + annotation = response.messages[0].contents[0].annotations[0] + assert "get_url" not in annotation["additional_properties"] + + +def test_streaming_mcp_searchindex_citation_enriched_from_mcp_output() -> None: + """MCP search-index citations are enriched from retrieved document metadata in MCP output.""" + client = OpenAIChatClient(model="test-model", api_key="test-key") + chat_options = ChatOptions() + function_call_ids: dict[int, tuple[str, str]] = {} + document_id = "inspection_procedures_p1_c0" + mcp_output = f""" +Retrieved 1 document. + +【4:1†source】 +{{ + "id": "{document_id}", + "content": "Inspection Procedures content", + "title": "Inspection Procedures", + "source": "inspection_procedures.pdf" +}} +""" + + mcp_update = client._parse_chunk_from_openai( + _make_mcp_call_done_event(mcp_output), + chat_options, + function_call_ids, + ) + citation_update = client._parse_chunk_from_openai( + _make_url_citation_event( + title=f"mcp://searchindex/{document_id}", + url=f"mcp://searchindex/{document_id}", + ), + chat_options, + function_call_ids, + ) + + response = client._finalize_response_updates([mcp_update, citation_update]) + + annotation = next( + annotation + for message in response.messages + for content in message.contents + for annotation in (content.annotations or []) + ) + assert annotation["additional_properties"]["mcp_document_id"] == document_id + assert annotation["additional_properties"]["document_title"] == "Inspection Procedures" + assert annotation["additional_properties"]["source"] == "inspection_procedures.pdf" + + +def test_parse_response_enriches_mcp_searchindex_citation_from_mcp_output() -> None: + """Non-streaming Responses output also gets MCP search-index document metadata.""" + client = OpenAIChatClient(model="test-model", api_key="test-key") + document_id = "ticket_management_policy_p1_c0" + + mock_mcp_item = MagicMock() + mock_mcp_item.type = "mcp_call" + mock_mcp_item.id = "mcp_123" + mock_mcp_item.call_id = None + mock_mcp_item.name = "knowledge_base_retrieve" + mock_mcp_item.server_label = "knowledge-base" + mock_mcp_item.arguments = '{"queries":["ticket policy"]}' + mock_mcp_item.output = f""" +Retrieved 1 document. + +【14:1†source】 +{{ + "id": "{document_id}", + "content": "Ticket Management Policy content", + "title": "Ticket Management Policy", + "source": "ticket_management_policy.pdf" +}} +""" + + mock_annotation = MagicMock() + mock_annotation.type = "url_citation" + mock_annotation.title = f"mcp://searchindex/{document_id}" + mock_annotation.url = f"mcp://searchindex/{document_id}" + mock_annotation.start_index = 221 + mock_annotation.end_index = 233 + + mock_message_content = MagicMock() + mock_message_content.type = "output_text" + mock_message_content.text = "All tickets must be acknowledged within 1 hour.【14:1†source】" + mock_message_content.annotations = [mock_annotation] + mock_message_content.logprobs = None + + mock_message_item = MagicMock() + mock_message_item.type = "message" + mock_message_item.content = [mock_message_content] + + mock_response = MagicMock() + mock_response.id = "response_123" + mock_response.model = "test-model" + mock_response.created_at = 1000000000 + mock_response.metadata = {} + mock_response.output = [mock_mcp_item, mock_message_item] + mock_response.usage = None + mock_response.status = "completed" + mock_response.conversation = None + + response = client._parse_response_from_openai(mock_response, options={}) + + annotation = response.messages[0].contents[-1].annotations[0] + assert annotation["additional_properties"]["mcp_document_id"] == document_id + assert annotation["additional_properties"]["document_title"] == "Ticket Management Policy" + assert annotation["additional_properties"]["source"] == "ticket_management_policy.pdf" + + def test_streaming_annotation_added_with_url_citation_no_url() -> None: """Test streaming annotation added event with url_citation but missing url is ignored.""" client = OpenAIChatClient(model="test-model", api_key="test-key")