Skip to content

Commit 9f8c26d

Browse files
authored
Improve ChatAdapter, introduce JsonAdapter, add default retries with the latter. (#1700)
* Improve ChatAdapter's handling of typed values and Pydantic models * Fixes for Literal * Fixes for formatting complex-typed values * Improve ChatAdapter, introduce JsonAdapter, add default retries with the latter. * Minor fixes * Update lock file * Updates for json retries
1 parent 16ba98a commit 9f8c26d

File tree

12 files changed

+1998
-1421
lines changed

12 files changed

+1998
-1421
lines changed

dspy/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,4 +73,5 @@
7373

7474
# TODO: Consider if this should access settings.lm *or* a list that's shared across all LMs in the program.
7575
def inspect_history(*args, **kwargs):
76-
return settings.lm.inspect_history(*args, **kwargs)
76+
from dspy.clients.lm import GLOBAL_HISTORY, _inspect_history
77+
return _inspect_history(GLOBAL_HISTORY, *args, **kwargs)

dspy/adapters/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
from .base import Adapter
2-
from .chat_adapter import ChatAdapter
2+
from .chat_adapter import ChatAdapter
3+
from .json_adapter import JsonAdapter

dspy/adapters/base.py

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,22 @@ def __init_subclass__(cls, **kwargs) -> None:
1313
cls.parse = with_callbacks(cls.parse)
1414

1515
def __call__(self, lm, lm_kwargs, signature, demos, inputs, _parse_values=True):
16-
inputs = self.format(signature, demos, inputs)
17-
inputs = dict(prompt=inputs) if isinstance(inputs, str) else dict(messages=inputs)
16+
inputs_ = self.format(signature, demos, inputs)
17+
inputs_ = dict(prompt=inputs_) if isinstance(inputs_, str) else dict(messages=inputs_)
1818

19-
outputs = lm(**inputs, **lm_kwargs)
19+
outputs = lm(**inputs_, **lm_kwargs)
2020
values = []
2121

22-
for output in outputs:
23-
value = self.parse(signature, output, _parse_values=_parse_values)
24-
assert set(value.keys()) == set(signature.output_fields.keys()), f"Expected {signature.output_fields.keys()} but got {value.keys()}"
25-
values.append(value)
26-
27-
return values
22+
try:
23+
for output in outputs:
24+
value = self.parse(signature, output, _parse_values=_parse_values)
25+
assert set(value.keys()) == set(signature.output_fields.keys()), f"Expected {signature.output_fields.keys()} but got {value.keys()}"
26+
values.append(value)
27+
return values
28+
29+
except Exception as e:
30+
from .json_adapter import JsonAdapter
31+
if _parse_values and not isinstance(self, JsonAdapter):
32+
return JsonAdapter()(lm, lm_kwargs, signature, demos, inputs, _parse_values=_parse_values)
33+
raise e
34+

dspy/adapters/chat_adapter.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import textwrap
88

99
from pydantic import TypeAdapter
10+
from collections.abc import Mapping
1011
from pydantic.fields import FieldInfo
1112
from typing import Any, Dict, KeysView, List, Literal, NamedTuple, get_args, get_origin
1213

@@ -269,6 +270,19 @@ def enumerate_fields(fields):
269270
return "\n".join(parts).strip()
270271

271272

273+
def move_type_to_front(d):
274+
# Move the 'type' key to the front of the dictionary, recursively, for LLM readability/adherence.
275+
if isinstance(d, Mapping):
276+
return {k: move_type_to_front(v) for k, v in sorted(d.items(), key=lambda item: (item[0] != 'type', item[0]))}
277+
elif isinstance(d, list):
278+
return [move_type_to_front(item) for item in d]
279+
return d
280+
281+
def prepare_schema(type_):
282+
schema = pydantic.TypeAdapter(type_).json_schema()
283+
schema = move_type_to_front(schema)
284+
return schema
285+
272286
def prepare_instructions(signature: SignatureMeta):
273287
parts = []
274288
parts.append("Your input fields are:\n" + enumerate_fields(signature.input_fields))
@@ -290,7 +304,7 @@ def field_metadata(field_name, field_info):
290304
desc = f"must be one of: {'; '.join([str(x) for x in type_.__args__])}"
291305
else:
292306
desc = "must be pareseable according to the following JSON schema: "
293-
desc += json.dumps(pydantic.TypeAdapter(type_).json_schema())
307+
desc += json.dumps(prepare_schema(type_))
294308

295309
desc = (" " * 8) + f"# note: the value you produce {desc}" if desc else ""
296310
return f"{{{field_name}}}{desc}"

0 commit comments

Comments
 (0)