Skip to content

Commit 8082c19

Browse files
ccurmeshinxi
authored andcommitted
tests[patch]: improve coverage of structured output tests (#29478)
1 parent 0445156 commit 8082c19

File tree

1 file changed

+42
-42
lines changed

1 file changed

+42
-42
lines changed

libs/standard-tests/langchain_tests/integration_tests/chat_models.py

Lines changed: 42 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import base64
22
import json
3-
from typing import List, Optional, cast
3+
from typing import Any, List, Literal, Optional, cast
44

55
import httpx
66
import pytest
@@ -29,7 +29,9 @@
2929
from langchain_tests.utils.pydantic import PYDANTIC_MAJOR_VERSION
3030

3131

32-
def _get_joke_class() -> type[BaseModel]:
32+
def _get_joke_class(
33+
schema_type: Literal["pydantic", "typeddict", "json_schema"],
34+
) -> Any:
3335
"""
3436
:private:
3537
"""
@@ -40,7 +42,28 @@ class Joke(BaseModel):
4042
setup: str = Field(description="question to set up a joke")
4143
punchline: str = Field(description="answer to resolve the joke")
4244

43-
return Joke
45+
def validate_joke(result: Any) -> bool:
46+
return isinstance(result, Joke)
47+
48+
class JokeDict(TypedDict):
49+
"""Joke to tell user."""
50+
51+
setup: Annotated[str, ..., "question to set up a joke"]
52+
punchline: Annotated[str, ..., "answer to resolve the joke"]
53+
54+
def validate_joke_dict(result: Any) -> bool:
55+
return all(key in ["setup", "punchline"] for key in result.keys())
56+
57+
if schema_type == "pydantic":
58+
return Joke, validate_joke
59+
60+
elif schema_type == "typeddict":
61+
return JokeDict, validate_joke_dict
62+
63+
elif schema_type == "json_schema":
64+
return Joke.model_json_schema(), validate_joke_dict
65+
else:
66+
raise ValueError("Invalid schema type")
4467

4568

4669
class _MagicFunctionSchema(BaseModel):
@@ -1151,7 +1174,8 @@ def test_bind_runnables_as_tools(self, model: BaseChatModel) -> None:
11511174
assert tool_call["args"].get("answer_style")
11521175
assert tool_call["type"] == "tool_call"
11531176

1154-
def test_structured_output(self, model: BaseChatModel) -> None:
1177+
@pytest.mark.parametrize("schema_type", ["pydantic", "typeddict", "json_schema"])
1178+
def test_structured_output(self, model: BaseChatModel, schema_type: str) -> None:
11551179
"""Test to verify structured output is generated both on invoke and stream.
11561180
11571181
This test is optional and should be skipped if the model does not support
@@ -1181,29 +1205,19 @@ def has_tool_calling(self) -> bool:
11811205
if not self.has_tool_calling:
11821206
pytest.skip("Test requires tool calling.")
11831207

1184-
Joke = _get_joke_class()
1185-
# Pydantic class
1186-
chat = model.with_structured_output(Joke, **self.structured_output_kwargs)
1208+
schema, validation_function = _get_joke_class(schema_type) # type: ignore[arg-type]
1209+
chat = model.with_structured_output(schema, **self.structured_output_kwargs)
11871210
result = chat.invoke("Tell me a joke about cats.")
1188-
assert isinstance(result, Joke)
1211+
validation_function(result)
11891212

11901213
for chunk in chat.stream("Tell me a joke about cats."):
1191-
assert isinstance(chunk, Joke)
1214+
validation_function(chunk)
1215+
assert chunk
11921216

1193-
# Schema
1194-
chat = model.with_structured_output(
1195-
Joke.model_json_schema(), **self.structured_output_kwargs
1196-
)
1197-
result = chat.invoke("Tell me a joke about cats.")
1198-
assert isinstance(result, dict)
1199-
assert set(result.keys()) == {"setup", "punchline"}
1200-
1201-
for chunk in chat.stream("Tell me a joke about cats."):
1202-
assert isinstance(chunk, dict)
1203-
assert isinstance(chunk, dict) # for mypy
1204-
assert set(chunk.keys()) == {"setup", "punchline"}
1205-
1206-
async def test_structured_output_async(self, model: BaseChatModel) -> None:
1217+
@pytest.mark.parametrize("schema_type", ["pydantic", "typeddict", "json_schema"])
1218+
async def test_structured_output_async(
1219+
self, model: BaseChatModel, schema_type: str
1220+
) -> None:
12071221
"""Test to verify structured output is generated both on invoke and stream.
12081222
12091223
This test is optional and should be skipped if the model does not support
@@ -1233,28 +1247,14 @@ def has_tool_calling(self) -> bool:
12331247
if not self.has_tool_calling:
12341248
pytest.skip("Test requires tool calling.")
12351249

1236-
Joke = _get_joke_class()
1237-
1238-
# Pydantic class
1239-
chat = model.with_structured_output(Joke, **self.structured_output_kwargs)
1250+
schema, validation_function = _get_joke_class(schema_type) # type: ignore[arg-type]
1251+
chat = model.with_structured_output(schema, **self.structured_output_kwargs)
12401252
result = await chat.ainvoke("Tell me a joke about cats.")
1241-
assert isinstance(result, Joke)
1253+
validation_function(result)
12421254

12431255
async for chunk in chat.astream("Tell me a joke about cats."):
1244-
assert isinstance(chunk, Joke)
1245-
1246-
# Schema
1247-
chat = model.with_structured_output(
1248-
Joke.model_json_schema(), **self.structured_output_kwargs
1249-
)
1250-
result = await chat.ainvoke("Tell me a joke about cats.")
1251-
assert isinstance(result, dict)
1252-
assert set(result.keys()) == {"setup", "punchline"}
1253-
1254-
async for chunk in chat.astream("Tell me a joke about cats."):
1255-
assert isinstance(chunk, dict)
1256-
assert isinstance(chunk, dict) # for mypy
1257-
assert set(chunk.keys()) == {"setup", "punchline"}
1256+
validation_function(chunk)
1257+
assert chunk
12581258

12591259
@pytest.mark.skipif(PYDANTIC_MAJOR_VERSION != 2, reason="Test requires pydantic 2.")
12601260
def test_structured_output_pydantic_2_v1(self, model: BaseChatModel) -> None:

0 commit comments

Comments
 (0)