Skip to content

Commit 5e57bb2

Browse files
authored
Merge pull request #1256 from stanfordnlp/chat_templates_v2
Chat templates v2
2 parents 7b13a02 + 8797a23 commit 5e57bb2

File tree

19 files changed

+710
-529
lines changed

19 files changed

+710
-529
lines changed

dsp/__init__.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1-
from .modules import *
2-
from .primitives import *
3-
from .templates import *
4-
from .utils import settings
1+
from .modules import * # noqa
2+
from .primitives import * # noqa
3+
from .adapters import * # noqa
4+
from .utils import settings # noqa
55

66
"""
77
TODO:

dsp/adapters/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
from .base_template import * # noqa
2+
from .template import * # noqa
3+
from .experimental_adapter import * # noqa
4+
from .utils import * # noqa

dsp/templates/template_v3.py renamed to dsp/adapters/base_template.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
1+
from collections import namedtuple
12
from typing import Callable
23

3-
from dsp.templates import Field, TemplateV2, format_answers, passages2text
4+
from .utils import format_answers, passages2text
5+
6+
Field = namedtuple("Field", "name separator input_variable output_variable description")
47

58

69
class Type:
@@ -19,7 +22,7 @@ def __eq__(self, __value: object) -> bool:
1922
return isinstance(__value, Type) and self.__dict__ == __value.__dict__
2023

2124

22-
class Template(TemplateV2):
25+
class BaseTemplate:
2326
"""A template datatype that represents the structure of communicate with the LM."""
2427

2528
def __init__(self, instructions: str, **kwargs):
@@ -61,9 +64,7 @@ def __eq__(self, other):
6164
v1, v2 = self.kwargs[k], other.kwargs[k]
6265
if not v1 == v2:
6366
print(k, v1, v2)
64-
6567

66-
# print("here?", self.instructions == other.instructions, self.kwargs == other.kwargs)
6768
return self.instructions == other.instructions and self.kwargs == other.kwargs
6869

6970
def __str__(self) -> str:
Lines changed: 202 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,202 @@
1+
from typing import Any, Union
2+
3+
import dsp
4+
from dsp.primitives.demonstrate import Example
5+
6+
from .base_template import BaseTemplate
7+
8+
9+
class ExperimentalAdapter(BaseTemplate):
10+
def query(self, example: Example, is_demo: bool = False) -> str:
11+
"""Retrieves the input variables from the example and formats them into a query string."""
12+
result: list[str] = []
13+
14+
# If not a demo, find the last field that doesn't have a value set in `example` and set it to ""
15+
# This creates the "Output:" prefix at the end of the prompt.
16+
if not is_demo:
17+
has_value = [
18+
field.input_variable in example
19+
and example[field.input_variable] is not None
20+
and example[field.input_variable] != ""
21+
for field in self.fields
22+
]
23+
24+
if not any(has_value):
25+
assert False, "No input variables found in the example"
26+
27+
for i in range(1, len(has_value)):
28+
if has_value[i - 1] and not any(has_value[i:]):
29+
example[self.fields[i].input_variable] = ""
30+
break
31+
32+
for field in self.fields:
33+
if field.input_variable in example and example[field.input_variable] is not None:
34+
if field.input_variable in self.format_handlers:
35+
format_handler = self.format_handlers[field.input_variable]
36+
else:
37+
def format_handler(x):
38+
return str(x).strip()
39+
40+
formatted_value = format_handler(example[field.input_variable])
41+
separator = "\n" if field.separator == " " and "\n" in formatted_value else field.separator
42+
43+
result.append(f"{field.name}{separator}{formatted_value}",)
44+
45+
return "\n\n".join([r for r in result if r])
46+
47+
def guidelines(self, show_guidelines=True) -> str:
48+
"""Returns the task guidelines as described in the lm prompt"""
49+
if (not show_guidelines) or (hasattr(dsp.settings, "show_guidelines") and not dsp.settings.show_guidelines):
50+
return ""
51+
52+
result = "Follow the following format.\n\n"
53+
54+
example = dsp.Example()
55+
for field in self.fields:
56+
example[field.input_variable] = field.description
57+
example.augmented = True
58+
59+
result += self.query(example)
60+
return result
61+
62+
def extract(
63+
self,
64+
example: Union[Example, dict[str, Any]],
65+
raw_pred: str,
66+
) -> Example:
67+
"""Extracts the answer from the LM raw prediction using the template structure
68+
69+
Args:
70+
example (Union[Example, dict[str, Any]]): Contains the input variables that raw_pred was completed on.
71+
raw_pred (str): LM generated string
72+
73+
Returns:
74+
Example: The example with the output variables filled in
75+
"""
76+
example = dsp.Example(example)
77+
78+
raw_pred = raw_pred.strip()
79+
parts = raw_pred.split('\n')
80+
adjusted_parts = []
81+
for part in parts:
82+
trimmed_part = part.strip()
83+
if trimmed_part:
84+
if adjusted_parts:
85+
adjusted_parts.append('\n' + trimmed_part)
86+
else:
87+
adjusted_parts.append(trimmed_part)
88+
raw_pred = '\n'.join(adjusted_parts)
89+
90+
idx = 0
91+
while idx < len(self.fields):
92+
if self.fields[idx].input_variable not in example or example[self.fields[idx].input_variable] is None:
93+
break
94+
idx += 1
95+
96+
import dspy
97+
98+
idx = min(idx, len(self.fields) - 1)
99+
while raw_pred != "" and idx < len(self.fields):
100+
if idx < len(self.fields) - 1:
101+
next_field_name = "\n" + self.fields[idx + 1].name
102+
offset = raw_pred.find(next_field_name)
103+
104+
if offset >= 0:
105+
if dspy.settings.release >= 20231003:
106+
example[self.fields[idx].output_variable] = raw_pred[:offset].strip().rstrip("---").strip()
107+
raw_pred = raw_pred[offset + len(next_field_name) :].strip().rstrip("---").strip()
108+
else:
109+
field_name_parts = self.fields[idx].name.split()
110+
start_pos = 0
111+
for part in field_name_parts:
112+
pos = raw_pred.find(part.strip())
113+
if pos != -1:
114+
start_pos = pos + len(part)
115+
else:
116+
break
117+
118+
example[self.fields[idx].output_variable] = raw_pred[start_pos:offset].strip().rstrip("---").strip()
119+
raw_pred = raw_pred[offset + len(next_field_name) :].strip()
120+
idx += 1
121+
else:
122+
example[self.fields[idx].output_variable] = raw_pred.strip().rstrip("---").strip()
123+
124+
raw_pred = ""
125+
idx += 1
126+
break
127+
128+
else:
129+
assert idx == len(self.fields) - 1, (idx, len(self.fields))
130+
131+
if dspy.settings.release >= 20231003:
132+
example[self.fields[idx].output_variable] = raw_pred.strip().rstrip("---").strip()
133+
else:
134+
field_name_parts = self.fields[idx].name.split()
135+
start_pos = 0
136+
for part in field_name_parts:
137+
pos = raw_pred.find(part.strip())
138+
if pos != -1:
139+
start_pos = pos + len(part)
140+
else:
141+
break
142+
example[self.fields[idx].output_variable] = raw_pred[start_pos:].strip()
143+
144+
break
145+
146+
return example
147+
148+
def __call__(self, example, show_guidelines=True) -> str:
149+
example = dsp.Example(example)
150+
output_fields = []
151+
for i in range(len(self.fields)):
152+
if self.fields[i].input_variable not in example:
153+
output_field = self.fields[i].input_variable
154+
if output_field not in output_fields:
155+
output_fields.append(self.fields[i].name.split(':')[0])
156+
157+
if hasattr(dsp.settings, "query_only") and dsp.settings.query_only:
158+
return self.query(example)
159+
160+
# The training data should not contain the output variable
161+
assert self.fields[-1].input_variable not in example, f"Output variable {self.fields[-1].input_variable} should not be supplied for querying the LM."
162+
# del example[self.fields[-1].input_variable]
163+
164+
rdemos = [
165+
self.query(demo, is_demo=True)
166+
for demo in example.demos
167+
if (
168+
(not demo.get("augmented", False))
169+
and ( # validate that the training example has the same primitive input var as the template
170+
self.fields[-1].input_variable in demo and demo[self.fields[-1].input_variable] is not None
171+
)
172+
)
173+
]
174+
175+
ademos = [self.query(demo, is_demo=True) for demo in example.demos if demo.get("augmented", False)]
176+
177+
# Move the rdemos to ademos if rdemo has all the fields filled in
178+
rdemos_ = []
179+
new_ademos = []
180+
for rdemo in rdemos:
181+
if all((field.name in rdemo) for field in self.fields if field.input_variable in example):
182+
new_ademos.append(rdemo)
183+
else:
184+
rdemos_.append(rdemo)
185+
186+
ademos = new_ademos + ademos
187+
rdemos = rdemos_
188+
189+
example["augmented"] = True
190+
191+
query = self.query(example)
192+
parts = [self.instructions, *rdemos, self.guidelines(show_guidelines), *ademos, query,]
193+
194+
prompt = "\n\n---\n\n".join([p.strip() for p in parts if p])
195+
prompt_ = prompt[: prompt.rfind("\n")].strip()
196+
197+
s_or_not = "s" if len(output_fields) > 1 else ""
198+
only_or_not = "only " if len(output_fields) == 1 else ""
199+
200+
prompt_ += f"\n\nPlease provide the output field{s_or_not} {', '.join(output_fields[:-1]) + (', then ' if len(output_fields) > 2 else ' then ') + output_fields[-1] if len(output_fields) > 1 else output_fields[0]}. Do so immediately, without additional content before or after, and precisely as the format above shows. Begin with {only_or_not}the field {output_fields[0]}."
201+
return prompt_.strip()
202+

dsp/templates/template_v2.py renamed to dsp/adapters/template.py

Lines changed: 2 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -1,73 +1,12 @@
1-
import re
2-
from collections import namedtuple
31
from typing import Any, Union
42

53
import dsp
64
from dsp.primitives.demonstrate import Example
75

8-
from .utils import format_answers, passages2text
6+
from .base_template import BaseTemplate
97

10-
Field = namedtuple("Field", "name separator input_variable output_variable description")
11-
12-
# TODO: de-duplicate with dsp/templates/template.py
13-
14-
15-
class TemplateV2:
16-
def __init__(
17-
self,
18-
template,
19-
format_handlers={
20-
"passages": passages2text,
21-
"context": passages2text,
22-
"answer": format_answers,
23-
"answers": format_answers,
24-
},
25-
):
26-
self.format_handlers = format_handlers
27-
28-
template = template.strip()
29-
30-
self.instructions = re.search("(.*)\n", template).group(1)
31-
template = template[len(self.instructions) :].strip()
32-
33-
self.fields = []
34-
while len(template) > 0:
35-
match = re.search("(.*)(\s){(.*)}\s(.*\${.*})", template)
36-
if match is not None:
37-
name = match.group(1)
38-
separator = match.group(2)
39-
variable = match.group(3)
40-
description = match.group(4)
41-
else:
42-
match = re.search("(.*)(\s){(.*)}", template)
43-
if match is not None:
44-
name = match.group(1)
45-
separator = match.group(2)
46-
variable = match.group(3)
47-
description = None
48-
else:
49-
raise ValueError("Could not parse template")
50-
51-
var_match = re.match("(.*) -> (.*)", variable)
52-
if var_match is not None:
53-
input_variable = var_match.group(1)
54-
output_variable = var_match.group(2)
55-
else:
56-
input_variable = variable
57-
output_variable = variable
58-
59-
self.fields.append(
60-
Field(
61-
name=name,
62-
separator=separator,
63-
input_variable=input_variable,
64-
output_variable=output_variable,
65-
description=description,
66-
),
67-
)
68-
69-
template = template[len(match.group(0)) :].strip()
708

9+
class Template(BaseTemplate):
7110
def query(self, example: Example, is_demo: bool = False) -> str:
7211
"""Retrieves the input variables from the example and formats them into a query string."""
7312
result: list[str] = []

dsp/templates/utils.py renamed to dsp/adapters/utils.py

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -17,29 +17,29 @@ def passages2text(passages: Union[str, list, tuple]) -> str:
1717
return "\n".join([f"[{idx+1}] «{txt}»" for idx, txt in enumerate(passages)])
1818

1919

20-
def passages2textV2(passages: Union[str, list, tuple]) -> str:
21-
"""Formats the given one or more passages into a single structured string."""
22-
if isinstance(passages, str):
23-
return passages
24-
25-
assert type(passages) in [list, tuple]
26-
27-
def psg2text(psg):
28-
try:
29-
title, snippet = psg.split("|", 1)
30-
return f"Title: {title.strip()} | Snippet: «{snippet.strip()}»"
31-
except Exception:
32-
pass
20+
# def passages2textV2(passages: Union[str, list, tuple]) -> str:
21+
# """Formats the given one or more passages into a single structured string."""
22+
# if isinstance(passages, str):
23+
# return passages
24+
25+
# assert type(passages) in [list, tuple]
26+
27+
# def psg2text(psg):
28+
# try:
29+
# title, snippet = psg.split("|", 1)
30+
# return f"Title: {title.strip()} | Snippet: «{snippet.strip()}»"
31+
# except Exception:
32+
# pass
3333

34-
return f"«{psg}»"
34+
# return f"«{psg}»"
3535

36-
if len(passages) == 0:
37-
return "N/A"
36+
# if len(passages) == 0:
37+
# return "N/A"
3838

39-
if len(passages) == 1:
40-
return psg2text(passages[0])
39+
# if len(passages) == 1:
40+
# return psg2text(passages[0])
4141

42-
return "\n".join([f"[{idx+1}] {psg2text(txt)}" for idx, txt in enumerate(passages)])
42+
# return "\n".join([f"[{idx+1}] {psg2text(txt)}" for idx, txt in enumerate(passages)])
4343

4444

4545
def format_answers(answers: Union[str, list]) -> Optional[str]:

0 commit comments

Comments
 (0)