diff --git a/llama-index-core/llama_index/core/vector_stores/types.py b/llama-index-core/llama_index/core/vector_stores/types.py index 12177296b8..b28c101bc4 100644 --- a/llama-index-core/llama_index/core/vector_stores/types.py +++ b/llama-index-core/llama_index/core/vector_stores/types.py @@ -75,9 +75,7 @@ class FilterOperator(str, Enum): ANY = "any" # Contains any (array of strings) ALL = "all" # Contains all (array of strings) TEXT_MATCH = "text_match" # full text match (allows you to search for a specific substring, token or phrase within the text field) - TEXT_MATCH_INSENSITIVE = ( - "text_match_insensitive" # full text match (case insensitive) - ) + TEXT_MATCH_INSENSITIVE = "text_match_insensitive" # full text match (case insensitive) CONTAINS = "contains" # metadata array contains value (string or number) IS_EMPTY = "is_empty" # the field is not exist or empty (null or empty array) @@ -126,6 +124,45 @@ def from_dict( """ return MetadataFilter.model_validate(filter_dict) + def matches(self, node: BaseNode) -> bool: + if self.operator == FilterOperator.EQ: + return node.metadata.get(self.key) == self.value + elif self.operator == FilterOperator.GT: + return node.metadata.get(self.key) > self.value + elif self.operator == FilterOperator.LT: + return node.metadata.get(self.key) < self.value + elif self.operator == FilterOperator.NE: + return node.metadata.get(self.key) != self.value + elif self.operator == FilterOperator.GTE: + return node.metadata.get(self.key) >= self.value + elif self.operator == FilterOperator.LTE: + return node.metadata.get(self.key) <= self.value + elif self.operator == FilterOperator.IN: + return node.metadata.get(self.key) in self.value + elif self.operator == FilterOperator.NIN: + return node.metadata.get(self.key) not in self.value + elif self.operator == FilterOperator.ANY: + raise NotImplementedError("ANY operator not implemented yet") + elif self.operator == FilterOperator.ALL: + raise NotImplementedError("ALL operator not implemented yet") + elif self.operator == FilterOperator.TEXT_MATCH: + raise NotImplementedError("TEXT_MATCH operator not implemented yet") + elif self.operator == FilterOperator.TEXT_MATCH_INSENSITIVE: + raise NotImplementedError("TEXT_MATCH_INSENSITIVE operator not implemented yet") + elif self.operator == FilterOperator.CONTAINS: + return node.metadata.get(self.key) in self.value + elif self.operator == FilterOperator.IS_EMPTY: + if self.key not in node.metadata: + return True + elif node.metadata.get(self.key) is None: + return True + elif hasattr(node.metadata.get(self.key), "__len__"): + return len(node.metadata.get(self.key)) == 0 + else: + return False + else: + raise ValueError(f"Unknown filter operator: {self.operator}") + # # TODO: Deprecate ExactMatchFilter and use MetadataFilter instead # # Keep class for now so that AutoRetriever can still work with old vector stores @@ -175,9 +212,7 @@ def from_dicts( """ return cls( - filters=[ - MetadataFilter.from_dict(filter_dict) for filter_dict in filter_dicts - ], + filters=[MetadataFilter.from_dict(filter_dict) for filter_dict in filter_dicts], condition=condition, ) @@ -185,10 +220,7 @@ def legacy_filters(self) -> List[ExactMatchFilter]: """Convert MetadataFilters to legacy ExactMatchFilters.""" filters = [] for filter in self.filters: - if ( - isinstance(filter, MetadataFilters) - or filter.operator != FilterOperator.EQ - ): + if isinstance(filter, MetadataFilters) or filter.operator != FilterOperator.EQ: raise ValueError( "Vector Store only supports exact match filters. " "Please use ExactMatchFilter or FilterOperator.EQ instead." @@ -196,6 +228,19 @@ def legacy_filters(self) -> List[ExactMatchFilter]: filters.append(ExactMatchFilter(key=filter.key, value=filter.value)) return filters + def matches(self, node: BaseNode) -> bool: + sub_conditions = [sub_filter.matches(node) for sub_filter in self.filters] + + if self.condition == FilterCondition.AND: + return all(sub_conditions) + elif self.condition == FilterCondition.OR: + return any(sub_conditions) + elif self.condition == FilterCondition.NOT: + assert len(sub_conditions) == 1, "NOT condition must have exactly one sub-filter" + return not sub_conditions[0] + else: + raise ValueError(f"Unknown filter condition: {self.condition}") + class VectorStoreQuerySpec(BaseModel): """Schema for a structured request for vector store @@ -308,9 +353,7 @@ def query(self, query: VectorStoreQuery, **kwargs: Any) -> VectorStoreQueryResul """Query vector store.""" ... - async def aquery( - self, query: VectorStoreQuery, **kwargs: Any - ) -> VectorStoreQueryResult: + async def aquery(self, query: VectorStoreQuery, **kwargs: Any) -> VectorStoreQueryResult: """ Asynchronously query vector store. NOTE: this is not implemented for all vector stores. If not implemented, @@ -318,9 +361,7 @@ async def aquery( """ return self.query(query, **kwargs) - def persist( - self, persist_path: str, fs: Optional[fsspec.AbstractFileSystem] = None - ) -> None: + def persist(self, persist_path: str, fs: Optional[fsspec.AbstractFileSystem] = None) -> None: return None @@ -416,9 +457,7 @@ async def aclear(self) -> None: def query(self, query: VectorStoreQuery, **kwargs: Any) -> VectorStoreQueryResult: """Query vector store.""" - async def aquery( - self, query: VectorStoreQuery, **kwargs: Any - ) -> VectorStoreQueryResult: + async def aquery(self, query: VectorStoreQuery, **kwargs: Any) -> VectorStoreQueryResult: """ Asynchronously query vector store. NOTE: this is not implemented for all vector stores. If not implemented, @@ -426,7 +465,5 @@ async def aquery( """ return self.query(query, **kwargs) - def persist( - self, persist_path: str, fs: Optional[fsspec.AbstractFileSystem] = None - ) -> None: + def persist(self, persist_path: str, fs: Optional[fsspec.AbstractFileSystem] = None) -> None: return None