Skip to content

Commit cc368f8

Browse files
authored
Merge pull request #1595 from mikeedjones/feat/remove-get_convo-asserts
Remove get_convo asserts from tests
2 parents d403737 + ead37b6 commit cc368f8

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

49 files changed

+2863
-734
lines changed

dsp/modules/dummy_lm.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66

77
# This testing module was moved in PR #735 to patch Arize Phoenix logging
8-
class DummyLM(LM):
8+
class DSPDummyLM(LM):
99
"""Dummy language model for unit testing purposes."""
1010

1111
def __init__(self, answers: Union[list[str], dict[str, str]], follow_examples: bool = False):
@@ -61,7 +61,7 @@ def basic_request(self, prompt, n=1, **kwargs) -> dict[str, list[dict[str, str]]
6161
},
6262
)
6363

64-
RED, GREEN, RESET = "\033[91m", "\033[92m", "\033[0m"
64+
RED, _, RESET = "\033[91m", "\033[92m", "\033[0m"
6565
print("=== DummyLM ===")
6666
print(prompt, end="")
6767
print(f"{RED}{answer}{RESET}")

dsp/utils/settings.py

Lines changed: 25 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,30 @@
11
import threading
2+
from copy import deepcopy
23
from contextlib import contextmanager
34

45
from dsp.utils.utils import dotdict
56

7+
DEFAULT_CONFIG = dotdict(
8+
lm=None,
9+
adapter=None,
10+
rm=None,
11+
branch_idx=0,
12+
reranker=None,
13+
compiled_lm=None,
14+
force_reuse_cached_compilation=False,
15+
compiling=False,
16+
skip_logprobs=False,
17+
trace=[],
18+
release=0,
19+
bypass_assert=False,
20+
bypass_suggest=False,
21+
assert_failures=0,
22+
suggest_failures=0,
23+
langchain_history=[],
24+
experimental=False,
25+
backoff_time=10,
26+
)
27+
628

729
class Settings:
830
"""DSP configuration settings."""
@@ -25,27 +47,9 @@ def __new__(cls):
2547
# TODO: remove first-class support for re-ranker and potentially combine with RM to form a pipeline of sorts
2648
# eg: RetrieveThenRerankPipeline(RetrievalModel, Reranker)
2749
# downstream operations like dsp.retrieve would use configs from the defined pipeline.
28-
config = dotdict(
29-
lm=None,
30-
adapter=None,
31-
rm=None,
32-
branch_idx=0,
33-
reranker=None,
34-
compiled_lm=None,
35-
force_reuse_cached_compilation=False,
36-
compiling=False, # TODO: can probably be removed
37-
skip_logprobs=False,
38-
trace=[],
39-
release=0,
40-
bypass_assert=False,
41-
bypass_suggest=False,
42-
assert_failures=0,
43-
suggest_failures=0,
44-
langchain_history=[],
45-
experimental=False,
46-
backoff_time = 10
47-
)
48-
cls._instance.__append(config)
50+
51+
# make a deepcopy of the default config to avoid modifying the default config
52+
cls._instance.__append(deepcopy(DEFAULT_CONFIG))
4953

5054
return cls._instance
5155

dspy/adapters/chat_adapter.py

Lines changed: 59 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
1-
import re
21
import ast
32
import json
3+
import re
44
import textwrap
5+
from typing import get_args, get_origin
56

6-
from pydantic import TypeAdapter
77
import pydantic
8+
from pydantic import TypeAdapter
9+
810
from .base import Adapter
9-
from typing import get_origin, get_args
1011

11-
field_header_pattern = re.compile(r'\[\[ ## (\w+) ## \]\]')
12+
field_header_pattern = re.compile(r"\[\[ ## (\w+) ## \]\]")
1213

1314

1415
class ChatAdapter(Adapter):
@@ -21,9 +22,11 @@ def format(self, signature, demos, inputs):
2122
# Extract demos where some of the output_fields are not filled in.
2223
incomplete_demos = [demo for demo in demos if not all(k in demo for k in signature.fields)]
2324
complete_demos = [demo for demo in demos if demo not in incomplete_demos]
24-
incomplete_demos = [demo for demo in incomplete_demos \
25-
if any(k in demo for k in signature.input_fields) and \
26-
any(k in demo for k in signature.output_fields)]
25+
incomplete_demos = [
26+
demo
27+
for demo in incomplete_demos
28+
if any(k in demo for k in signature.input_fields) and any(k in demo for k in signature.output_fields)
29+
]
2730

