Skip to content

Commit 593f878

Browse files
committed
Infer start and end date from filters
1 parent 694f862 commit 593f878

File tree

4 files changed

+145
-10
lines changed

4 files changed

+145
-10
lines changed

nbs/00_vector.ipynb

Lines changed: 67 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -490,7 +490,8 @@
490490
" num_dimensions: int,\n",
491491
" distance_type: str,\n",
492492
" id_type: str,\n",
493-
" time_partition_interval: Optional[timedelta]) -> None:\n",
493+
" time_partition_interval: Optional[timedelta],\n",
494+
" infer_filters: bool) -> None:\n",
494495
" \"\"\"\n",
495496
" Initializes a base Vector object to generate queries for vector clients.\n",
496497
"\n",
@@ -522,6 +523,7 @@
522523
"\n",
523524
" self.id_type = id_type.lower()\n",
524525
" self.time_partition_interval = time_partition_interval\n",
526+
" self.infer_filters = infer_filters\n",
525527
"\n",
526528
" def _quote_ident(self, ident):\n",
527529
" \"\"\"\n",
@@ -713,6 +715,36 @@
713715
" raise ValueError(\"Unknown filter type: {filter_type}\".format(filter_type=type(filter)))\n",
714716
"\n",
715717
" return (where, params)\n",
718+
" \n",
719+
" def _parse_datetime(self, input_datetime):\n",
720+
" \"\"\"\n",
721+
" Parse a datetime object or string representation of a datetime.\n",
722+
"\n",
723+
" Args:\n",
724+
" input_datetime (datetime or str): Input datetime or string.\n",
725+
"\n",
726+
" Returns:\n",
727+
" datetime: Parsed datetime object.\n",
728+
"\n",
729+
" Raises:\n",
730+
" ValueError: If the input cannot be parsed as a datetime.\n",
731+
" \"\"\"\n",
732+
" if input_datetime is None:\n",
733+
" return None\n",
734+
" \n",
735+
" if isinstance(input_datetime, datetime):\n",
736+
" # If input is already a datetime object, return it as is\n",
737+
" return input_datetime\n",
738+
"\n",
739+
" if isinstance(input_datetime, str):\n",
740+
" try:\n",
741+
" # Attempt to parse the input string into a datetime\n",
742+
" return datetime.fromisoformat(input_datetime)\n",
743+
" except ValueError:\n",
744+
" raise ValueError(\"Invalid datetime string format\")\n",
745+
"\n",
746+
" raise ValueError(\"Input must be a datetime object or string\")\n",
747+
"\n",
716748
"\n",
717749
" def search_query(\n",
718750
" self, \n",
@@ -739,6 +771,20 @@
739771
" distance = \"-1.0\"\n",
740772
" order_by_clause = \"\"\n",
741773
"\n",
774+
" if self.infer_filters:\n",
775+
" if uuid_time_filter is None and isinstance(filter, dict):\n",
776+
" if \"__start_date\" in filter or \"__end_date\" in filter:\n",
777+
" start_date = self._parse_datetime(filter.get(\"__start_date\"))\n",
778+
" end_date = self._parse_datetime(filter.get(\"__end_date\"))\n",
779+
" \n",
780+
" uuid_time_filter = UUIDTimeRange(start_date, end_date)\n",
781+
" \n",
782+
" if start_date is not None:\n",
783+
" del filter[\"__start_date\"]\n",
784+
" if end_date is not None:\n",
785+
" del filter[\"__end_date\"]\n",
786+
"\n",
787+
"\n",
742788
" where_clauses = []\n",
743789
" if filter is not None:\n",
744790
" (where_filter, params) = self._where_clause_for_filter(params, filter)\n",
@@ -836,7 +882,8 @@
836882
" distance_type: str = 'cosine',\n",
837883
" id_type='UUID',\n",
838884
" time_partition_interval: Optional[timedelta] = None,\n",
839-
" max_db_connections: Optional[int] = None\n",
885+
" max_db_connections: Optional[int] = None,\n",
886+
" infer_filters: bool = True,\n",
840887
" ) -> None:\n",
841888
" \"\"\"\n",
842889
" Initializes a async client for storing vector data.\n",
@@ -855,7 +902,7 @@
855902
" The type of the id column. Can be either 'UUID' or 'TEXT'.\n",
856903
" \"\"\"\n",
857904
" self.builder = QueryBuilder(\n",
858-
" table_name, num_dimensions, distance_type, id_type, time_partition_interval)\n",
905+
" table_name, num_dimensions, distance_type, id_type, time_partition_interval, infer_filters)\n",
859906
" self.service_url = service_url\n",
860907
" self.pool = None\n",
861908
" self.max_db_connections = max_db_connections\n",
@@ -1444,8 +1491,20 @@
14441491
"assert not await vec.table_is_empty()\n",
14451492
"rec = await vec.search([1.0, 2.0], limit=4, uuid_time_filter=UUIDTimeRange(specific_datetime-timedelta(days=7), specific_datetime+timedelta(days=7)))\n",
14461493
"assert len(rec) == 1\n",
1494+
"rec = await vec.search([1.0, 2.0], limit=4, filter={\"__start_date\": specific_datetime-timedelta(days=7), \"__end_date\": specific_datetime+timedelta(days=7)})\n",
1495+
"assert len(rec) == 1\n",
1496+
"rec = await vec.search([1.0, 2.0], limit=4, filter={\"__start_date\": str(specific_datetime-timedelta(days=7)), \"__end_date\": str(specific_datetime+timedelta(days=7))})\n",
1497+
"assert len(rec) == 1\n",
1498+
"rec = await vec.search([1.0, 2.0], limit=4, filter={\"__start_date\": str(specific_datetime-timedelta(days=7))})\n",
1499+
"assert len(rec) == 2\n",
1500+
"rec = await vec.search([1.0, 2.0], limit=4, filter={\"__end_date\": str(specific_datetime+timedelta(days=7))})\n",
1501+
"assert len(rec) == 1\n",
14471502
"rec = await vec.search([1.0, 2.0], limit=4, uuid_time_filter=UUIDTimeRange(specific_datetime-timedelta(days=7), specific_datetime-timedelta(days=2)))\n",
14481503
"assert len(rec) == 0\n",
1504+
"rec = await vec.search([1.0, 2.0], limit=4, filter={\"__start_date\": specific_datetime-timedelta(days=7), \"__end_date\": specific_datetime-timedelta(days=2)})\n",
1505+
"assert len(rec) == 0\n",
1506+
"rec = await vec.search([1.0, 2.0], limit=4, filter={\"__start_date\": str(specific_datetime-timedelta(days=7)), \"__end_date\": str(specific_datetime-timedelta(days=2))})\n",
1507+
"assert len(rec) == 0\n",
14491508
"rec = await vec.search([1.0, 2.0], limit=4, uuid_time_filter=UUIDTimeRange(specific_datetime-timedelta(days=7)))\n",
14501509
"assert len(rec) == 2\n",
14511510
"rec = await vec.search([1.0, 2.0], limit=4, uuid_time_filter=UUIDTimeRange(start_date=specific_datetime, time_delta=timedelta(days=7)))\n",
@@ -1500,7 +1559,8 @@
15001559
" distance_type: str = 'cosine',\n",
15011560
" id_type='UUID',\n",
15021561
" time_partition_interval: Optional[timedelta] = None,\n",
1503-
" max_db_connections: Optional[int] = None\n",
1562+
" max_db_connections: Optional[int] = None,\n",
1563+
" infer_filters: bool = True,\n",
15041564
" ) -> None:\n",
15051565
" \"\"\"\n",
15061566
" Initializes a sync client for storing vector data.\n",
@@ -1519,7 +1579,7 @@
15191579
" The type of the primary id column. Can be either 'UUID' or 'TEXT'.\n",
15201580
" \"\"\"\n",
15211581
" self.builder = QueryBuilder(\n",
1522-
" table_name, num_dimensions, distance_type, id_type, time_partition_interval)\n",
1582+
" table_name, num_dimensions, distance_type, id_type, time_partition_interval, infer_filters)\n",
15231583
" self.service_url = service_url\n",
15241584
" self.pool = None\n",
15251585
" self.max_db_connections = max_db_connections\n",
@@ -2147,6 +2207,8 @@
21472207
"assert not vec.table_is_empty()\n",
21482208
"rec = vec.search([1.0, 2.0], limit=4, uuid_time_filter=UUIDTimeRange(specific_datetime-timedelta(days=7), specific_datetime+timedelta(days=7)))\n",
21492209
"assert len(rec) == 1\n",
2210+
"rec = vec.search([1.0, 2.0], limit=4, filter={\"__start_date\": specific_datetime-timedelta(days=7), \"__end_date\": specific_datetime+timedelta(days=7)})\n",
2211+
"assert len(rec) == 1\n",
21502212
"rec = vec.search([1.0, 2.0], limit=4, uuid_time_filter=UUIDTimeRange(specific_datetime-timedelta(days=7), specific_datetime-timedelta(days=2)))\n",
21512213
"assert len(rec) == 0\n",
21522214
"rec = vec.search([1.0, 2.0], limit=4, uuid_time_filter=UUIDTimeRange(specific_datetime-timedelta(days=7)))\n",

