Skip to content

Commit 43a67f8

Browse files
authored
Update Custom LM code
Closes #1601
1 parent bae2ad8 commit 43a67f8

File tree

1 file changed

+28
-12
lines changed

1 file changed

+28
-12
lines changed

docs/docs/building-blocks/1-language_models.md

Lines changed: 28 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -155,31 +155,47 @@ If there is not much overlap in features between your LM and LiteLLM it's better
155155
Let's create an LM for Gemini using `google-generativeai` package from scratch:
156156

157157
```python
158-
from dspy import LM
158+
import os
159+
import dspy
159160
import google.generativeai as genai
160161

161-
162-
class GeminiLM:
163-
def __init__(self, model, api_key, endpoint):
164-
genai.configure(api_key=os.environ["API_KEY"] or api_key)
165-
self.model = genai.GenerativeModel(model)
162+
class GeminiLM(dspy.LM):
163+
def __init__(self, model, api_key=None, endpoint=None, **kwargs):
164+
genai.configure(api_key=os.environ["GEMINI_API_KEY"] or api_key)
166165

167166
self.endpoint = endpoint
168167
self.history = []
169-
super().__init__(model)
170-
168+
169+
super().__init__(model, **kwargs)
170+
self.model = genai.GenerativeModel(model)
171171

172172
def __call__(self, prompt=None, messages=None, **kwargs):
173-
if isinstance(prompt, str):
174-
prompt = [prompt]
173+
# Custom chat model working for text completion model
174+
prompt = '\n\n'.join([x['content'] for x in messages] + ['BEGIN RESPONSE:'])
175175

176176
completions = self.model.generate_content(prompt)
177177
self.history.append({"prompt": prompt, "completions": completions})
178-
178+
179+
# Must return a list of strings
180+
return [completions.candidates[0].content.parts[0].text]
179181

180182
def inspect_history(self):
181183
for interaction in self.history:
182184
print(f"Prompt: {interaction['prompt']} -> Completions: {interaction['completions']}")
185+
186+
lm = GeminiLM("gemini-1.5-flash", temperature=0)
187+
dspy.configure(lm=lm)
188+
189+
qa = dspy.ChainOfThought("question->answer")
190+
qa(question="What is the capital of France?")
191+
```
192+
193+
**Output:**
194+
```text
195+
Prediction(
196+
reasoning='France is a country in Western Europe. Its capital city is Paris.',
197+
answer='Paris'
198+
)
183199
```
184200

185201
The above example is the simplest form of LM. You can add more options to tweak generation config and even control the generated output based on your requirement.
@@ -323,4 +339,4 @@ The above example is a simple Adapter that converts the input to uppercase befor
323339

324340
### Overriding `__call__` method
325341

326-
To gain control over usage of format and parse and even more fine-grained control over the flow of input from signature to outputs you can override `__call__` method and implement your custom flow. Although for most cases only implementing `parse` and `format` function will be fine.
342+
To gain control over usage of format and parse and even more fine-grained control over the flow of input from signature to outputs you can override `__call__` method and implement your custom flow. Although for most cases only implementing `parse` and `format` function will be fine.

0 commit comments

Comments
 (0)