6868 "Qdrant/clip-ViT-B-32-text" : np .array ([0.0083 , 0.0103 , - 0.0138 , 0.0199 , - 0.0069 ]),
6969 "thenlper/gte-base" : np .array ([0.0038 , 0.0355 , 0.0181 , 0.0092 , 0.0654 ]),
7070 "jinaai/jina-clip-v1" : np .array ([- 0.0862 , - 0.0101 , - 0.0056 , 0.0375 , - 0.0472 ]),
71+ "google/embeddinggemma-300m" : np .array (
72+ [- 0.08181356 , 0.0214127 , 0.05120273 , - 0.03690156 , - 0.0254504 ]
73+ ),
74+ }
75+
76+
77+ DOC_PREFIXES = {
78+ "google/embeddinggemma-300m" : "title: none | text: " ,
79+ }
80+ QUERY_PREFIXES = {
81+ "google/embeddinggemma-300m" : "task: search result | query: " ,
82+ }
83+ CANONICAL_QUERY_VECTOR_VALUES = {
84+ "google/embeddinggemma-300m" : np .array (
85+ [- 0.22990295 , 0.03311195 , 0.04290345 , - 0.03558498 , - 0.01399477 ]
86+ )
7187}
7288
7389MULTI_TASK_MODELS = ["jinaai/jina-embeddings-v3" ]
@@ -119,6 +135,9 @@ def test_embedding(model_cache, model_name: str) -> None:
119135
120136 with model_cache (model_desc .model ) as model :
121137 docs = ["hello world" , "flag embedding" ]
138+ if model_desc .model in DOC_PREFIXES :
139+ docs = [DOC_PREFIXES [model_desc .model ] + doc for doc in docs ]
140+
122141 embeddings = list (model .embed (docs ))
123142 embeddings = np .stack (embeddings , axis = 0 )
124143 assert embeddings .shape == (2 , dim )
@@ -129,6 +148,39 @@ def test_embedding(model_cache, model_name: str) -> None:
129148 ), model_desc .model
130149
131150
151+ def test_query_embedding (model_cache ) -> None :
152+ is_ci = os .getenv ("CI" )
153+ is_mac = platform .system () == "Darwin"
154+ is_manual = os .getenv ("GITHUB_EVENT_NAME" ) == "workflow_dispatch"
155+
156+ for model_desc in TextEmbedding ._list_supported_models ():
157+ if model_desc .model in MULTI_TASK_MODELS or (
158+ is_mac and model_desc .model == "nomic-ai/nomic-embed-text-v1.5-Q"
159+ ):
160+ continue
161+
162+ if model_desc .model not in CANONICAL_QUERY_VECTOR_VALUES :
163+ continue
164+
165+ if not should_test_model (model_desc , "" , is_ci , is_manual ):
166+ continue
167+
168+ dim = model_desc .dim
169+ with model_cache (model_desc .model ) as model :
170+ queries = ["hello world" , "flag embedding" ]
171+ if model_desc .model in QUERY_PREFIXES :
172+ queries = [QUERY_PREFIXES [model_desc .model ] + query for query in queries ]
173+
174+ embeddings = list (model .query_embed (queries ))
175+ embeddings = np .stack (embeddings , axis = 0 )
176+ assert embeddings .shape == (2 , dim )
177+
178+ canonical_vector = CANONICAL_QUERY_VECTOR_VALUES [model_desc .model ]
179+ assert np .allclose (
180+ embeddings [0 , : canonical_vector .shape [0 ]], canonical_vector , atol = 1e-3
181+ ), model_desc .model
182+
183+
132184@pytest .mark .parametrize ("n_dims,model_name" , [(384 , "BAAI/bge-small-en-v1.5" )])
133185def test_batch_embedding (model_cache , n_dims : int , model_name : str ) -> None :
134186 with model_cache (model_name ) as model :
0 commit comments