2831
demos = incomplete_demos + complete_demos
2932

@@ -32,44 +35,52 @@ def format(self, signature, demos, inputs):
3235
for demo in demos:
3336
messages.append(format_turn(signature, demo, role="user", incomplete=demo in incomplete_demos))
3437
messages.append(format_turn(signature, demo, role="assistant", incomplete=demo in incomplete_demos))
35-
38+
3639
messages.append(format_turn(signature, inputs, role="user"))
3740

3841
return messages
39-
42+
4043
def parse(self, signature, completion, _parse_values=True):
4144
sections = [(None, [])]
4245

4346
for line in completion.splitlines():
4447
match = field_header_pattern.match(line.strip())
45-
if match: sections.append((match.group(1), []))
46-
else: sections[-1][1].append(line)
48+
if match:
49+
sections.append((match.group(1), []))
50+
else:
51+
sections[-1][1].append(line)
4752

48-
sections = [(k, '\n'.join(v).strip()) for k, v in sections]
53+
sections = [(k, "\n".join(v).strip()) for k, v in sections]
4954

5055
fields = {}
5156
for k, v in sections:
5257
if (k not in fields) and (k in signature.output_fields):
5358
try:
5459
fields[k] = parse_value(v, signature.output_fields[k].annotation) if _parse_values else v
5560
except Exception as e:
56-
raise ValueError(f"Error parsing field {k}: {e}.\n\n\t\tOn attempting to parse the value\n```\n{v}\n```")
61+
raise ValueError(
62+
f"Error parsing field {k}: {e}.\n\n\t\tOn attempting to parse the value\n```\n{v}\n```"
63+
)
5764

5865
if fields.keys() != signature.output_fields.keys():
5966
raise ValueError(f"Expected {signature.output_fields.keys()} but got {fields.keys()}")
6067

6168
return fields
6269

70+
6371
def format_blob(blob):
64-
if '\n' not in blob and "«" not in blob and "»" not in blob: return f"«{blob}»"
72+
if "\n" not in blob and "«" not in blob and "»" not in blob:
73+
return f"«{blob}»"
6574

66-
modified_blob = blob.replace('\n', '\n ')
75+
modified_blob = blob.replace("\n", "\n ")
6776
return f"«««\n {modified_blob}\n»»»"
6877

6978

7079
def format_list(items):
71-
if len(items) == 0: return "N/A"
72-
if len(items) == 1: return format_blob(items[0])
80+
if len(items) == 0:
81+
return "N/A"
82+
if len(items) == 1:
83+
return format_blob(items[0])
7384

7485
return "\n".join([f"[{idx+1}] {format_blob(txt)}" for idx, txt in enumerate(items)])
7586

@@ -89,82 +100,90 @@ def format_fields(fields):
89100
v = _format_field_value(v)
90101
output.append(f"[[ ## {k} ## ]]\n{v}")
91102

92-
return '\n\n'.join(output).strip()
93-
103+
return "\n\n".join(output).strip()
104+
94105

95106
def parse_value(value, annotation):
96-
if annotation is str: return str(value)
107+
if annotation is str:
108+
return str(value)
97109
parsed_value = value
98110
if isinstance(value, str):
99-
try: parsed_value = json.loads(value)
111+
try:
112+
parsed_value = json.loads(value)
100113
except json.JSONDecodeError:
101-
try: parsed_value = ast.literal_eval(value)
102-
except (ValueError, SyntaxError): parsed_value = value
114+
try:
115+
parsed_value = ast.literal_eval(value)
116+
except (ValueError, SyntaxError):
117+
parsed_value = value
103118
return TypeAdapter(annotation).validate_python(parsed_value)
104119

105120

106-
def format_turn(signature, values, role, incomplete=False):
121+
def format_turn(signature, values, role, incomplete=False):
107122
content = []
108123

