Skip to content

Commit ca37072

Browse files
fix: Use constrained decoding for OpenAIResponses structured_predict (#20808)
1 parent 59b4cc6 commit ca37072

File tree

3 files changed

+188
-52
lines changed

3 files changed

+188
-52
lines changed

llama-index-integrations/llms/llama-index-llms-openai/llama_index/llms/openai/responses.py

Lines changed: 34 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -963,17 +963,26 @@ def structured_predict(
963963
llm_kwargs: Optional[Dict[str, Any]] = None,
964964
**prompt_args: Any,
965965
) -> Model:
966-
"""Structured predict."""
967-
llm_kwargs = llm_kwargs or {}
966+
"""Structured predict using constrained decoding via responses.parse.
968967
969-
llm_kwargs["tool_choice"] = (
970-
"required" if "tool_choice" not in llm_kwargs else llm_kwargs["tool_choice"]
968+
Uses `text_format` with `tool_choice="none"` to guarantee JSON schema
969+
adherence at the API level, rather than best-effort function calling.
970+
"""
971+
messages = prompt.format_messages(**prompt_args)
972+
message_dicts = to_openai_message_dicts(
973+
messages, model=self.model, is_responses_api=True
971974
)
972-
# by default structured prediction uses function calling to extract structured outputs
973-
# here we force tool_choice to be required
974-
return super().structured_predict(
975-
output_cls, prompt, llm_kwargs=llm_kwargs, **prompt_args
975+
response = self._client.responses.parse(
976+
model=self.model,
977+
input=message_dicts,
978+
text_format=output_cls,
979+
tool_choice="none",
980+
store=self.store,
981+
**(llm_kwargs or {}),
976982
)
983+
if response.output_parsed is not None:
984+
return response.output_parsed
985+
raise ValueError("Failed to produce a structured response from the model.")
977986

978987
@dispatcher.span
979988
async def astructured_predict(
@@ -983,17 +992,26 @@ async def astructured_predict(
983992
llm_kwargs: Optional[Dict[str, Any]] = None,
984993
**prompt_args: Any,
985994
) -> Model:
986-
"""Structured predict."""
987-
llm_kwargs = llm_kwargs or {}
995+
"""Async structured predict using constrained decoding via responses.parse.
988996
989-
llm_kwargs["tool_choice"] = (
990-
"required" if "tool_choice" not in llm_kwargs else llm_kwargs["tool_choice"]
997+
Uses `text_format` with `tool_choice="none"` to guarantee JSON schema
998+
adherence at the API level, rather than best-effort function calling.
999+
"""
1000+
messages = prompt.format_messages(**prompt_args)
1001+
message_dicts = to_openai_message_dicts(
1002+
messages, model=self.model, is_responses_api=True
9911003
)
992-
# by default structured prediction uses function calling to extract structured outputs
993-
# here we force tool_choice to be required
994-
return await super().astructured_predict(
995-
output_cls, prompt, llm_kwargs=llm_kwargs, **prompt_args
1004+
response = await self._aclient.responses.parse(
1005+
model=self.model,
1006+
input=message_dicts,
1007+
text_format=output_cls,
1008+
tool_choice="none",
1009+
store=self.store,
1010+
**(llm_kwargs or {}),
9961011
)
1012+
if response.output_parsed is not None:
1013+
return response.output_parsed
1014+
raise ValueError("Failed to produce a structured response from the model.")
9971015

9981016
@dispatcher.span
9991017
def stream_structured_predict(

llama-index-integrations/llms/llama-index-llms-openai/pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ dev = [
2727

2828
[project]
2929
name = "llama-index-llms-openai"
30-
version = "0.6.21"
30+
version = "0.6.22"
3131
description = "llama-index llms openai integration"
3232
authors = [{name = "llama-index"}]
3333
requires-python = ">=3.9,<4.0"

llama-index-integrations/llms/llama-index-llms-openai/tests/test_openai_responses.py

Lines changed: 153 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -40,16 +40,14 @@
4040
@pytest.fixture
4141
def default_responses_llm():
4242
"""Create a default OpenAIResponses instance with mocked clients."""
43-
with (
44-
patch("llama_index.llms.openai.responses.SyncOpenAI"),
45-
patch("llama_index.llms.openai.responses.AsyncOpenAI"),
46-
):
47-
llm = OpenAIResponses(
48-
model="gpt-4o-mini",
49-
api_key="fake-api-key",
50-
api_base="https://api.openai.com/v1",
51-
api_version="2023-05-15",
52-
)
43+
with patch("llama_index.llms.openai.responses.SyncOpenAI"):
44+
with patch("llama_index.llms.openai.responses.AsyncOpenAI"):
45+
llm = OpenAIResponses(
46+
model="gpt-4o-mini",
47+
api_key="fake-api-key",
48+
api_base="https://api.openai.com/v1",
49+
api_version="2023-05-15",
50+
)
5351
return llm
5452

5553

@@ -68,21 +66,19 @@ def test_init_and_properties(default_responses_llm):
6866

6967
def test_get_model_name():
7068
"""Test different model name formats are properly handled."""
71-
with (
72-
patch("llama_index.llms.openai.responses.SyncOpenAI"),
73-
patch("llama_index.llms.openai.responses.AsyncOpenAI"),
74-
):
75-
# Standard model
76-
llm = OpenAIResponses(model="gpt-4o-mini")
77-
assert llm._get_model_name() == "gpt-4o-mini"
69+
with patch("llama_index.llms.openai.responses.SyncOpenAI"):
70+
with patch("llama_index.llms.openai.responses.AsyncOpenAI"):
71+
# Standard model
72+
llm = OpenAIResponses(model="gpt-4o-mini")
73+
assert llm._get_model_name() == "gpt-4o-mini"
7874

79-
# Legacy fine-tuning format
80-
llm = OpenAIResponses(model="ft-model:gpt-4")
81-
assert llm._get_model_name() == "ft-model"
75+
# Legacy fine-tuning format
76+
llm = OpenAIResponses(model="ft-model:gpt-4")
77+
assert llm._get_model_name() == "ft-model"
8278

83-
# New fine-tuning format
84-
llm = OpenAIResponses(model="ft:gpt-4:org:custom:id")
85-
assert llm._get_model_name() == "gpt-4"
79+
# New fine-tuning format
80+
llm = OpenAIResponses(model="ft:gpt-4:org:custom:id")
81+
assert llm._get_model_name() == "gpt-4"
8682

8783

8884
def test_get_model_kwargs(default_responses_llm):
@@ -144,12 +140,10 @@ def test_parse_response_output():
144140
)
145141
]
146142

147-
with (
148-
patch("llama_index.llms.openai.responses.SyncOpenAI"),
149-
patch("llama_index.llms.openai.responses.AsyncOpenAI"),
150-
):
151-
llm = OpenAIResponses(model="gpt-4o-mini")
152-
chat_response = llm._parse_response_output(output)
143+
with patch("llama_index.llms.openai.responses.SyncOpenAI"):
144+
with patch("llama_index.llms.openai.responses.AsyncOpenAI"):
145+
llm = OpenAIResponses(model="gpt-4o-mini")
146+
chat_response = llm._parse_response_output(output)
153147

154148
assert chat_response.message.role == MessageRole.ASSISTANT
155149
assert len(chat_response.message.blocks) == 1
@@ -330,12 +324,10 @@ def test_get_tool_calls_from_response():
330324
)
331325
]
332326

333-
with (
334-
patch("llama_index.llms.openai.responses.SyncOpenAI"),
335-
patch("llama_index.llms.openai.responses.AsyncOpenAI"),
336-
):
337-
llm = OpenAIResponses(model="gpt-4o-mini")
338-
tool_selections = llm.get_tool_calls_from_response(chat_response)
327+
with patch("llama_index.llms.openai.responses.SyncOpenAI"):
328+
with patch("llama_index.llms.openai.responses.AsyncOpenAI"):
329+
llm = OpenAIResponses(model="gpt-4o-mini")
330+
tool_selections = llm.get_tool_calls_from_response(chat_response)
339331

340332
assert len(tool_selections) == 1
341333
assert tool_selections[0].tool_id == "123"
@@ -552,6 +544,132 @@ async def test_astream_complete_with_api():
552544
assert responses[-1].text is not None
553545

554546

547+
def test_structured_predict_uses_responses_parse(default_responses_llm):
548+
"""Test that structured_predict uses responses.parse with text_format for constrained decoding."""
549+
550+
class Person(BaseModel):
551+
name: str = Field(description="The person's name")
552+
age: int = Field(description="The person's age")
553+
554+
llm = default_responses_llm
555+
mock_response = MagicMock()
556+
mock_response.output_parsed = Person(name="Alice", age=25)
557+
llm._client.responses.parse = MagicMock(return_value=mock_response)
558+
559+
result = llm.structured_predict(
560+
output_cls=Person,
561+
prompt=PromptTemplate(
562+
"Create a profile for a person named {name} who is {age} years old"
563+
),
564+
name="Alice",
565+
age=25,
566+
)
567+
568+
assert isinstance(result, Person)
569+
assert result.name == "Alice"
570+
assert result.age == 25
571+
572+
call_kwargs = llm._client.responses.parse.call_args
573+
assert call_kwargs.kwargs["text_format"] is Person
574+
assert call_kwargs.kwargs["tool_choice"] == "none"
575+
assert call_kwargs.kwargs["model"] == "gpt-4o-mini"
576+
577+
578+
def test_structured_predict_raises_on_none_output(default_responses_llm):
579+
"""Test that structured_predict raises ValueError when output_parsed is None."""
580+
581+
class Person(BaseModel):
582+
name: str = Field(description="The person's name")
583+
age: int = Field(description="The person's age")
584+
585+
llm = default_responses_llm
586+
mock_response = MagicMock()
587+
mock_response.output_parsed = None
588+
llm._client.responses.parse = MagicMock(return_value=mock_response)
589+
590+
with pytest.raises(ValueError, match="Failed to produce a structured response"):
591+
llm.structured_predict(
592+
output_cls=Person,
593+
prompt=PromptTemplate("Create a profile for a person"),
594+
)
595+
596+
597+
@pytest.mark.asyncio
598+
async def test_astructured_predict_uses_responses_parse(default_responses_llm):
599+
"""Test that astructured_predict uses async responses.parse with text_format."""
600+
from unittest.mock import AsyncMock
601+
602+
class Person(BaseModel):
603+
name: str = Field(description="The person's name")
604+
age: int = Field(description="The person's age")
605+
606+
llm = default_responses_llm
607+
mock_response = MagicMock()
608+
mock_response.output_parsed = Person(name="Bob", age=30)
609+
llm._aclient.responses.parse = AsyncMock(return_value=mock_response)
610+
611+
result = await llm.astructured_predict(
612+
output_cls=Person,
613+
prompt=PromptTemplate(
614+
"Create a profile for a person named {name} who is {age} years old"
615+
),
616+
name="Bob",
617+
age=30,
618+
)
619+
620+
assert isinstance(result, Person)
621+
assert result.name == "Bob"
622+
assert result.age == 30
623+
624+
call_kwargs = llm._aclient.responses.parse.call_args
625+
assert call_kwargs.kwargs["text_format"] is Person
626+
assert call_kwargs.kwargs["tool_choice"] == "none"
627+
assert call_kwargs.kwargs["model"] == "gpt-4o-mini"
628+
629+
630+
@pytest.mark.asyncio
631+
async def test_astructured_predict_raises_on_none_output(default_responses_llm):
632+
"""Test that astructured_predict raises ValueError when output_parsed is None."""
633+
from unittest.mock import AsyncMock
634+
635+
class Person(BaseModel):
636+
name: str = Field(description="The person's name")
637+
age: int = Field(description="The person's age")
638+
639+
llm = default_responses_llm
640+
mock_response = MagicMock()
641+
mock_response.output_parsed = None
642+
llm._aclient.responses.parse = AsyncMock(return_value=mock_response)
643+
644+
with pytest.raises(ValueError, match="Failed to produce a structured response"):
645+
await llm.astructured_predict(
646+
output_cls=Person,
647+
prompt=PromptTemplate("Create a profile for a person"),
648+
)
649+
650+
651+
def test_structured_predict_passes_llm_kwargs(default_responses_llm):
652+
"""Test that structured_predict forwards llm_kwargs to responses.parse."""
653+
654+
class Person(BaseModel):
655+
name: str = Field(description="The person's name")
656+
age: int = Field(description="The person's age")
657+
658+
llm = default_responses_llm
659+
mock_response = MagicMock()
660+
mock_response.output_parsed = Person(name="Alice", age=25)
661+
llm._client.responses.parse = MagicMock(return_value=mock_response)
662+
663+
llm.structured_predict(
664+
output_cls=Person,
665+
prompt=PromptTemplate("Create a profile for a person"),
666+
llm_kwargs={"temperature": 0.5},
667+
)
668+
669+
call_kwargs = llm._client.responses.parse.call_args
670+
assert call_kwargs.kwargs["temperature"] == 0.5
671+
672+
555673
@pytest.mark.skipif(SKIP_OPENAI_TESTS, reason="OpenAI API key not available")
556674
def test_structured_prediction_with_api():
557675
"""Test structured prediction with real API call."""

0 commit comments

Comments
 (0)