Skip to content

Commit ca023be

Browse files
feat: MUVERA embeddings (#542)
* Implement MuveraEmbedding * Add random generator parameter for reproducibility in MuveraEmbedding * Document random_seed parameter * Remove unnecessary module docstring from muvera_embedding.py * refactor: clean up constructor parameters and improve formatting in MuveraEmbedding * refactor: rename muvera_embedding.py to muvera.py and update related references * feat: enhance MuveraEmbedding with multi-vector model support and improve parameter defaults * feat: add embedding_size property to MuveraEmbedding * feat: update MuveraPostprocessor to use model description for embedding size and add Jupyter notebook for MUVERA usage * fix: fix types, doctest, rename variables, refactor (#545) * fix: fix types, doctest, rename variables, refactor * fix: fix python3.9 compatibility * fix: make get_output_dimension protected * Optimize muvera (#551) * vectorize operations * fix: fill empty clusters with dataset vectors * rollback get_output_dimension * fix: fix type hints * fix: review comments * tests: add tests --------- Co-authored-by: George <[email protected]>
1 parent d1ddc81 commit ca023be

File tree

3 files changed

+405
-0
lines changed

3 files changed

+405
-0
lines changed

fastembed/postprocess/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from fastembed.postprocess.muvera import Muvera
2+
3+
__all__ = ["Muvera"]

fastembed/postprocess/muvera.py

Lines changed: 364 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,364 @@
1+
from typing import Union
2+
3+
import numpy as np
4+
5+
from fastembed.common.types import NumpyArray
6+
from fastembed.late_interaction.late_interaction_embedding_base import (
7+
LateInteractionTextEmbeddingBase,
8+
)
9+
from fastembed.late_interaction_multimodal.late_interaction_multimodal_embedding_base import (
10+
LateInteractionMultimodalEmbeddingBase,
11+
)
12+
13+
14+
MultiVectorModel = Union[LateInteractionTextEmbeddingBase, LateInteractionMultimodalEmbeddingBase]
15+
MAX_HAMMING_DISTANCE = 65 # 64 bits + 1
16+
POPCOUNT_LUT = np.array([bin(x).count("1") for x in range(256)], dtype=np.uint8)
17+
18+
19+
def hamming_distance_matrix(ids: np.ndarray) -> np.ndarray:
20+
"""Compute full Hamming distance matrix
21+
22+
Args:
23+
ids: shape (n,) - array of ids, only size of the array matters
24+
25+
Return:
26+
np.ndarray (n, n) - hamming distance matrix
27+
"""
28+
n = len(ids)
29+
xor_vals = np.bitwise_xor(ids[:, None], ids[None, :]) # (n, n) uint64
30+
bytes_view = xor_vals.view(np.uint8).reshape(n, n, 8) # (n, n, 8)
31+
return POPCOUNT_LUT[bytes_view].sum(axis=2)
32+
33+
34+
class SimHashProjection:
35+
"""
36+
SimHash projection component for MUVERA clustering.
37+
38+
This class implements locality-sensitive hashing using random hyperplanes
39+
to partition the vector space into 2^k_sim clusters. Each vector is assigned
40+
to a cluster based on which side of k_sim random hyperplanes it falls on.
41+
42+
Attributes:
43+
k_sim (int): Number of SimHash functions (hyperplanes)
44+
dim (int): Dimensionality of input vectors
45+
simhash_vectors (np.ndarray): Random hyperplane normal vectors of shape (dim, k_sim)
46+
"""
47+
48+
def __init__(self, k_sim: int, dim: int, random_generator: np.random.Generator):
49+
"""
50+
Initialize SimHash projection with random hyperplanes.
51+
52+
Args:
53+
k_sim (int): Number of SimHash functions, determines 2^k_sim clusters
54+
dim (int): Dimensionality of input vectors
55+
random_generator (np.random.Generator): Random number generator for reproducibility
56+
"""
57+
self.k_sim = k_sim
58+
self.dim = dim
59+
# Generate k_sim random hyperplanes (normal vectors) from standard normal distribution
60+
self.simhash_vectors = random_generator.normal(size=(dim, k_sim))
61+
62+
def get_cluster_ids(self, vectors: np.ndarray) -> np.ndarray:
63+
"""
64+
Compute the cluster IDs for a given vector using SimHash.
65+
66+
The cluster ID is determined by computing the dot product of the vector
67+
with each hyperplane normal vector, taking the sign, and interpreting
68+
the resulting binary string as an integer.
69+
70+
Args:
71+
vectors (np.ndarray): Input vectors of shape (n, dim,)
72+
73+
Returns:
74+
np.ndarray: Cluster IDs in range [0, 2^k_sim - 1]
75+
76+
Raises:
77+
AssertionError: If a vector shape doesn't match expected dimensionality
78+
"""
79+
dot_product = (
80+
vectors @ self.simhash_vectors
81+
) # (token_num, dim) x (dim, k_sim) -> (token_num, k_sim)
82+
cluster_ids = (dot_product > 0) @ (1 << np.arange(self.k_sim))
83+
return cluster_ids
84+
85+
86+
class Muvera:
87+
"""
88+
MUVERA (Multi-Vector Retrieval Architecture) algorithm implementation.
89+
90+
This class creates Fixed Dimensional Encodings (FDEs) from variable-length
91+
sequences of vectors by using SimHash clustering and random projections.
92+
The process involves:
93+
1. Clustering vectors using multiple SimHash projections
94+
2. Computing cluster centers (with different strategies for docs vs queries)
95+
3. Applying random projections for dimensionality reduction
96+
4. Concatenating results from all projections
97+
98+
Attributes:
99+
k_sim (int): Number of SimHash functions per projection
100+
dim (int): Input vector dimensionality
101+
dim_proj (int): Output dimensionality after random projection
102+
r_reps (int): Number of random projection repetitions
103+
random_seed (int): Random seed for consistent random matrix generation
104+
simhash_projections (List[SimHashProjection]): SimHash instances for clustering
105+
dim_reduction_projections (np.ndarray): Random projection matrices of shape (R_reps, d, d_proj)
106+
"""
107+
108+
def __init__(
109+
self,
110+
dim: int,
111+
k_sim: int = 5,
112+
dim_proj: int = 16,
113+
r_reps: int = 20,
114+
random_seed: int = 42,
115+
):
116+
"""
117+
Initialize MUVERA algorithm with specified parameters.
118+
119+
Args:
120+
dim (int): Dimensionality of individual input vectors
121+
k_sim (int, optional): Number of SimHash functions (creates 2^k_sim clusters).
122+
Defaults to 5.
123+
dim_proj (int, optional): Dimensionality after random projection (must be <= dim).
124+
Defaults to 16.
125+
r_reps (int, optional): Number of random projection repetitions for robustness.
126+
Defaults to 20.
127+
random_seed (int, optional): Seed for random number generator to ensure
128+
reproducible results. Defaults to 42.
129+
130+
Raises:
131+
ValueError: If dim_proj > dim (cannot project to higher dimensionality)
132+
"""
133+
if dim_proj > dim:
134+
raise ValueError(
135+
f"Cannot project to a higher dimensionality (dim_proj={dim_proj} > dim={dim})"
136+
)
137+
138+
self.k_sim = k_sim
139+
self.dim = dim
140+
self.dim_proj = dim_proj
141+
self.r_reps = r_reps
142+
# Create r_reps independent SimHash projections for robustness
143+
generator = np.random.default_rng(random_seed)
144+
self.simhash_projections = [
145+
SimHashProjection(k_sim=self.k_sim, dim=self.dim, random_generator=generator)
146+
for _ in range(r_reps)
147+
]
148+
# Random projection matrices with entries from {-1, +1} for each repetition
149+
self.dim_reduction_projections = generator.choice([-1, 1], size=(r_reps, dim, dim_proj))
150+
151+
@classmethod
152+
def from_multivector_model(
153+
cls,
154+
model: MultiVectorModel,
155+
k_sim: int = 5,
156+
dim_proj: int = 16,
157+
r_reps: int = 20, # noqa[naming]
158+
random_seed: int = 42,
159+
) -> "Muvera":
160+
"""
161+
Create a Muvera instance from a multi-vector embedding model.
162+
163+
This class method provides a convenient way to initialize a MUVERA
164+
that is compatible with a given multi-vector model by automatically extracting
165+
the embedding dimensionality from the model.
166+
167+
Args:
168+
model (MultiVectorModel): A late interaction text or multimodal embedding model
169+
that provides multi-vector embeddings. Must have an
170+
`embedding_size` attribute specifying the dimensionality
171+
of individual vectors.
172+
k_sim (int, optional): Number of SimHash functions (creates 2^k_sim clusters).
173+
Defaults to 5.
174+
dim_proj (int, optional): Dimensionality after random projection (must be <= model's
175+
embedding_size). Defaults to 16.
176+
r_reps (int, optional): Number of random projection repetitions for robustness.
177+
Defaults to 20.
178+
random_seed (int, optional): Seed for random number generator to ensure
179+
reproducible results. Defaults to 42.
180+
181+
Returns:
182+
Muvera: A configured MUVERA instance ready to process embeddings from the given model.
183+
184+
Raises:
185+
ValueError: If dim_proj > model.embedding_size (cannot project to higher dimensionality)
186+
187+
Example:
188+
>>> from fastembed import LateInteractionTextEmbedding
189+
>>> model = LateInteractionTextEmbedding(model_name="colbert-ir/colbertv2.0")
190+
>>> muvera = Muvera.from_multivector_model(
191+
... model=model,
192+
... k_sim=6,
193+
... dim_proj=32
194+
... )
195+
>>> # Now use postprocessor with embeddings from the model
196+
>>> embeddings = np.array(list(model.embed(["sample text"])))
197+
>>> fde = muvera.process_document(embeddings[0])
198+
"""
199+
return cls(
200+
dim=model.embedding_size,
201+
k_sim=k_sim,
202+
dim_proj=dim_proj,
203+
r_reps=r_reps,
204+
random_seed=random_seed,
205+
)
206+
207+
def _get_output_dimension(self) -> int:
208+
"""
209+
Get the output dimension of the MUVERA algorithm.
210+
211+
Returns:
212+
int: Output dimension (r_reps * num_partitions * dim_proj) where b = 2^k_sim
213+
"""
214+
num_partitions = 2**self.k_sim
215+
return self.r_reps * num_partitions * self.dim_proj
216+
217+
@property
218+
def embedding_size(self) -> int:
219+
return self._get_output_dimension()
220+
221+
def process_document(self, vectors: NumpyArray) -> NumpyArray:
222+
"""
223+
Encode a document's vectors into a Fixed Dimensional Encoding (FDE).
224+
225+
Uses document-specific settings: normalizes cluster centers by vector count
226+
and fills empty clusters using Hamming distance-based selection.
227+
228+
Args:
229+
vectors (NumpyArray): Document vectors of shape (n_tokens, dim)
230+
231+
Returns:
232+
NumpyArray: Fixed dimensional encodings of shape (r_reps * b * dim_proj,)
233+
"""
234+
return self.process(vectors, fill_empty_clusters=True, normalize_by_count=True)
235+
236+
def process_query(self, vectors: NumpyArray) -> NumpyArray:
237+
"""
238+
Encode a query's vectors into a Fixed Dimensional Encoding (FDE).
239+
240+
Uses query-specific settings: no normalization by count and no empty
241+
cluster filling to preserve query vector magnitudes.
242+
243+
Args:
244+
vectors (NumpyArray]): Query vectors of shape (n_tokens, dim)
245+
246+
Returns:
247+
NumpyArray: Fixed dimensional encoding of shape (r_reps * b * dim_proj,)
248+
"""
249+
return self.process(vectors, fill_empty_clusters=False, normalize_by_count=False)
250+
251+
def process(
252+
self,
253+
vectors: NumpyArray,
254+
fill_empty_clusters: bool = True,
255+
normalize_by_count: bool = True,
256+
) -> NumpyArray:
257+
"""
258+
Core encoding method that transforms variable-length vector sequences into FDEs.
259+
260+
The encoding process:
261+
1. For each of r_reps random projections:
262+
a. Assign vectors to clusters using SimHash
263+
b. Compute cluster centers (sum of vectors in each cluster)
264+
c. Optionally normalize by cluster size
265+
d. Fill empty clusters using Hamming distance if requested
266+
e. Apply random projection for dimensionality reduction
267+
f. Flatten cluster centers into a vector
268+
2. Concatenate all projection results
269+
270+
Args:
271+
vectors (np.ndarray): Input vectors of shape (n_vectors, dim)
272+
fill_empty_clusters (bool): Whether to fill empty clusters using nearest
273+
vectors based on Hamming distance of cluster IDs
274+
normalize_by_count (bool): Whether to normalize cluster centers by the
275+
number of vectors assigned to each cluster
276+
277+
Returns:
278+
np.ndarray: Fixed dimensional encoding of shape (r_reps * b * dim_proj)
279+
where B = 2^k_sim is the number of clusters
280+
281+
Raises:
282+
AssertionError: If input vectors don't have expected dimensionality
283+
"""
284+
assert (
285+
vectors.shape[1] == self.dim
286+
), f"Expected vectors of shape (n, {self.dim}), got {vectors.shape}"
287+
288+
# Store results from each random projection
289+
output_vectors = []
290+
291+
# num of space partitions in SimHash
292+
num_partitions = 2**self.k_sim
293+
cluster_center_ids = np.arange(num_partitions)
294+
precomputed_hamming_matrix = (
295+
hamming_distance_matrix(cluster_center_ids) if fill_empty_clusters else None
296+
)
297+
298+
for projection_index, simhash in enumerate(self.simhash_projections):
299+
# Initialize cluster centers and count vectors assigned to each cluster
300+
cluster_centers = np.zeros((num_partitions, self.dim))
301+
cluster_center_id_to_vectors: dict[int, list[int]] = {
302+
cluster_center_id: [] for cluster_center_id in cluster_center_ids
303+
}
304+
cluster_vector_counts = None
305+
empty_mask = None
306+
307+
# Assign each vector to its cluster and accumulate cluster centers
308+
vector_cluster_ids = simhash.get_cluster_ids(vectors)
309+
for cluster_id, (vec_idx, vec) in zip(vector_cluster_ids, enumerate(vectors)):
310+
cluster_centers[cluster_id] += vec
311+
cluster_center_id_to_vectors[cluster_id].append(vec_idx)
312+
313+
if normalize_by_count or fill_empty_clusters:
314+
cluster_vector_counts = np.bincount(vector_cluster_ids, minlength=num_partitions)
315+
empty_mask = cluster_vector_counts == 0
316+
317+
if normalize_by_count:
318+
assert empty_mask is not None
319+
assert cluster_vector_counts is not None
320+
non_empty_mask = ~empty_mask
321+
cluster_centers[non_empty_mask] /= cluster_vector_counts[non_empty_mask][:, None]
322+
323+
# Fill empty clusters using vectors with minimum Hamming distance
324+
if fill_empty_clusters:
325+
assert empty_mask is not None
326+
assert precomputed_hamming_matrix is not None
327+
masked_hamming = np.where(
328+
empty_mask[None, :], MAX_HAMMING_DISTANCE, precomputed_hamming_matrix
329+
)
330+
nearest_non_empty = np.argmin(masked_hamming, axis=1)
331+
fill_vectors = np.array(
332+
[
333+
vectors[cluster_center_id_to_vectors[cluster_id][0]]
334+
for cluster_id in nearest_non_empty[empty_mask]
335+
]
336+
).reshape(-1, self.dim)
337+
cluster_centers[empty_mask] = fill_vectors
338+
339+
# Apply random projection for dimensionality reduction if needed
340+
if self.dim_proj < self.dim:
341+
dim_reduction_projection = self.dim_reduction_projections[
342+
projection_index
343+
] # Get projection matrix for this repetition
344+
projected_centers = (1 / np.sqrt(self.dim_proj)) * (
345+
cluster_centers @ dim_reduction_projection
346+
)
347+
348+
# Flatten cluster centers into a single vector and add to output
349+
output_vectors.append(projected_centers.flatten())
350+
continue
351+
352+
# If no projection needed (dim_proj == dim), use original cluster centers
353+
output_vectors.append(cluster_centers.flatten())
354+
355+
# Concatenate results from all R_reps projections into final FDE
356+
return np.concatenate(output_vectors)
357+
358+
359+
if __name__ == "__main__":
360+
v_arrs = np.random.randn(10, 100, 128)
361+
muvera = Muvera(128, 4, 8, 20, 42)
362+
363+
for v_arr in v_arrs:
364+
muvera.process(v_arr) # type: ignore

0 commit comments

Comments
 (0)