Skip to content

Commit 7e78199

Browse files
Add dspy.Embedding (#1735)
* Add embedding model * force return type to be numpy array --------- Co-authored-by: Omar Khattab <[email protected]>
1 parent e4e7e0b commit 7e78199

File tree

4 files changed

+155
-8
lines changed

4 files changed

+155
-8
lines changed

dspy/clients/__init__.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,15 @@
11
from .lm import LM
22
from .base_lm import BaseLM, inspect_history
3+
from .embedding import Embedding
4+
import litellm
5+
import os
6+
from pathlib import Path
7+
from litellm.caching import Cache
8+
9+
DISK_CACHE_DIR = os.environ.get("DSPY_CACHEDIR") or os.path.join(Path.home(), ".dspy_cache")
10+
litellm.cache = Cache(disk_cache_dir=DISK_CACHE_DIR, type="disk")
11+
litellm.telemetry = False
12+
13+
if "LITELLM_LOCAL_MODEL_COST_MAP" not in os.environ:
14+
# accessed at run time by litellm; i.e., fine to keep after import
15+
os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True"

dspy/clients/embedding.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
import litellm
2+
import numpy as np
3+
4+
5+
class Embedding:
6+
"""DSPy embedding class.
7+
8+
The class for computing embeddings for text inputs. This class provides a unified interface for both:
9+
10+
1. Hosted embedding models (e.g. OpenAI's text-embedding-3-small) via litellm integration
11+
2. Custom embedding functions that you provide
12+
13+
For hosted models, simply pass the model name as a string (e.g. "openai/text-embedding-3-small"). The class will use
14+
litellm to handle the API calls and caching.
15+
16+
For custom embedding models, pass a callable function that:
17+
- Takes a list of strings as input.
18+
- Returns embeddings as either:
19+
- A 2D numpy array of float32 values
20+
- A 2D list of float32 values
21+
- Each row should represent one embedding vector
22+
23+
Args:
24+
model: The embedding model to use. This can be either a string (representing the name of the hosted embedding
25+
model, must be an embedding model supported by litellm) or a callable that represents a custom embedding
26+
model.
27+
28+
Examples:
29+
Example 1: Using a hosted model.
30+
31+
```python
32+
import dspy
33+
34+
embedder = dspy.Embedding("openai/text-embedding-3-small")
35+
embeddings = embedder(["hello", "world"])
36+
37+
assert embeddings.shape == (2, 1536)
38+
```
39+
40+
Example 2: Using a custom function.
41+
42+
```python
43+
import dspy
44+
45+
def my_embedder(texts):
46+
return np.random.rand(len(texts), 10)
47+
48+
embedder = dspy.Embedding(my_embedder)
49+
embeddings = embedder(["hello", "world"])
50+
51+
assert embeddings.shape == (2, 10)
52+
```
53+
"""
54+
55+
def __init__(self, model):
56+
self.model = model
57+
58+
def __call__(self, inputs, caching=True, **kwargs):
59+
"""Compute embeddings for the given inputs.
60+
61+
Args:
62+
inputs: The inputs to compute embeddings for, can be a single string or a list of strings.
63+
caching: Whether to cache the embedding response, only valid when using a hosted embedding model.
64+
kwargs: Additional keyword arguments to pass to the embedding model.
65+
66+
Returns:
67+
A 2-D numpy array of embeddings, one embedding per row.
68+
"""
69+
if isinstance(inputs, str):
70+
inputs = [inputs]
71+
if isinstance(self.model, str):
72+
embedding_response = litellm.embedding(model=self.model, input=inputs, caching=caching, **kwargs)
73+
return np.array([data["embedding"] for data in embedding_response.data], dtype=np.float32)
74+
elif callable(self.model):
75+
return np.array(self.model(inputs, **kwargs), dtype=np.float32)
76+
else:
77+
raise ValueError(f"`model` in `dspy.Embedding` must be a string or a callable, but got {type(self.model)}.")

dspy/clients/lm.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,26 +5,19 @@
55
import uuid
66
from concurrent.futures import ThreadPoolExecutor
77
from datetime import datetime
8-
from pathlib import Path
98
from typing import Any, Dict, List, Literal, Optional
109

1110
import litellm
1211
import ujson
13-
from litellm.caching import Cache
1412

1513
from dspy.clients.finetune import FinetuneJob, TrainingMethod
1614
from dspy.clients.lm_finetune_utils import execute_finetune_job, get_provider_finetune_job_class
1715
from dspy.utils.callback import BaseCallback, with_callbacks
1816

19-
DISK_CACHE_DIR = os.environ.get("DSPY_CACHEDIR") or os.path.join(Path.home(), ".dspy_cache")
20-
litellm.cache = Cache(disk_cache_dir=DISK_CACHE_DIR, type="disk")
21-
litellm.telemetry = False
22-
23-
if "LITELLM_LOCAL_MODEL_COST_MAP" not in os.environ:
24-
os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True"
2517

2618
logger = logging.getLogger(__name__)
2719

20+
2821
class LM(BaseLM):
2922
"""
3023
A language model supporting chat or text completion requests for use with DSPy modules.

tests/clients/test_embedding.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
import pytest
2+
from unittest.mock import Mock, patch
3+
import numpy as np
4+
5+
from dspy.clients.embedding import Embedding
6+
7+
8+
# Mock response format similar to litellm's embedding response.
9+
class MockEmbeddingResponse:
10+
def __init__(self, embeddings):
11+
self.data = [{"embedding": emb} for emb in embeddings]
12+
self.usage = {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}
13+
self.model = "mock_model"
14+
self.object = "list"
15+
16+
17+
def test_litellm_embedding():
18+
model = "text-embedding-ada-002"
19+
inputs = ["hello", "world"]
20+
mock_embeddings = [
21+
[0.1, 0.2, 0.3], # embedding for "hello"
22+
[0.4, 0.5, 0.6], # embedding for "world"
23+
]
24+
25+
with patch("litellm.embedding") as mock_litellm:
26+
# Configure mock to return proper response format.
27+
mock_litellm.return_value = MockEmbeddingResponse(mock_embeddings)
28+
29+
# Create embedding instance and call it.
30+
embedding = Embedding(model)
31+
result = embedding(inputs)
32+
33+
# Verify litellm was called with correct parameters.
34+
mock_litellm.assert_called_once_with(model=model, input=inputs, caching=True)
35+
36+
assert len(result) == len(inputs)
37+
np.testing.assert_allclose(result, mock_embeddings)
38+
39+
40+
def test_callable_embedding():
41+
inputs = ["hello", "world", "test"]
42+
43+
expected_embeddings = [
44+
[0.1, 0.2, 0.3], # embedding for "hello"
45+
[0.4, 0.5, 0.6], # embedding for "world"
46+
[0.7, 0.8, 0.9], # embedding for "test"
47+
]
48+
49+
def mock_embedding_fn(texts):
50+
# Simple callable that returns random embeddings.
51+
return expected_embeddings
52+
53+
# Create embedding instance with callable
54+
embedding = Embedding(mock_embedding_fn)
55+
result = embedding(inputs)
56+
57+
np.testing.assert_allclose(result, expected_embeddings)
58+
59+
60+
def test_invalid_model_type():
61+
# Test that invalid model type raises ValueError
62+
with pytest.raises(ValueError):
63+
embedding = Embedding(123) # Invalid model type
64+
embedding(["test"])

0 commit comments

Comments
 (0)