Skip to content

Commit ff2fa71

Browse files
simple formattting (#7979)
1 parent 5cc3dcb commit ff2fa71

File tree

2 files changed

+55
-16
lines changed

2 files changed

+55
-16
lines changed

dspy/adapters/chat_adapter.py

Lines changed: 27 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,14 @@ class ChatAdapter(Adapter):
3737
def __init__(self, callbacks: Optional[list[BaseCallback]] = None):
3838
super().__init__(callbacks)
3939

40-
def __call__(self, lm: LM, lm_kwargs: dict[str, Any], signature: Type[Signature], demos: list[dict[str, Any]], inputs: dict[str, Any]) -> list[dict[str, Any]]:
40+
def __call__(
41+
self,
42+
lm: LM,
43+
lm_kwargs: dict[str, Any],
44+
signature: Type[Signature],
45+
demos: list[dict[str, Any]],
46+
inputs: dict[str, Any],
47+
) -> list[dict[str, Any]]:
4148
try:
4249
return super().__call__(lm, lm_kwargs, signature, demos, inputs)
4350
except Exception as e:
@@ -46,8 +53,10 @@ def __call__(self, lm: LM, lm_kwargs: dict[str, Any], signature: Type[Signature]
4653
raise e
4754
# fallback to JSONAdapter
4855
return JSONAdapter()(lm, lm_kwargs, signature, demos, inputs)
49-
50-
def format(self, signature: Type[Signature], demos: list[dict[str, Any]], inputs: dict[str, Any]) -> list[dict[str, Any]]:
56+
57+
def format(
58+
self, signature: Type[Signature], demos: list[dict[str, Any]], inputs: dict[str, Any]
59+
) -> list[dict[str, Any]]:
5160
messages: list[dict[str, Any]] = []
5261

5362
# Extract demos where some of the output_fields are not filled in.
@@ -88,7 +97,7 @@ def parse(self, signature: Type[Signature], completion: str) -> dict[str, Any]:
8897
if match:
8998
# If the header pattern is found, split the rest of the line as content
9099
header = match.group(1)
91-
remaining_content = line[match.end():].strip()
100+
remaining_content = line[match.end() :].strip()
92101
sections.append((header, [remaining_content] if remaining_content else []))
93102
else:
94103
sections[-1][1].append(line)
@@ -111,7 +120,9 @@ def parse(self, signature: Type[Signature], completion: str) -> dict[str, Any]:
111120
return fields
112121

113122
# TODO(PR): Looks ok?
114-
def format_finetune_data(self, signature: Type[Signature], demos: list[dict[str, Any]], inputs: dict[str, Any], outputs: dict[str, Any]) -> dict[str, list[Any]]:
123+
def format_finetune_data(
124+
self, signature: Type[Signature], demos: list[dict[str, Any]], inputs: dict[str, Any], outputs: dict[str, Any]
125+
) -> dict[str, list[Any]]:
115126
# Get system + user messages
116127
messages = self.format(signature, demos, inputs)
117128

@@ -134,7 +145,14 @@ def format_fields(self, signature: Type[Signature], values: dict[str, Any], role
134145
}
135146
return format_fields(fields_with_values)
136147

137-
def format_turn(self, signature: Type[Signature], values: dict[str, Any], role: str, incomplete: bool = False, is_conversation_history: bool = False) -> dict[str, Any]:
148+
def format_turn(
149+
self,
150+
signature: Type[Signature],
151+
values: dict[str, Any],
152+
role: str,
153+
incomplete: bool = False,
154+
is_conversation_history: bool = False,
155+
) -> dict[str, Any]:
138156
return format_turn(signature, values, role, incomplete, is_conversation_history)
139157

140158

@@ -158,7 +176,9 @@ def format_fields(fields_with_values: Dict[FieldInfoWithName, Any]) -> str:
158176
return "\n\n".join(output).strip()
159177

160178

161-
def format_turn(signature: Type[Signature], values: dict[str, Any], role: str, incomplete=False, is_conversation_history=False):
179+
def format_turn(
180+
signature: Type[Signature], values: dict[str, Any], role: str, incomplete=False, is_conversation_history=False
181+
):
162182
"""
163183
Constructs a new message ("turn") to append to a chat thread. The message is carefully formatted
164184
so that it can instruct an LLM to generate responses conforming to the specified DSPy signature.

dspy/adapters/json_adapter.py

Lines changed: 28 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,11 @@
1313
from pydantic.fields import FieldInfo
1414

1515
from dspy.adapters.base import Adapter
16-
from dspy.adapters.types.image import try_expand_image_tags
1716
from dspy.adapters.types.history import History
17+
from dspy.adapters.types.image import try_expand_image_tags
1818
from dspy.adapters.utils import format_field_value, get_annotation_name, parse_value, serialize_for_json
1919
from dspy.clients.lm import LM
20-
from dspy.signatures.signature import SignatureMeta, Signature
20+
from dspy.signatures.signature import Signature, SignatureMeta
2121
from dspy.signatures.utils import get_dspy_field_type
2222

2323
logger = logging.getLogger(__name__)
@@ -27,11 +27,19 @@ class FieldInfoWithName(NamedTuple):
2727
name: str
2828
info: FieldInfo
2929

30+
3031
class JSONAdapter(Adapter):
3132
def __init__(self):
3233
pass
3334

34-
def __call__(self, lm: LM, lm_kwargs: dict[str, Any], signature: Type[Signature], demos: list[dict[str, Any]], inputs: dict[str, Any]) -> list[dict[str, Any]]:
35+
def __call__(
36+
self,
37+
lm: LM,
38+
lm_kwargs: dict[str, Any],
39+
signature: Type[Signature],
40+
demos: list[dict[str, Any]],
41+
inputs: dict[str, Any],
42+
) -> list[dict[str, Any]]:
3543
inputs = self.format(signature, demos, inputs)
3644
inputs = dict(prompt=inputs) if isinstance(inputs, str) else dict(messages=inputs)
3745

@@ -66,7 +74,9 @@ def __call__(self, lm: LM, lm_kwargs: dict[str, Any], signature: Type[Signature]
6674

6775
return values
6876

69-
def format(self, signature: Type[Signature], demos: list[dict[str, Any]], inputs: dict[str, Any]) -> list[dict[str, Any]]:
77+
def format(
78+
self, signature: Type[Signature], demos: list[dict[str, Any]], inputs: dict[str, Any]
79+
) -> list[dict[str, Any]]:
7080
messages = []
7181

7282
# Extract demos where some of the output_fields are not filled in.
@@ -118,11 +128,20 @@ def format_fields(self, signature: Type[Signature], values: dict[str, Any], role
118128
if field_name in values
119129
}
120130
return format_fields(role=role, fields_with_values=fields_with_values)
121-
122-
def format_turn(self, signature: Type[Signature], values, role: str, incomplete: bool = False, is_conversation_history: bool = False) -> dict[str, Any]:
131+
132+
def format_turn(
133+
self,
134+
signature: Type[Signature],
135+
values,
136+
role: str,
137+
incomplete: bool = False,
138+
is_conversation_history: bool = False,
139+
) -> dict[str, Any]:
123140
return format_turn(signature, values, role, incomplete, is_conversation_history)
124-
125-
def format_finetune_data(self, signature: Type[Signature], demos: list[dict[str, Any]], inputs: dict[str, Any], outputs: dict[str, Any]) -> dict[str, list[Any]]:
141+
142+
def format_finetune_data(
143+
self, signature: Type[Signature], demos: list[dict[str, Any]], inputs: dict[str, Any], outputs: dict[str, Any]
144+
) -> dict[str, list[Any]]:
126145
# TODO: implement format_finetune_data method in JSONAdapter
127146
raise NotImplementedError
128147

@@ -136,7 +155,7 @@ def format_fields(role: str, fields_with_values: Dict[FieldInfoWithName, Any]) -
136155
Args:
137156
role: The role of the message ('user' or 'assistant')
138157
fields_with_values: A dictionary mapping information about a field to its corresponding value.
139-
158+
140159
Returns:
141160
The joined formatted values of the fields, represented as a string.
142161
"""

0 commit comments

Comments
 (0)