nbs/tsv_python_getting_started_tutorial.ipynb

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,29 @@
159159
"Each partition will consist of data for the specified length of time. We'll use 7 days for simplicity, but you can pick whatever value make sense for your use case -- for example if you query recent vectors frequently you might want to use a smaller time delta like 1 day, or if you query vectors over a decade long time period then you might want to use a larger time delta like 6 months or 1 year."
160160
]
161161
},
162+
{
163+
"cell_type": "code",
164+
"execution_count": null,
165+
"metadata": {},
166+
"outputs": [],
167+
"source": [
168+
"#| hide\n",
169+
"import asyncpg"
170+
]
171+
},
172+
{
173+
"cell_type": "code",
174+
"execution_count": null,
175+
"metadata": {},
176+
"outputs": [],
177+
"source": [
178+
"#| hide\n",
179+
"con = await asyncpg.connect(TIMESCALE_SERVICE_URL)\n",
180+
"await con.execute(\"DROP TABLE IF EXISTS commit_history;\")\n",
181+
"await con.execute(\"DROP EXTENSION IF EXISTS vector CASCADE\")\n",
182+
"await con.close()"
183+
]
184+
},
162185
{
163186
"cell_type": "code",
164187
"execution_count": null,

timescale_vector/_modidx.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,8 @@
7878
'timescale_vector/client.py'),
7979
'timescale_vector.client.QueryBuilder._get_embedding_index_name': ( 'vector.html#querybuilder._get_embedding_index_name',
8080
'timescale_vector/client.py'),
81+
'timescale_vector.client.QueryBuilder._parse_datetime': ( 'vector.html#querybuilder._parse_datetime',
82+
'timescale_vector/client.py'),
8183
'timescale_vector.client.QueryBuilder._quote_ident': ( 'vector.html#querybuilder._quote_ident',
8284
'timescale_vector/client.py'),
8385
'timescale_vector.client.QueryBuilder._where_clause_for_filter': ( 'vector.html#querybuilder._where_clause_for_filter',

timescale_vector/client.py

Lines changed: 53 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -379,7 +379,8 @@ def __init__(
379379
num_dimensions: int,
380380
distance_type: str,
381381
id_type: str,
382-
time_partition_interval: Optional[timedelta]) -> None:
382+
time_partition_interval: Optional[timedelta],
383+
infer_filters: bool) -> None:
383384
"""
384385
Initializes a base Vector object to generate queries for vector clients.
385386
@@ -411,6 +412,7 @@ def __init__(
411412

412413
self.id_type = id_type.lower()
413414
self.time_partition_interval = time_partition_interval
415+
self.infer_filters = infer_filters
414416

415417
def _quote_ident(self, ident):
416418
"""
@@ -602,6 +604,36 @@ def _where_clause_for_filter(self, params: List, filter: Optional[Union[Dict[str
602604
raise ValueError("Unknown filter type: {filter_type}".format(filter_type=type(filter)))
603605

604606
return (where, params)
607+
608+
def _parse_datetime(self, input_datetime):
609+
"""
610+
Parse a datetime object or string representation of a datetime.
611+
612+
Args:
613+
input_datetime (datetime or str): Input datetime or string.
614+
615+
Returns:
616+
datetime: Parsed datetime object.
617+
618+
Raises:
619+
ValueError: If the input cannot be parsed as a datetime.
620+
"""
621+
if input_datetime is None:
622+
return None
623+
624+
if isinstance(input_datetime, datetime):
625+
# If input is already a datetime object, return it as is
626+
return input_datetime
627+
628+
if isinstance(input_datetime, str):
629+
try:
630+
# Attempt to parse the input string into a datetime
631+
return datetime.fromisoformat(input_datetime)
632+
except ValueError:
633+
raise ValueError("Invalid datetime string format")
634+
635+
raise ValueError("Input must be a datetime object or string")
636+
605637

606638
def search_query(
607639
self,
@@ -628,6 +660,20 @@ def search_query(
628660
distance = "-1.0"
629661
order_by_clause = ""
630662

663+
if self.infer_filters:
664+
if uuid_time_filter is None and isinstance(filter, dict):
665+
if "__start_date" in filter or "__end_date" in filter:
666+
start_date = self._parse_datetime(filter.get("__start_date"))
667+
end_date = self._parse_datetime(filter.get("__end_date"))
668+
669+
uuid_time_filter = UUIDTimeRange(start_date, end_date)
670+
671+
if start_date is not None:
672+
del filter["__start_date"]
673+
if end_date is not None:
674+
del filter["__end_date"]
675+
676+
631677
where_clauses = []
632678
if filter is not None:
633679
(where_filter, params) = self._where_clause_for_filter(params, filter)
@@ -671,7 +717,8 @@ def __init__(
671717
distance_type: str = 'cosine',
672718
id_type='UUID',
673719
time_partition_interval: Optional[timedelta] = None,
674-
max_db_connections: Optional[int] = None
720+
max_db_connections: Optional[int] = None,
721+
infer_filters: bool = True,
675722
) -> None:
676723
"""
677724
Initializes a async client for storing vector data.
@@ -690,7 +737,7 @@ def __init__(
690737
The type of the id column. Can be either 'UUID' or 'TEXT'.
691738
"""
692739
self.builder = QueryBuilder(
693-
table_name, num_dimensions, distance_type, id_type, time_partition_interval)
740+
table_name, num_dimensions, distance_type, id_type, time_partition_interval, infer_filters)
694741
self.service_url = service_url
695742
self.pool = None
696743
self.max_db_connections = max_db_connections
@@ -933,7 +980,8 @@ def __init__(
933980
distance_type: str = 'cosine',
934981
id_type='UUID',
935982
time_partition_interval: Optional[timedelta] = None,
936-
max_db_connections: Optional[int] = None
983+
max_db_connections: Optional[int] = None,
984+
infer_filters: bool = True,
937985
) -> None:
938986
"""
939987
Initializes a sync client for storing vector data.
@@ -952,7 +1000,7 @@ def __init__(
9521000
The type of the primary id column. Can be either 'UUID' or 'TEXT'.
9531001
"""
9541002
self.builder = QueryBuilder(
955-
table_name, num_dimensions, distance_type, id_type, time_partition_interval)
1003+
table_name, num_dimensions, distance_type, id_type, time_partition_interval, infer_filters)
9561004
self.service_url = service_url
9571005
self.pool = None
9581006
self.max_db_connections = max_db_connections

0 commit comments

Comments
 (0)