Skip to content

Commit 2f3834a

Browse files
fix lint for chat adapter (#7834)
1 parent 9cf2d40 commit 2f3834a

File tree

1 file changed

+26
-20
lines changed

1 file changed

+26
-20
lines changed

dspy/adapters/chat_adapter.py

Lines changed: 26 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -5,18 +5,17 @@
55
import textwrap
66
from collections.abc import Mapping
77
from itertools import chain
8-
98
from typing import Any, Dict, Literal, NamedTuple
109

1110
import pydantic
1211
from pydantic.fields import FieldInfo
1312

1413
from dspy.adapters.base import Adapter
15-
from dspy.adapters.utils import parse_value, format_field_value, get_annotation_name
14+
from dspy.adapters.image_utils import try_expand_image_tags
15+
from dspy.adapters.utils import format_field_value, get_annotation_name, parse_value
1616
from dspy.signatures.field import OutputField
1717
from dspy.signatures.signature import Signature, SignatureMeta
1818
from dspy.signatures.utils import get_dspy_field_type
19-
from dspy.adapters.image_utils import try_expand_image_tags
2019

2120
field_header_pattern = re.compile(r"\[\[ ## (\w+) ## \]\]")
2221

@@ -99,9 +98,6 @@ def format_finetune_data(self, signature, demos, inputs, outputs):
9998
# Wrap the messages in a dictionary with a "messages" key
10099
return dict(messages=messages)
101100

102-
def format_turn(self, signature, values, role, incomplete=False):
103-
return format_turn(signature, values, role, incomplete)
104-
105101
def format_fields(self, signature, values, role):
106102
fields_with_values = {
107103
FieldInfoWithName(name=field_name, info=field_info): values.get(
@@ -152,7 +148,9 @@ def format_turn(signature, values, role, incomplete=False):
152148
"""
153149
if role == "user":
154150
fields = signature.input_fields
155-
message_prefix = "This is an example of the task, though some input or output fields are not supplied." if incomplete else ""
151+
message_prefix = (
152+
"This is an example of the task, though some input or output fields are not supplied." if incomplete else ""
153+
)
156154
else:
157155
# Add the completed field for the assistant turn
158156
fields = {**signature.output_fields, BuiltInCompletedOutputFieldInfo.name: BuiltInCompletedOutputFieldInfo.info}
@@ -167,31 +165,40 @@ def format_turn(signature, values, role, incomplete=False):
167165
messages.append(message_prefix)
168166

169167
field_messages = format_fields(
170-
{FieldInfoWithName(name=k, info=v): values.get(k, "Not supplied for this particular example.")
171-
for k, v in fields.items()},
168+
{
169+
FieldInfoWithName(name=k, info=v): values.get(k, "Not supplied for this particular example.")
170+
for k, v in fields.items()
171+
},
172172
)
173173
messages.append(field_messages)
174174

175+
def type_info(v):
176+
if v.annotation is not str:
177+
return f" (must be formatted as a valid Python {get_annotation_name(v.annotation)})"
178+
else:
179+
return ""
180+
175181
# Add output field instructions for user messages
176182
if role == "user" and signature.output_fields:
177-
type_info = lambda v: f" (must be formatted as a valid Python {get_annotation_name(v.annotation)})" if v.annotation is not str else ""
178-
field_instructions = "Respond with the corresponding output fields, starting with the field " + \
179-
", then ".join(f"`[[ ## {f} ## ]]`{type_info(v)}" for f, v in signature.output_fields.items()) + \
180-
", and then ending with the marker for `[[ ## completed ## ]]`."
183+
field_instructions = (
184+
"Respond with the corresponding output fields, starting with the field "
185+
+ ", then ".join(f"`[[ ## {f} ## ]]`{type_info(v)}" for f, v in signature.output_fields.items())
186+
+ ", and then ending with the marker for `[[ ## completed ## ]]`."
187+
)
181188
messages.append(field_instructions)
182189
joined_messages = "\n\n".join(msg for msg in messages)
183190
return {"role": role, "content": joined_messages}
184191

192+
185193
def flatten_messages(messages):
186194
"""Flatten nested message lists."""
187-
return list(chain.from_iterable(
188-
item if isinstance(item, list) else [item] for item in messages
189-
))
195+
return list(chain.from_iterable(item if isinstance(item, list) else [item] for item in messages))
196+
190197

191198
def enumerate_fields(fields: dict) -> str:
192199
parts = []
193200
for idx, (k, v) in enumerate(fields.items()):
194-
parts.append(f"{idx+1}. `{k}`")
201+
parts.append(f"{idx + 1}. `{k}`")
195202
parts[-1] += f" ({get_annotation_name(v.annotation)})"
196203
parts[-1] += f": {v.json_schema_extra['desc']}" if v.json_schema_extra["desc"] != f"${{{k}}}" else ""
197204

@@ -207,8 +214,8 @@ def move_type_to_front(d):
207214
return d
208215

209216

210-
def prepare_schema(type_):
211-
schema = pydantic.TypeAdapter(type_).json_schema()
217+
def prepare_schema(field_type):
218+
schema = pydantic.TypeAdapter(field_type).json_schema()
212219
schema = move_type_to_front(schema)
213220
return schema
214221

@@ -237,7 +244,6 @@ def field_metadata(field_name, field_info):
237244
f"must exactly match (no extra characters) one of: {'; '.join([str(x) for x in field_type.__args__])}"
238245
)
239246
else:
240-
# desc = "must be pareseable according to the following JSON schema: "
241247
desc = "must adhere to the JSON schema: "
242248
desc += json.dumps(prepare_schema(field_type), ensure_ascii=False)
243249

0 commit comments

Comments
 (0)