109124
if role == "user":
110125
field_names = signature.input_fields.keys()
111126
if incomplete:
112127
content.append("This is an example of the task, though some input or output fields are not supplied.")
113128
else:
114-
field_names, values = list(signature.output_fields.keys()) + ['completed'], {**values, 'completed': ''}
129+
field_names, values = list(signature.output_fields.keys()) + ["completed"], {**values, "completed": ""}
115130

116131
if not incomplete:
117132
if not set(values).issuperset(set(field_names)):
118133
raise ValueError(f"Expected {field_names} but got {values.keys()}")
119-
134+
120135
content.append(format_fields({k: values.get(k, "Not supplied for this particular example.") for k in field_names}))
121136

122137
if role == "user":
123-
content.append("Respond with the corresponding output fields, starting with the field " +
124-
", then ".join(f"`{f}`" for f in signature.output_fields) +
125-
", and then ending with the marker for `completed`.")
138+
content.append(
139+
"Respond with the corresponding output fields, starting with the field "
140+
+ ", then ".join(f"`{f}`" for f in signature.output_fields)
141+
+ ", and then ending with the marker for `completed`."
142+
)
126143

127-
return {"role": role, "content": '\n\n'.join(content).strip()}
144+
return {"role": role, "content": "\n\n".join(content).strip()}
128145

129146

130147
def get_annotation_name(annotation):
131148
origin = get_origin(annotation)
132149
args = get_args(annotation)
133150
if origin is None:
134-
if hasattr(annotation, '__name__'):
151+
if hasattr(annotation, "__name__"):
135152
return annotation.__name__
136153
else:
137154
return str(annotation)
138155
else:
139-
args_str = ', '.join(get_annotation_name(arg) for arg in args)
140-
return f"{origin.__name__}[{args_str}]"
156+
args_str = ", ".join(get_annotation_name(arg) for arg in args)
157+
return f"{get_annotation_name(origin)}[{args_str}]"
158+
141159

142160
def enumerate_fields(fields):
143161
parts = []
144162
for idx, (k, v) in enumerate(fields.items()):
145163
parts.append(f"{idx+1}. `{k}`")
146164
parts[-1] += f" ({get_annotation_name(v.annotation)})"
147-
parts[-1] += f": {v.json_schema_extra['desc']}" if v.json_schema_extra['desc'] != f'${{{k}}}' else ''
165+
parts[-1] += f": {v.json_schema_extra['desc']}" if v.json_schema_extra["desc"] != f"${{{k}}}" else ""
166+
167+
return "\n".join(parts).strip()
148168

149-
return '\n'.join(parts).strip()
150169

151170
def prepare_instructions(signature):
152171
parts = []
153172
parts.append("Your input fields are:\n" + enumerate_fields(signature.input_fields))
154173
parts.append("Your output fields are:\n" + enumerate_fields(signature.output_fields))
155174
parts.append("All interactions will be structured in the following way, with the appropriate values filled in.")
156175

157-
parts.append(format_fields({f : f"{{{f}}}" for f in signature.input_fields}))
158-
parts.append(format_fields({f : f"{{{f}}}" for f in signature.output_fields}))
159-
parts.append(format_fields({'completed' : ""}))
176+
parts.append(format_fields({f: f"{{{f}}}" for f in signature.input_fields}))
177+
parts.append(format_fields({f: f"{{{f}}}" for f in signature.output_fields}))
178+
parts.append(format_fields({"completed": ""}))
160179

161180
instructions = textwrap.dedent(signature.instructions)
162-
objective = ('\n' + ' ' * 8).join([''] + instructions.splitlines())
181+
objective = ("\n" + " " * 8).join([""] + instructions.splitlines())
163182
parts.append(f"In adhering to this structure, your objective is: {objective}")
164183

165184
# parts.append("You will receive some input fields in each interaction. " +
166185
# "Respond only with the corresponding output fields, starting with the field " +
167186
# ", then ".join(f"`{f}`" for f in signature.output_fields) +
168187
# ", and then ending with the marker for `completed`.")
169188

170-
return '\n\n'.join(parts).strip()
189+
return "\n\n".join(parts).strip()

0 commit comments

Comments
 (0)