11import base64
22import json
3- from typing import List , Optional , cast
3+ from typing import Any , List , Literal , Optional , cast
44
55import httpx
66import pytest
2929from langchain_tests .utils .pydantic import PYDANTIC_MAJOR_VERSION
3030
3131
32- def _get_joke_class () -> type [BaseModel ]:
32+ def _get_joke_class (
33+ schema_type : Literal ["pydantic" , "typeddict" , "json_schema" ],
34+ ) -> Any :
3335 """
3436 :private:
3537 """
@@ -40,7 +42,28 @@ class Joke(BaseModel):
4042 setup : str = Field (description = "question to set up a joke" )
4143 punchline : str = Field (description = "answer to resolve the joke" )
4244
43- return Joke
45+ def validate_joke (result : Any ) -> bool :
46+ return isinstance (result , Joke )
47+
48+ class JokeDict (TypedDict ):
49+ """Joke to tell user."""
50+
51+ setup : Annotated [str , ..., "question to set up a joke" ]
52+ punchline : Annotated [str , ..., "answer to resolve the joke" ]
53+
54+ def validate_joke_dict (result : Any ) -> bool :
55+ return all (key in ["setup" , "punchline" ] for key in result .keys ())
56+
57+ if schema_type == "pydantic" :
58+ return Joke , validate_joke
59+
60+ elif schema_type == "typeddict" :
61+ return JokeDict , validate_joke_dict
62+
63+ elif schema_type == "json_schema" :
64+ return Joke .model_json_schema (), validate_joke_dict
65+ else :
66+ raise ValueError ("Invalid schema type" )
4467
4568
4669class _MagicFunctionSchema (BaseModel ):
@@ -1151,7 +1174,8 @@ def test_bind_runnables_as_tools(self, model: BaseChatModel) -> None:
11511174 assert tool_call ["args" ].get ("answer_style" )
11521175 assert tool_call ["type" ] == "tool_call"
11531176
1154- def test_structured_output (self , model : BaseChatModel ) -> None :
1177+ @pytest .mark .parametrize ("schema_type" , ["pydantic" , "typeddict" , "json_schema" ])
1178+ def test_structured_output (self , model : BaseChatModel , schema_type : str ) -> None :
11551179 """Test to verify structured output is generated both on invoke and stream.
11561180
11571181 This test is optional and should be skipped if the model does not support
@@ -1181,29 +1205,19 @@ def has_tool_calling(self) -> bool:
11811205 if not self .has_tool_calling :
11821206 pytest .skip ("Test requires tool calling." )
11831207
1184- Joke = _get_joke_class ()
1185- # Pydantic class
1186- chat = model .with_structured_output (Joke , ** self .structured_output_kwargs )
1208+ schema , validation_function = _get_joke_class (schema_type ) # type: ignore[arg-type]
1209+ chat = model .with_structured_output (schema , ** self .structured_output_kwargs )
11871210 result = chat .invoke ("Tell me a joke about cats." )
1188- assert isinstance (result , Joke )
1211+ validation_function (result )
11891212
11901213 for chunk in chat .stream ("Tell me a joke about cats." ):
1191- assert isinstance (chunk , Joke )
1214+ validation_function (chunk )
1215+ assert chunk
11921216
1193- # Schema
1194- chat = model .with_structured_output (
1195- Joke .model_json_schema (), ** self .structured_output_kwargs
1196- )
1197- result = chat .invoke ("Tell me a joke about cats." )
1198- assert isinstance (result , dict )
1199- assert set (result .keys ()) == {"setup" , "punchline" }
1200-
1201- for chunk in chat .stream ("Tell me a joke about cats." ):
1202- assert isinstance (chunk , dict )
1203- assert isinstance (chunk , dict ) # for mypy
1204- assert set (chunk .keys ()) == {"setup" , "punchline" }
1205-
1206- async def test_structured_output_async (self , model : BaseChatModel ) -> None :
1217+ @pytest .mark .parametrize ("schema_type" , ["pydantic" , "typeddict" , "json_schema" ])
1218+ async def test_structured_output_async (
1219+ self , model : BaseChatModel , schema_type : str
1220+ ) -> None :
12071221 """Test to verify structured output is generated both on invoke and stream.
12081222
12091223 This test is optional and should be skipped if the model does not support
@@ -1233,28 +1247,14 @@ def has_tool_calling(self) -> bool:
12331247 if not self .has_tool_calling :
12341248 pytest .skip ("Test requires tool calling." )
12351249
1236- Joke = _get_joke_class ()
1237-
1238- # Pydantic class
1239- chat = model .with_structured_output (Joke , ** self .structured_output_kwargs )
1250+ schema , validation_function = _get_joke_class (schema_type ) # type: ignore[arg-type]
1251+ chat = model .with_structured_output (schema , ** self .structured_output_kwargs )
12401252 result = await chat .ainvoke ("Tell me a joke about cats." )
1241- assert isinstance (result , Joke )
1253+ validation_function (result )
12421254
12431255 async for chunk in chat .astream ("Tell me a joke about cats." ):
1244- assert isinstance (chunk , Joke )
1245-
1246- # Schema
1247- chat = model .with_structured_output (
1248- Joke .model_json_schema (), ** self .structured_output_kwargs
1249- )
1250- result = await chat .ainvoke ("Tell me a joke about cats." )
1251- assert isinstance (result , dict )
1252- assert set (result .keys ()) == {"setup" , "punchline" }
1253-
1254- async for chunk in chat .astream ("Tell me a joke about cats." ):
1255- assert isinstance (chunk , dict )
1256- assert isinstance (chunk , dict ) # for mypy
1257- assert set (chunk .keys ()) == {"setup" , "punchline" }
1256+ validation_function (chunk )
1257+ assert chunk
12581258
12591259 @pytest .mark .skipif (PYDANTIC_MAJOR_VERSION != 2 , reason = "Test requires pydantic 2." )
12601260 def test_structured_output_pydantic_2_v1 (self , model : BaseChatModel ) -> None :
0 commit comments