|
490 | 490 | " num_dimensions: int,\n", |
491 | 491 | " distance_type: str,\n", |
492 | 492 | " id_type: str,\n", |
493 | | - " time_partition_interval: Optional[timedelta]) -> None:\n", |
| 493 | + " time_partition_interval: Optional[timedelta],\n", |
| 494 | + " infer_filters: bool) -> None:\n", |
494 | 495 | " \"\"\"\n", |
495 | 496 | " Initializes a base Vector object to generate queries for vector clients.\n", |
496 | 497 | "\n", |
|
522 | 523 | "\n", |
523 | 524 | " self.id_type = id_type.lower()\n", |
524 | 525 | " self.time_partition_interval = time_partition_interval\n", |
| 526 | + " self.infer_filters = infer_filters\n", |
525 | 527 | "\n", |
526 | 528 | " def _quote_ident(self, ident):\n", |
527 | 529 | " \"\"\"\n", |
|
713 | 715 | " raise ValueError(\"Unknown filter type: {filter_type}\".format(filter_type=type(filter)))\n", |
714 | 716 | "\n", |
715 | 717 | " return (where, params)\n", |
| 718 | + " \n", |
| 719 | + " def _parse_datetime(self, input_datetime):\n", |
| 720 | + " \"\"\"\n", |
| 721 | + " Parse a datetime object or string representation of a datetime.\n", |
| 722 | + "\n", |
| 723 | + " Args:\n", |
| 724 | + " input_datetime (datetime or str): Input datetime or string.\n", |
| 725 | + "\n", |
| 726 | + " Returns:\n", |
| 727 | + " datetime: Parsed datetime object.\n", |
| 728 | + "\n", |
| 729 | + " Raises:\n", |
| 730 | + " ValueError: If the input cannot be parsed as a datetime.\n", |
| 731 | + " \"\"\"\n", |
| 732 | + " if input_datetime is None:\n", |
| 733 | + " return None\n", |
| 734 | + " \n", |
| 735 | + " if isinstance(input_datetime, datetime):\n", |
| 736 | + " # If input is already a datetime object, return it as is\n", |
| 737 | + " return input_datetime\n", |
| 738 | + "\n", |
| 739 | + " if isinstance(input_datetime, str):\n", |
| 740 | + " try:\n", |
| 741 | + " # Attempt to parse the input string into a datetime\n", |
| 742 | + " return datetime.fromisoformat(input_datetime)\n", |
| 743 | + " except ValueError:\n", |
| 744 | + " raise ValueError(\"Invalid datetime string format\")\n", |
| 745 | + "\n", |
| 746 | + " raise ValueError(\"Input must be a datetime object or string\")\n", |
| 747 | + "\n", |
716 | 748 | "\n", |
717 | 749 | " def search_query(\n", |
718 | 750 | " self, \n", |
|
739 | 771 | " distance = \"-1.0\"\n", |
740 | 772 | " order_by_clause = \"\"\n", |
741 | 773 | "\n", |
| 774 | + " if self.infer_filters:\n", |
| 775 | + " if uuid_time_filter is None and isinstance(filter, dict):\n", |
| 776 | + " if \"__start_date\" in filter or \"__end_date\" in filter:\n", |
| 777 | + " start_date = self._parse_datetime(filter.get(\"__start_date\"))\n", |
| 778 | + " end_date = self._parse_datetime(filter.get(\"__end_date\"))\n", |
| 779 | + " \n", |
| 780 | + " uuid_time_filter = UUIDTimeRange(start_date, end_date)\n", |
| 781 | + " \n", |
| 782 | + " if start_date is not None:\n", |
| 783 | + " del filter[\"__start_date\"]\n", |
| 784 | + " if end_date is not None:\n", |
| 785 | + " del filter[\"__end_date\"]\n", |
| 786 | + "\n", |
| 787 | + "\n", |
742 | 788 | " where_clauses = []\n", |
743 | 789 | " if filter is not None:\n", |
744 | 790 | " (where_filter, params) = self._where_clause_for_filter(params, filter)\n", |
|
836 | 882 | " distance_type: str = 'cosine',\n", |
837 | 883 | " id_type='UUID',\n", |
838 | 884 | " time_partition_interval: Optional[timedelta] = None,\n", |
839 | | - " max_db_connections: Optional[int] = None\n", |
| 885 | + " max_db_connections: Optional[int] = None,\n", |
| 886 | + " infer_filters: bool = True,\n", |
840 | 887 | " ) -> None:\n", |
841 | 888 | " \"\"\"\n", |
842 | 889 | " Initializes a async client for storing vector data.\n", |
|
855 | 902 | " The type of the id column. Can be either 'UUID' or 'TEXT'.\n", |
856 | 903 | " \"\"\"\n", |
857 | 904 | " self.builder = QueryBuilder(\n", |
858 | | - " table_name, num_dimensions, distance_type, id_type, time_partition_interval)\n", |
| 905 | + " table_name, num_dimensions, distance_type, id_type, time_partition_interval, infer_filters)\n", |
859 | 906 | " self.service_url = service_url\n", |
860 | 907 | " self.pool = None\n", |
861 | 908 | " self.max_db_connections = max_db_connections\n", |
|
1444 | 1491 | "assert not await vec.table_is_empty()\n", |
1445 | 1492 | "rec = await vec.search([1.0, 2.0], limit=4, uuid_time_filter=UUIDTimeRange(specific_datetime-timedelta(days=7), specific_datetime+timedelta(days=7)))\n", |
1446 | 1493 | "assert len(rec) == 1\n", |
| 1494 | + "rec = await vec.search([1.0, 2.0], limit=4, filter={\"__start_date\": specific_datetime-timedelta(days=7), \"__end_date\": specific_datetime+timedelta(days=7)})\n", |
| 1495 | + "assert len(rec) == 1\n", |
| 1496 | + "rec = await vec.search([1.0, 2.0], limit=4, filter={\"__start_date\": str(specific_datetime-timedelta(days=7)), \"__end_date\": str(specific_datetime+timedelta(days=7))})\n", |
| 1497 | + "assert len(rec) == 1\n", |
| 1498 | + "rec = await vec.search([1.0, 2.0], limit=4, filter={\"__start_date\": str(specific_datetime-timedelta(days=7))})\n", |
| 1499 | + "assert len(rec) == 2\n", |
| 1500 | + "rec = await vec.search([1.0, 2.0], limit=4, filter={\"__end_date\": str(specific_datetime+timedelta(days=7))})\n", |
| 1501 | + "assert len(rec) == 1\n", |
1447 | 1502 | "rec = await vec.search([1.0, 2.0], limit=4, uuid_time_filter=UUIDTimeRange(specific_datetime-timedelta(days=7), specific_datetime-timedelta(days=2)))\n", |
1448 | 1503 | "assert len(rec) == 0\n", |
| 1504 | + "rec = await vec.search([1.0, 2.0], limit=4, filter={\"__start_date\": specific_datetime-timedelta(days=7), \"__end_date\": specific_datetime-timedelta(days=2)})\n", |
| 1505 | + "assert len(rec) == 0\n", |
| 1506 | + "rec = await vec.search([1.0, 2.0], limit=4, filter={\"__start_date\": str(specific_datetime-timedelta(days=7)), \"__end_date\": str(specific_datetime-timedelta(days=2))})\n", |
| 1507 | + "assert len(rec) == 0\n", |
1449 | 1508 | "rec = await vec.search([1.0, 2.0], limit=4, uuid_time_filter=UUIDTimeRange(specific_datetime-timedelta(days=7)))\n", |
1450 | 1509 | "assert len(rec) == 2\n", |
1451 | 1510 | "rec = await vec.search([1.0, 2.0], limit=4, uuid_time_filter=UUIDTimeRange(start_date=specific_datetime, time_delta=timedelta(days=7)))\n", |
|
1500 | 1559 | " distance_type: str = 'cosine',\n", |
1501 | 1560 | " id_type='UUID',\n", |
1502 | 1561 | " time_partition_interval: Optional[timedelta] = None,\n", |
1503 | | - " max_db_connections: Optional[int] = None\n", |
| 1562 | + " max_db_connections: Optional[int] = None,\n", |
| 1563 | + " infer_filters: bool = True,\n", |
1504 | 1564 | " ) -> None:\n", |
1505 | 1565 | " \"\"\"\n", |
1506 | 1566 | " Initializes a sync client for storing vector data.\n", |
|
1519 | 1579 | " The type of the primary id column. Can be either 'UUID' or 'TEXT'.\n", |
1520 | 1580 | " \"\"\"\n", |
1521 | 1581 | " self.builder = QueryBuilder(\n", |
1522 | | - " table_name, num_dimensions, distance_type, id_type, time_partition_interval)\n", |
| 1582 | + " table_name, num_dimensions, distance_type, id_type, time_partition_interval, infer_filters)\n", |
1523 | 1583 | " self.service_url = service_url\n", |
1524 | 1584 | " self.pool = None\n", |
1525 | 1585 | " self.max_db_connections = max_db_connections\n", |
|
2147 | 2207 | "assert not vec.table_is_empty()\n", |
2148 | 2208 | "rec = vec.search([1.0, 2.0], limit=4, uuid_time_filter=UUIDTimeRange(specific_datetime-timedelta(days=7), specific_datetime+timedelta(days=7)))\n", |
2149 | 2209 | "assert len(rec) == 1\n", |
| 2210 | + "rec = vec.search([1.0, 2.0], limit=4, filter={\"__start_date\": specific_datetime-timedelta(days=7), \"__end_date\": specific_datetime+timedelta(days=7)})\n", |
| 2211 | + "assert len(rec) == 1\n", |
2150 | 2212 | "rec = vec.search([1.0, 2.0], limit=4, uuid_time_filter=UUIDTimeRange(specific_datetime-timedelta(days=7), specific_datetime-timedelta(days=2)))\n", |
2151 | 2213 | "assert len(rec) == 0\n", |
2152 | 2214 | "rec = vec.search([1.0, 2.0], limit=4, uuid_time_filter=UUIDTimeRange(specific_datetime-timedelta(days=7)))\n", |
|
0 commit comments