Skip to content

Commit ab23dc0

Browse files
authored
Enhance type hints of adapters (#7896)
* enhance type hints of adapters
1 parent 00417ae commit ab23dc0

File tree

5 files changed

+79
-71
lines changed

5 files changed

+79
-71
lines changed

dspy/adapters/base.py

Lines changed: 31 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
from abc import ABC, abstractmethod
2-
3-
from litellm import ContextWindowExceededError
2+
from typing import Type, Any, Optional, TYPE_CHECKING
43

54
from dspy.adapters.types import History
6-
from dspy.utils.callback import with_callbacks
7-
5+
from dspy.utils.callback import BaseCallback, with_callbacks
6+
from dspy.signatures.signature import Signature
7+
if TYPE_CHECKING:
8+
from dspy.clients.lm import LM
89

910
class Adapter(ABC):
10-
def __init__(self, callbacks=None):
11+
def __init__(self, callbacks: Optional[list[BaseCallback]] = None):
1112
self.callbacks = callbacks or []
1213

1314
def __init_subclass__(cls, **kwargs) -> None:
@@ -17,60 +18,53 @@ def __init_subclass__(cls, **kwargs) -> None:
1718
cls.format = with_callbacks(cls.format)
1819
cls.parse = with_callbacks(cls.parse)
1920

20-
def __call__(self, lm, lm_kwargs, signature, demos, inputs):
21+
def __call__(self, lm: "LM", lm_kwargs: dict[str, Any], signature: Type[Signature], demos: list[dict[str, Any]], inputs: dict[str, Any]) -> list[dict[str, Any]]:
2122
inputs_ = self.format(signature, demos, inputs)
2223
inputs_ = dict(prompt=inputs_) if isinstance(inputs_, str) else dict(messages=inputs_)
2324

2425
outputs = lm(**inputs_, **lm_kwargs)
2526
values = []
2627

27-
try:
28-
for output in outputs:
29-
output_logprobs = None
30-
31-
if isinstance(output, dict):
32-
output, output_logprobs = output["text"], output["logprobs"]
28+
for output in outputs:
29+
output_logprobs = None
3330

34-
value = self.parse(signature, output)
31+
if isinstance(output, dict):
32+
output, output_logprobs = output["text"], output["logprobs"]
3533

36-
if set(value.keys()) != set(signature.output_fields.keys()):
37-
raise ValueError(
38-
"Parsed output fields do not match signature output fields. "
39-
f"Expected: {set(signature.output_fields.keys())}, Got: {set(value.keys())}"
40-
)
34+
value = self.parse(signature, output)
4135

42-
if output_logprobs is not None:
43-
value["logprobs"] = output_logprobs
36+
if set(value.keys()) != set(signature.output_fields.keys()):
37+
raise ValueError(
38+
"Parsed output fields do not match signature output fields. "
39+
f"Expected: {set(signature.output_fields.keys())}, Got: {set(value.keys())}"
40+
)
4441

45-
values.append(value)
42+
if output_logprobs is not None:
43+
value["logprobs"] = output_logprobs
4644

47-
return values
45+
values.append(value)
4846

49-
except Exception as e:
50-
if isinstance(e, ContextWindowExceededError):
51-
# On context window exceeded error, we don't want to retry with a different adapter.
52-
raise e
53-
from dspy.adapters.json_adapter import JSONAdapter
47+
return values
5448

55-
if not isinstance(self, JSONAdapter):
56-
return JSONAdapter()(lm, lm_kwargs, signature, demos, inputs)
57-
raise e
5849

5950
@abstractmethod
60-
def format(self, signature, demos, inputs):
51+
def format(self, signature: Type[Signature], demos: list[dict[str, Any]], inputs: dict[str, Any]) -> list[dict[str, Any]]:
6152
raise NotImplementedError
6253

6354
@abstractmethod
64-
def parse(self, signature, completion):
55+
def parse(self, signature: Type[Signature], completion: str) -> dict[str, Any]:
6556
raise NotImplementedError
66-
67-
def format_finetune_data(self, signature, demos, inputs, outputs):
57+
58+
def format_fields(self, signature: Type[Signature], values: dict[str, Any], role: str) -> str:
59+
raise NotImplementedError
60+
61+
def format_finetune_data(self, signature: Type[Signature], demos: list[dict[str, Any]], inputs: dict[str, Any], outputs: dict[str, Any]) -> dict[str, list[Any]]:
6862
raise NotImplementedError
6963

70-
def format_turn(self, signature, values, role, incomplete=False, is_conversation_history=False):
71-
pass
64+
def format_turn(self, signature: Type[Signature], values, role: str, incomplete: bool = False, is_conversation_history: bool = False) -> dict[str, Any]:
65+
raise NotImplementedError
7266

73-
def format_conversation_history(self, signature, inputs):
67+
def format_conversation_history(self, signature: Type[Signature], inputs: dict[str, Any]) -> list[dict[str, Any]]:
7468
history_field_name = None
7569
for name, field in signature.input_fields.items():
7670
if field.annotation == History:

dspy/adapters/chat_adapter.py

Lines changed: 24 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,11 @@
44
import re
55
import textwrap
66
from collections.abc import Mapping
7-
from itertools import chain
8-
from typing import Any, Dict, Literal, NamedTuple
7+
from typing import Any, Dict, Literal, NamedTuple, Optional, Type
98

109
import pydantic
1110
from pydantic.fields import FieldInfo
11+
from litellm import ContextWindowExceededError
1212

1313
from dspy.adapters.base import Adapter
1414
from dspy.adapters.types.image import try_expand_image_tags
@@ -17,6 +17,9 @@
1717
from dspy.signatures.field import OutputField
1818
from dspy.signatures.signature import Signature, SignatureMeta
1919
from dspy.signatures.utils import get_dspy_field_type
20+
from dspy.adapters.json_adapter import JSONAdapter
21+
from dspy.clients.lm import LM
22+
from dspy.utils.callback import BaseCallback
2023

2124
field_header_pattern = re.compile(r"\[\[ ## (\w+) ## \]\]")
2225

@@ -31,7 +34,20 @@ class FieldInfoWithName(NamedTuple):
3134

3235

3336
class ChatAdapter(Adapter):
34-
def format(self, signature: Signature, demos: list[dict[str, Any]], inputs: dict[str, Any]) -> list[dict[str, Any]]:
37+
def __init__(self, callbacks: Optional[list[BaseCallback]] = None):
38+
super().__init__(callbacks)
39+
40+
def __call__(self, lm: LM, lm_kwargs: dict[str, Any], signature: Type[Signature], demos: list[dict[str, Any]], inputs: dict[str, Any]) -> list[dict[str, Any]]:
41+
try:
42+
return super().__call__(lm, lm_kwargs, signature, demos, inputs)
43+
except Exception as e:
44+
if isinstance(e, ContextWindowExceededError):
45+
# On context window exceeded error, we don't want to retry with a different adapter.
46+
raise e
47+
# fallback to JSONAdapter
48+
return JSONAdapter()(lm, lm_kwargs, signature, demos, inputs)
49+
50+
def format(self, signature: Type[Signature], demos: list[dict[str, Any]], inputs: dict[str, Any]) -> list[dict[str, Any]]:
3551
messages: list[dict[str, Any]] = []
3652

3753
# Extract demos where some of the output_fields are not filled in.
@@ -64,7 +80,7 @@ def format(self, signature: Signature, demos: list[dict[str, Any]], inputs: dict
6480
messages = try_expand_image_tags(messages)
6581
return messages
6682

67-
def parse(self, signature, completion):
83+
def parse(self, signature: Type[Signature], completion: str) -> dict[str, Any]:
6884
sections = [(None, [])]
6985

7086
for line in completion.splitlines():
@@ -92,7 +108,7 @@ def parse(self, signature, completion):
92108
return fields
93109

94110
# TODO(PR): Looks ok?
95-
def format_finetune_data(self, signature, demos, inputs, outputs):
111+
def format_finetune_data(self, signature: Type[Signature], demos: list[dict[str, Any]], inputs: dict[str, Any], outputs: dict[str, Any]) -> dict[str, list[Any]]:
96112
# Get system + user messages
97113
messages = self.format(signature, demos, inputs)
98114

@@ -105,7 +121,7 @@ def format_finetune_data(self, signature, demos, inputs, outputs):
105121
# Wrap the messages in a dictionary with a "messages" key
106122
return dict(messages=messages)
107123

108-
def format_fields(self, signature, values, role):
124+
def format_fields(self, signature: Type[Signature], values: dict[str, Any], role: str) -> str:
109125
fields_with_values = {
110126
FieldInfoWithName(name=field_name, info=field_info): values.get(
111127
field_name, "Not supplied for this particular example."
@@ -115,7 +131,7 @@ def format_fields(self, signature, values, role):
115131
}
116132
return format_fields(fields_with_values)
117133

118-
def format_turn(self, signature, values, role, incomplete=False, is_conversation_history=False):
134+
def format_turn(self, signature: Type[Signature], values: dict[str, Any], role: str, incomplete: bool = False, is_conversation_history: bool = False) -> dict[str, Any]:
119135
return format_turn(signature, values, role, incomplete, is_conversation_history)
120136

121137

@@ -139,7 +155,7 @@ def format_fields(fields_with_values: Dict[FieldInfoWithName, Any]) -> str:
139155
return "\n\n".join(output).strip()
140156

141157

142-
def format_turn(signature, values, role, incomplete=False, is_conversation_history=False):
158+
def format_turn(signature: Type[Signature], values: dict[str, Any], role: str, incomplete=False, is_conversation_history=False):
143159
"""
144160
Constructs a new message ("turn") to append to a chat thread. The message is carefully formatted
145161
so that it can instruct an LLM to generate responses conforming to the specified DSPy signature.
@@ -212,11 +228,6 @@ def type_info(v):
212228
return {"role": role, "content": joined_messages}
213229

214230

215-
def flatten_messages(messages):
216-
"""Flatten nested message lists."""
217-
return list(chain.from_iterable(item if isinstance(item, list) else [item] for item in messages))
218-
219-
220231
def enumerate_fields(fields: dict) -> str:
221232
parts = []
222233
for idx, (k, v) in enumerate(fields.items()):

dspy/adapters/json_adapter.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import logging
55
import textwrap
66
from copy import deepcopy
7-
from typing import Any, Dict, KeysView, Literal, NamedTuple
7+
from typing import Any, Dict, KeysView, Literal, NamedTuple, Type
88

99
import json_repair
1010
import litellm
@@ -16,7 +16,8 @@
1616
from dspy.adapters.types.image import Image
1717
from dspy.adapters.types.history import History
1818
from dspy.adapters.utils import format_field_value, get_annotation_name, parse_value, serialize_for_json
19-
from dspy.signatures.signature import SignatureMeta
19+
from dspy.clients.lm import LM
20+
from dspy.signatures.signature import SignatureMeta, Signature
2021
from dspy.signatures.utils import get_dspy_field_type
2122

2223
logger = logging.getLogger(__name__)
@@ -26,12 +27,11 @@ class FieldInfoWithName(NamedTuple):
2627
name: str
2728
info: FieldInfo
2829

29-
3030
class JSONAdapter(Adapter):
3131
def __init__(self):
3232
pass
3333

34-
def __call__(self, lm, lm_kwargs, signature, demos, inputs):
34+
def __call__(self, lm: LM, lm_kwargs: dict[str, Any], signature: Type[Signature], demos: list[dict[str, Any]], inputs: dict[str, Any]) -> list[dict[str, Any]]:
3535
inputs = self.format(signature, demos, inputs)
3636
inputs = dict(prompt=inputs) if isinstance(inputs, str) else dict(messages=inputs)
3737

@@ -66,7 +66,7 @@ def __call__(self, lm, lm_kwargs, signature, demos, inputs):
6666

6767
return values
6868

69-
def format(self, signature, demos, inputs):
69+
def format(self, signature: Type[Signature], demos: list[dict[str, Any]], inputs: dict[str, Any]) -> list[dict[str, Any]]:
7070
messages = []
7171

7272
# Extract demos where some of the output_fields are not filled in.
@@ -94,7 +94,7 @@ def format(self, signature, demos, inputs):
9494

9595
return messages
9696

97-
def parse(self, signature, completion):
97+
def parse(self, signature: Type[Signature], completion: str) -> dict[str, Any]:
9898
fields = json_repair.loads(completion)
9999
fields = {k: v for k, v in fields.items() if k in signature.output_fields}
100100

@@ -108,10 +108,7 @@ def parse(self, signature, completion):
108108

109109
return fields
110110

111-
def format_turn(self, signature, values, role, incomplete=False, is_conversation_history=False):
112-
return format_turn(signature, values, role, incomplete, is_conversation_history)
113-
114-
def format_fields(self, signature, values, role):
111+
def format_fields(self, signature: Type[Signature], values: dict[str, Any], role: str) -> str:
115112
fields_with_values = {
116113
FieldInfoWithName(name=field_name, info=field_info): values.get(
117114
field_name, "Not supplied for this particular example."
@@ -121,6 +118,13 @@ def format_fields(self, signature, values, role):
121118
}
122119

123120
return format_fields(role=role, fields_with_values=fields_with_values)
121+
122+
def format_turn(self, signature: Type[Signature], values, role: str, incomplete: bool = False, is_conversation_history: bool = False) -> dict[str, Any]:
123+
return format_turn(signature, values, role, incomplete, is_conversation_history)
124+
125+
def format_finetune_data(self, signature: Type[Signature], demos: list[dict[str, Any]], inputs: dict[str, Any], outputs: dict[str, Any]) -> dict[str, list[Any]]:
126+
# TODO: implement format_finetune_data method in JSONAdapter
127+
raise NotImplementedError
124128

125129

126130
def _format_field_value(field_info: FieldInfo, value: Any) -> str:

dspy/clients/lm.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import uuid
77
from datetime import datetime
88
from hashlib import sha256
9-
from typing import Any, Dict, List, Literal, Optional, cast
9+
from typing import Any, Dict, List, Literal, Optional, cast, TYPE_CHECKING
1010

1111
import litellm
1212
import pydantic
@@ -17,11 +17,12 @@
1717
from litellm import RetryPolicy
1818

1919
import dspy
20-
from dspy.adapters.base import Adapter
2120
from dspy.clients.openai import OpenAIProvider
2221
from dspy.clients.provider import Provider, TrainingJob
2322
from dspy.clients.utils_finetune import TrainDataFormat
2423
from dspy.utils.callback import BaseCallback, with_callbacks
24+
if TYPE_CHECKING:
25+
from dspy.adapters.base import Adapter
2526

2627
from .base_lm import BaseLM
2728

@@ -219,7 +220,7 @@ def infer_provider(self) -> Provider:
219220
# providers in this file. Is this okay?
220221
return Provider()
221222

222-
def infer_adapter(self) -> Adapter:
223+
def infer_adapter(self) -> "Adapter":
223224
import dspy
224225

225226
if dspy.settings.adapter:

dspy/predict/predict.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
from dspy.primitives.program import Module
1010
from dspy.signatures.signature import ensure_signature
1111
from dspy.utils.callback import with_callbacks
12+
from dspy.dsp.utils import settings
13+
from dspy.adapters.chat_adapter import ChatAdapter
1214

1315

1416
class Predict(Module, Parameter):
@@ -71,16 +73,14 @@ def __call__(self, **kwargs):
7173
return self.forward(**kwargs)
7274

7375
def forward(self, **kwargs):
74-
import dspy
75-
7676
# Extract the three privileged keyword arguments.
7777
assert "new_signature" not in kwargs, "new_signature is no longer a valid keyword argument."
7878
signature = ensure_signature(kwargs.pop("signature", self.signature))
7979
demos = kwargs.pop("demos", self.demos)
8080
config = dict(**self.config, **kwargs.pop("config", {}))
8181

8282
# Get the right LM to use.
83-
lm = kwargs.pop("lm", self.lm) or dspy.settings.lm
83+
lm = kwargs.pop("lm", self.lm) or settings.lm
8484
assert isinstance(lm, BaseLM), "No LM is loaded."
8585

8686
# If temperature is 0.0 but its n > 1, set temperature to 0.7.
@@ -96,9 +96,7 @@ def forward(self, **kwargs):
9696
missing = [k for k in signature.input_fields if k not in kwargs]
9797
print(f"WARNING: Not all input fields were provided to module. Present: {present}. Missing: {missing}.")
9898

99-
import dspy
100-
101-
adapter = dspy.settings.adapter or dspy.ChatAdapter()
99+
adapter = settings.adapter or ChatAdapter()
102100
completions = adapter(
103101
lm,
104102
lm_kwargs=config,
@@ -109,8 +107,8 @@ def forward(self, **kwargs):
109107

110108
pred = Prediction.from_completions(completions, signature=signature)
111109

112-
if kwargs.pop("_trace", True) and dspy.settings.trace is not None:
113-
trace = dspy.settings.trace
110+
if kwargs.pop("_trace", True) and settings.trace is not None:
111+
trace = settings.trace
114112
trace.append((self, {**kwargs}, pred))
115113

116114
return pred

0 commit comments

Comments
 (0)