Skip to content

Commit b2af969

Browse files
committed
fix: fix local inference on list of prefetches (#965)
1 parent eef5e25 commit b2af969

File tree

4 files changed

+23
-6
lines changed

4 files changed

+23
-6
lines changed

qdrant_client/async_qdrant_client.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -527,7 +527,9 @@ async def query_points(
527527
"""
528528
assert len(kwargs) == 0, f"Unknown arguments: {list(kwargs.keys())}"
529529
query = self._resolve_query(query)
530-
requires_inference = self._inference_inspector.inspect([query, prefetch])
530+
requires_inference = self._inference_inspector.inspect(query)
531+
if not requires_inference:
532+
requires_inference = self._inference_inspector.inspect(prefetch)
531533
if requires_inference and (not self.cloud_inference):
532534
query = (
533535
next(
@@ -691,7 +693,9 @@ async def query_points_groups(
691693
"""
692694
assert len(kwargs) == 0, f"Unknown arguments: {list(kwargs.keys())}"
693695
query = self._resolve_query(query)
694-
requires_inference = self._inference_inspector.inspect([query, prefetch])
696+
requires_inference = self._inference_inspector.inspect(query)
697+
if not requires_inference:
698+
requires_inference = self._inference_inspector.inspect(prefetch)
695699
if requires_inference and (not self.cloud_inference):
696700
query = (
697701
next(

qdrant_client/embed/type_inspector.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,8 @@ def inspect(self, points: Union[Iterable[BaseModel], BaseModel]) -> bool:
4545
self.parser.parse_model(point.__class__)
4646
if self._inspect_model(point):
4747
return True
48+
else:
49+
return False
4850
return False
4951

5052
def _inspect_model(self, model: BaseModel, paths: Optional[list[FieldPath]] = None) -> bool:

qdrant_client/qdrant_client.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -556,7 +556,9 @@ def query_points(
556556
# If the query contains unprocessed documents, we need to embed them and
557557
# replace the original query with the embedded vectors.
558558
query = self._resolve_query(query)
559-
requires_inference = self._inference_inspector.inspect([query, prefetch])
559+
requires_inference = self._inference_inspector.inspect(query)
560+
if not requires_inference:
561+
requires_inference = self._inference_inspector.inspect(prefetch)
560562
if requires_inference and not self.cloud_inference:
561563
query = (
562564
next(
@@ -725,7 +727,9 @@ def query_points_groups(
725727
# If the query contains unprocessed documents, we need to embed them and
726728
# replace the original query with the embedded vectors.
727729
query = self._resolve_query(query)
728-
requires_inference = self._inference_inspector.inspect([query, prefetch])
730+
requires_inference = self._inference_inspector.inspect(query)
731+
if not requires_inference:
732+
requires_inference = self._inference_inspector.inspect(prefetch)
729733
if requires_inference and not self.cloud_inference:
730734
query = (
731735
next(

tests/embed_tests/test_inspectors.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -246,6 +246,13 @@ def test_inspect_prefetch_types():
246246
paths = inspector_embed.inspect(doc_prefetch)
247247
assert len(paths) == 1 and paths[0].as_str_list() == ["query"]
248248

249+
no_query_list_prefetch_with_doc = models.Prefetch(
250+
query=None, prefetch=[models.Prefetch(query=None), models.Prefetch(query=doc)]
251+
)
252+
assert inspector.inspect(no_query_list_prefetch_with_doc)
253+
paths = inspector_embed.inspect(no_query_list_prefetch_with_doc)
254+
assert len(paths) == 1 and paths[0].as_str_list() == ["prefetch.query"]
255+
249256
nested_prefetch = models.Prefetch(
250257
query=None,
251258
prefetch=models.Prefetch(query=doc),
@@ -276,8 +283,8 @@ def test_inspect_prefetch_types():
276283
"prefetch.prefetch.query",
277284
}
278285

279-
assert inspector.inspect([None, deep_nested_prefetch])
280-
paths = inspector_embed.inspect([None, deep_nested_prefetch])
286+
assert inspector.inspect([none_prefetch, deep_nested_prefetch])
287+
paths = inspector_embed.inspect([none_prefetch, deep_nested_prefetch])
281288
assert len(paths) == 1 and set(paths[0].as_str_list()) == {
282289
"prefetch.prefetch.prefetch.query",
283290
"prefetch.prefetch.query",

0 commit comments

Comments
 (0)