Skip to content

Commit dbc5c48

Browse files
authored
Merge pull request #1315 from sanders41/composit-embedder
Add composit embedder
2 parents 0d80f3c + c93503e commit dbc5c48

File tree

3 files changed

+43
-7
lines changed

3 files changed

+43
-7
lines changed

meilisearch_python_sdk/index.py

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
SimilarSearchResults,
2828
)
2929
from meilisearch_python_sdk.models.settings import (
30+
CompositeEmbedder,
3031
Embedders,
3132
Faceting,
3233
FilterableAttributeFeatures,
@@ -8459,7 +8460,12 @@ def _embedder_json_to_embedders_model( # pragma: no cover
84598460

84608461
embedders: dict[
84618462
str,
8462-
OpenAiEmbedder | HuggingFaceEmbedder | OllamaEmbedder | RestEmbedder | UserProvidedEmbedder,
8463+
OpenAiEmbedder
8464+
| HuggingFaceEmbedder
8465+
| OllamaEmbedder
8466+
| RestEmbedder
8467+
| UserProvidedEmbedder
8468+
| CompositeEmbedder,
84638469
] = {}
84648470
for k, v in embedder_json.items():
84658471
if v.get("source") == "openAi":
@@ -8470,6 +8476,8 @@ def _embedder_json_to_embedders_model( # pragma: no cover
84708476
embedders[k] = OllamaEmbedder(**v)
84718477
elif v.get("source") == "rest":
84728478
embedders[k] = RestEmbedder(**v)
8479+
elif v.get("source") == "composit":
8480+
embedders[k] = CompositeEmbedder(**v)
84738481
else:
84748482
embedders[k] = UserProvidedEmbedder(**v)
84758483

@@ -8482,7 +8490,12 @@ def _embedder_json_to_settings_model( # pragma: no cover
84828490
) -> (
84838491
dict[
84848492
str,
8485-
OpenAiEmbedder | HuggingFaceEmbedder | OllamaEmbedder | RestEmbedder | UserProvidedEmbedder,
8493+
OpenAiEmbedder
8494+
| HuggingFaceEmbedder
8495+
| OllamaEmbedder
8496+
| RestEmbedder
8497+
| UserProvidedEmbedder
8498+
| CompositeEmbedder,
84868499
]
84878500
| None
84888501
):
@@ -8491,7 +8504,12 @@ def _embedder_json_to_settings_model( # pragma: no cover
84918504

84928505
embedders: dict[
84938506
str,
8494-
OpenAiEmbedder | HuggingFaceEmbedder | OllamaEmbedder | RestEmbedder | UserProvidedEmbedder,
8507+
OpenAiEmbedder
8508+
| HuggingFaceEmbedder
8509+
| OllamaEmbedder
8510+
| RestEmbedder
8511+
| UserProvidedEmbedder
8512+
| CompositeEmbedder,
84958513
] = {}
84968514
for k, v in embedder_json.items():
84978515
if v.get("source") == "openAi":
@@ -8502,6 +8520,8 @@ def _embedder_json_to_settings_model( # pragma: no cover
85028520
embedders[k] = OllamaEmbedder(**v)
85038521
elif v.get("source") == "rest":
85048522
embedders[k] = RestEmbedder(**v)
8523+
elif v.get("source") == "composit":
8524+
embedders[k] = CompositeEmbedder(**v)
85058525
else:
85068526
embedders[k] = UserProvidedEmbedder(**v)
85078527

meilisearch_python_sdk/models/settings.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -106,10 +106,25 @@ class UserProvidedEmbedder(CamelBase):
106106
binary_quantized: bool | None = None
107107

108108

109+
class CompositeEmbedder(CamelBase):
110+
source: str = "composite"
111+
search_embedder: (
112+
OpenAiEmbedder | HuggingFaceEmbedder | OllamaEmbedder | RestEmbedder | UserProvidedEmbedder
113+
)
114+
indexing_embedder: (
115+
OpenAiEmbedder | HuggingFaceEmbedder | OllamaEmbedder | RestEmbedder | UserProvidedEmbedder
116+
)
117+
118+
109119
class Embedders(CamelBase):
110120
embedders: dict[
111121
str,
112-
OpenAiEmbedder | HuggingFaceEmbedder | OllamaEmbedder | RestEmbedder | UserProvidedEmbedder,
122+
OpenAiEmbedder
123+
| HuggingFaceEmbedder
124+
| OllamaEmbedder
125+
| RestEmbedder
126+
| UserProvidedEmbedder
127+
| CompositeEmbedder,
113128
]
114129

115130

@@ -162,7 +177,8 @@ class MeilisearchSettings(CamelBase):
162177
| HuggingFaceEmbedder
163178
| OllamaEmbedder
164179
| RestEmbedder
165-
| UserProvidedEmbedder,
180+
| UserProvidedEmbedder
181+
| CompositeEmbedder,
166182
]
167183
| None
168184
) = None # Optional[Embedders] = None

tests/conftest.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -301,7 +301,7 @@ async def enable_edit_by_function(base_url, ssl_verify):
301301
base_url=base_url, headers={"Authorization": f"Bearer {MASTER_KEY}"}, verify=ssl_verify
302302
) as client:
303303
await client.patch("/experimental-features", json={"editDocumentsByFunction": True})
304-
yield
304+
yield
305305

306306

307307
@pytest.fixture(scope="session", autouse=True)
@@ -310,7 +310,7 @@ async def enable_network(base_url, ssl_verify):
310310
base_url=base_url, headers={"Authorization": f"Bearer {MASTER_KEY}"}, verify=ssl_verify
311311
) as client:
312312
await client.patch("/experimental-features", json={"network": True})
313-
yield
313+
yield
314314

315315

316316
@pytest.fixture

0 commit comments

Comments
 (0)