|
7 | 7 | from dataclasses import dataclass, field |
8 | 8 | from typing import TYPE_CHECKING, Any, Generic, Literal, cast, overload |
9 | 9 |
|
10 | | -from pydantic import TypeAdapter, ValidationError |
| 10 | +from pydantic import Json, TypeAdapter, ValidationError |
11 | 11 | from pydantic_core import SchemaValidator, to_json |
12 | 12 | from typing_extensions import Self, TypedDict, TypeVar, assert_never |
13 | 13 |
|
@@ -624,21 +624,33 @@ def __init__( |
624 | 624 | json_schema = self._function_schema.json_schema |
625 | 625 | json_schema['description'] = self._function_schema.description |
626 | 626 | else: |
627 | | - type_adapter: TypeAdapter[Any] |
| 627 | + json_schema_type_adapter: TypeAdapter[Any] |
| 628 | + validation_type_adapter: TypeAdapter[Any] |
628 | 629 | if _utils.is_model_like(output): |
629 | | - type_adapter = TypeAdapter(output) |
| 630 | + json_schema_type_adapter = validation_type_adapter = TypeAdapter(output) |
630 | 631 | else: |
631 | 632 | self.outer_typed_dict_key = 'response' |
| 633 | + output_type: type[OutputDataT] = cast(type[OutputDataT], output) |
| 634 | + |
632 | 635 | response_data_typed_dict = TypedDict( # noqa: UP013 |
633 | 636 | 'response_data_typed_dict', |
634 | | - {'response': cast(type[OutputDataT], output)}, # pyright: ignore[reportInvalidTypeForm] |
| 637 | + {'response': output_type}, # pyright: ignore[reportInvalidTypeForm] |
| 638 | + ) |
| 639 | + json_schema_type_adapter = TypeAdapter(response_data_typed_dict) |
| 640 | + |
| 641 | + # More lenient validator: allow either the native type or a JSON string containing it |
| 642 | + # i.e. `response: OutputDataT | Json[OutputDataT]`, as some models don't follow the schema correctly, |
| 643 | + # e.g. `BedrockConverseModel('us.meta.llama3-2-11b-instruct-v1:0')` |
| 644 | + response_validation_typed_dict = TypedDict( # noqa: UP013 |
| 645 | + 'response_validation_typed_dict', |
| 646 | + {'response': output_type | Json[output_type]}, # pyright: ignore[reportInvalidTypeForm] |
635 | 647 | ) |
636 | | - type_adapter = TypeAdapter(response_data_typed_dict) |
| 648 | + validation_type_adapter = TypeAdapter(response_validation_typed_dict) |
637 | 649 |
|
638 | 650 | # Really a PluggableSchemaValidator, but it's API-compatible |
639 | | - self.validator = cast(SchemaValidator, type_adapter.validator) |
| 651 | + self.validator = cast(SchemaValidator, validation_type_adapter.validator) |
640 | 652 | json_schema = _utils.check_object_json_schema( |
641 | | - type_adapter.json_schema(schema_generator=GenerateToolJsonSchema) |
| 653 | + json_schema_type_adapter.json_schema(schema_generator=GenerateToolJsonSchema) |
642 | 654 | ) |
643 | 655 |
|
644 | 656 | if self.outer_typed_dict_key: |
|
0 commit comments