Skip to content

Commit 5a2580b

Browse files
Support Pydantic field constraint in DSPy (#7980)
* added the ability to process metadata feilds and add them directly to the prompt and added docstring throughtout the file * quick fix to get_annotation_name * fixed some small formating issues * adapted the chat adapter to add metadata processing * added the metadata formatting to the json adapter * addressed Zachs comments * restart * add field constraint * change format --------- Co-authored-by: gilad12-coder <[email protected]> Co-authored-by: Gilad Morad <[email protected]>
1 parent ff2fa71 commit 5a2580b

File tree

5 files changed

+103
-3
lines changed

5 files changed

+103
-3
lines changed

dspy/adapters/chat_adapter.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -257,7 +257,9 @@ def enumerate_fields(fields: dict) -> str:
257257
parts.append(f"{idx + 1}. `{k}`")
258258
parts[-1] += f" ({get_annotation_name(v.annotation)})"
259259
parts[-1] += f": {v.json_schema_extra['desc']}" if v.json_schema_extra["desc"] != f"${{{k}}}" else ""
260-
260+
parts[-1] += (
261+
f"\nConstraints: {v.json_schema_extra['constraints']}" if v.json_schema_extra.get("constraints") else ""
262+
)
261263
return "\n".join(parts).strip()
262264

263265

dspy/adapters/json_adapter.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -252,6 +252,9 @@ def enumerate_fields(fields):
252252
parts.append(f"{idx+1}. `{k}`")
253253
parts[-1] += f" ({get_annotation_name(v.annotation)})"
254254
parts[-1] += f": {v.json_schema_extra['desc']}" if v.json_schema_extra["desc"] != f"${{{k}}}" else ""
255+
parts[-1] += (
256+
f"\nConstraints: {v.json_schema_extra['constraints']}" if v.json_schema_extra.get("constraints") else ""
257+
)
255258

256259
return "\n".join(parts).strip()
257260

dspy/signatures/field.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,21 @@
11
import pydantic
2+
23
# The following arguments can be used in DSPy InputField and OutputField in addition
34
# to the standard pydantic.Field arguments. We just hope pydanitc doesn't add these,
45
# as it would give a name clash.
56
DSPY_FIELD_ARG_NAMES = ["desc", "prefix", "format", "parser", "__dspy_field_type"]
67

8+
PYDANTIC_CONSTRAINT_MAP = {
9+
"gt": "greater than: ",
10+
"ge": "greater than or equal to: ",
11+
"lt": "less than: ",
12+
"le": "less than or equal to: ",
13+
"min_length": "minimum length: ",
14+
"max_length": "maximum length: ",
15+
"multiple_of": "a multiple of the given number: ",
16+
"allow_inf_nan": "allow 'inf', '-inf', 'nan' values: ",
17+
}
18+
719

820
def move_kwargs(**kwargs):
921
# Pydantic doesn't allow arbitrary arguments to be given to fields,
@@ -21,10 +33,24 @@ def move_kwargs(**kwargs):
2133
# Also copy over the pydantic "description" if no dspy "desc" is given.
2234
if "description" in kwargs and "desc" not in json_schema_extra:
2335
json_schema_extra["desc"] = kwargs["description"]
36+
constraints = _translate_pydantic_field_constraints(**kwargs)
37+
if constraints:
38+
json_schema_extra["constraints"] = constraints
2439
pydantic_kwargs["json_schema_extra"] = json_schema_extra
2540
return pydantic_kwargs
2641

2742

43+
def _translate_pydantic_field_constraints(**kwargs):
44+
"""Extracts Pydantic constraints and translates them into human-readable format."""
45+
46+
constraints = []
47+
for key, value in kwargs.items():
48+
if key in PYDANTIC_CONSTRAINT_MAP:
49+
constraints.append(f"{PYDANTIC_CONSTRAINT_MAP[key]}{value}")
50+
51+
return ", ".join(constraints)
52+
53+
2854
def InputField(**kwargs):
2955
return pydantic.Field(**move_kwargs(**kwargs, __dspy_field_type="input"))
3056

tests/predict/test_predict.py

Lines changed: 53 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def test_reset_method():
3636
def test_lm_after_dump_and_load_state():
3737
predict_instance = Predict("input -> output")
3838
lm = dspy.LM(
39-
model="openai/gpt-4o-mini",
39+
model="openai/gpt-4o-mini",
4040
model_type="chat",
4141
temperature=1,
4242
max_tokens=100,
@@ -229,11 +229,12 @@ class CustomSignature2(dspy.Signature):
229229
new_instance.load(file_path)
230230
assert new_instance.signature.dump_state() == original_instance.signature.dump_state()
231231

232+
232233
@pytest.mark.parametrize("filename", ["model.json", "model.pkl"])
233234
def test_lm_field_after_dump_and_load_state(tmp_path, filename):
234235
file_path = tmp_path / filename
235236
lm = dspy.LM(
236-
model="openai/gpt-4o-mini",
237+
model="openai/gpt-4o-mini",
237238
model_type="chat",
238239
temperature=1,
239240
max_tokens=100,
@@ -487,3 +488,53 @@ class MySignature(dspy.Signature):
487488
assert "what's the capital of france?" in messages[1]["content"]
488489
assert "paris" in messages[2]["content"]
489490
assert "are you sure that's correct" in messages[3]["content"]
491+
492+
493+
@pytest.mark.parametrize("adapter_type", ["chat", "json"])
494+
def test_field_constraints(adapter_type):
495+
class SpyLM(dspy.LM):
496+
def __init__(self, *args, return_json=False, **kwargs):
497+
super().__init__(*args, **kwargs)
498+
self.calls = []
499+
self.return_json = return_json
500+
501+
def __call__(self, prompt=None, messages=None, **kwargs):
502+
self.calls.append({"prompt": prompt, "messages": messages, "kwargs": kwargs})
503+
if self.return_json:
504+
return ["{'score':'0.5', 'count':'2'}"]
505+
return ["[[ ## score ## ]]\n0.5\n[[ ## count ## ]]\n2"]
506+
507+
class ConstrainedSignature(dspy.Signature):
508+
"""Test signature with constrained fields."""
509+
510+
# Input with length and value constraints
511+
text: str = dspy.InputField(min_length=5, max_length=100, desc="Input text")
512+
number: int = dspy.InputField(gt=0, lt=10, desc="A number between 0 and 10")
513+
514+
# Output with multiple constraints
515+
score: float = dspy.OutputField(ge=0.0, le=1.0, desc="Score between 0 and 1")
516+
count: int = dspy.OutputField(multiple_of=2, desc="Even number count")
517+
518+
program = Predict(ConstrainedSignature)
519+
lm = SpyLM("dummy_model")
520+
if adapter_type == "chat":
521+
lm = SpyLM("dummy_model")
522+
dspy.settings.configure(adapter=dspy.ChatAdapter(), lm=lm)
523+
else:
524+
lm = SpyLM("dummy_model", return_json=True)
525+
dspy.settings.configure(adapter=dspy.JSONAdapter(), lm=lm)
526+
527+
# Call the predictor to trigger instruction generation
528+
program(text="hello world", number=5)
529+
530+
# Get the system message containing the instructions
531+
system_message = lm.calls[0]["messages"][0]["content"]
532+
533+
# Verify constraints are included in the field descriptions
534+
assert "minimum length: 5" in system_message
535+
assert "maximum length: 100" in system_message
536+
assert "greater than: 0" in system_message
537+
assert "less than: 10" in system_message
538+
assert "greater than or equal to: 0.0" in system_message
539+
assert "less than or equal to: 1.0" in system_message
540+
assert "a multiple of the given number: 2" in system_message

tests/signatures/test_signature.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -391,3 +391,21 @@ def test_make_signature_from_string():
391391
assert sig.output_fields["output1"].annotation == List[str]
392392
assert "output2" in sig.output_fields
393393
assert sig.output_fields["output2"].annotation == Union[int, str]
394+
395+
396+
def test_signature_field_with_constraints():
397+
class MySignature(Signature):
398+
inputs: str = InputField()
399+
outputs1: str = OutputField(min_length=5, max_length=10)
400+
outputs2: int = OutputField(ge=5, le=10)
401+
402+
assert "outputs1" in MySignature.output_fields
403+
output1_constraints = MySignature.output_fields["outputs1"].json_schema_extra["constraints"]
404+
405+
assert "minimum length: 5" in output1_constraints
406+
assert "maximum length: 10" in output1_constraints
407+
408+
assert "outputs2" in MySignature.output_fields
409+
output2_constraints = MySignature.output_fields["outputs2"].json_schema_extra["constraints"]
410+
assert "greater than or equal to: 5" in output2_constraints
411+
assert "less than or equal to: 10" in output2_constraints

0 commit comments

Comments
 (0)