|
296 | 296 | "source": [
|
297 | 297 | "#| export\n",
|
298 | 298 | "class UUIDTimeRange:\n",
|
299 |
| - " def __init__(self, start_date: Optional[datetime] = None, end_date: Optional[datetime] = None, time_delta: Optional[timedelta] = None, start_inclusive=True, end_inclusive=False):\n", |
| 299 | + " \n", |
| 300 | + " @staticmethod\n", |
| 301 | + " def _parse_datetime(input_datetime: Union[datetime, str]):\n", |
| 302 | + " \"\"\"\n", |
| 303 | + " Parse a datetime object or string representation of a datetime.\n", |
| 304 | + "\n", |
| 305 | + " Args:\n", |
| 306 | + " input_datetime (datetime or str): Input datetime or string.\n", |
| 307 | + "\n", |
| 308 | + " Returns:\n", |
| 309 | + " datetime: Parsed datetime object.\n", |
| 310 | + "\n", |
| 311 | + " Raises:\n", |
| 312 | + " ValueError: If the input cannot be parsed as a datetime.\n", |
| 313 | + " \"\"\"\n", |
| 314 | + " if input_datetime is None or input_datetime == \"None\":\n", |
| 315 | + " return None\n", |
| 316 | + " \n", |
| 317 | + " if isinstance(input_datetime, datetime):\n", |
| 318 | + " # If input is already a datetime object, return it as is\n", |
| 319 | + " return input_datetime\n", |
| 320 | + "\n", |
| 321 | + " if isinstance(input_datetime, str):\n", |
| 322 | + " try:\n", |
| 323 | + " # Attempt to parse the input string into a datetime\n", |
| 324 | + " return datetime.fromisoformat(input_datetime)\n", |
| 325 | + " except ValueError:\n", |
| 326 | + " raise ValueError(\"Invalid datetime string format: {}\".format(input_datetime))\n", |
| 327 | + "\n", |
| 328 | + " raise ValueError(\"Input must be a datetime object or string\")\n", |
| 329 | + "\n", |
| 330 | + " def __init__(self, start_date: Optional[Union[datetime, str]] = None, end_date: Optional[Union[datetime, str]] = None, time_delta: Optional[timedelta] = None, start_inclusive=True, end_inclusive=False):\n", |
300 | 331 | " \"\"\"\n",
|
301 | 332 | " A UUIDTimeRange is a time range predicate on the UUID Version 1 timestamps. \n",
|
302 | 333 | " \n",
|
303 | 334 | " Note that naive datetime objects are interpreted as local time on the python client side and converted to UTC before being sent to the database.\n",
|
304 | 335 | " \"\"\"\n",
|
| 336 | + " start_date = UUIDTimeRange._parse_datetime(start_date)\n", |
| 337 | + " end_date = UUIDTimeRange._parse_datetime(end_date)\n", |
| 338 | + "\n", |
305 | 339 | " if start_date is not None and end_date is not None:\n",
|
306 | 340 | " if start_date > end_date:\n",
|
307 | 341 | " raise Exception(\"start_date must be before end_date\")\n",
|
|
726 | 760 | " raise ValueError(\"Unknown filter type: {filter_type}\".format(filter_type=type(filter)))\n",
|
727 | 761 | "\n",
|
728 | 762 | " return (where, params)\n",
|
729 |
| - " \n", |
730 |
| - " def _parse_datetime(self, input_datetime):\n", |
731 |
| - " \"\"\"\n", |
732 |
| - " Parse a datetime object or string representation of a datetime.\n", |
733 |
| - "\n", |
734 |
| - " Args:\n", |
735 |
| - " input_datetime (datetime or str): Input datetime or string.\n", |
736 |
| - "\n", |
737 |
| - " Returns:\n", |
738 |
| - " datetime: Parsed datetime object.\n", |
739 |
| - "\n", |
740 |
| - " Raises:\n", |
741 |
| - " ValueError: If the input cannot be parsed as a datetime.\n", |
742 |
| - " \"\"\"\n", |
743 |
| - " if input_datetime is None:\n", |
744 |
| - " return None\n", |
745 |
| - " \n", |
746 |
| - " if isinstance(input_datetime, datetime):\n", |
747 |
| - " # If input is already a datetime object, return it as is\n", |
748 |
| - " return input_datetime\n", |
749 |
| - "\n", |
750 |
| - " if isinstance(input_datetime, str):\n", |
751 |
| - " try:\n", |
752 |
| - " # Attempt to parse the input string into a datetime\n", |
753 |
| - " return datetime.fromisoformat(input_datetime)\n", |
754 |
| - " except ValueError:\n", |
755 |
| - " raise ValueError(\"Invalid datetime string format\")\n", |
756 |
| - "\n", |
757 |
| - " raise ValueError(\"Input must be a datetime object or string\")\n", |
758 |
| - "\n", |
759 | 763 | "\n",
|
760 | 764 | " def search_query(\n",
|
761 | 765 | " self, \n",
|
|
785 | 789 | " if self.infer_filters:\n",
|
786 | 790 | " if uuid_time_filter is None and isinstance(filter, dict):\n",
|
787 | 791 | " if \"__start_date\" in filter or \"__end_date\" in filter:\n",
|
788 |
| - " start_date = self._parse_datetime(filter.get(\"__start_date\"))\n", |
789 |
| - " end_date = self._parse_datetime(filter.get(\"__end_date\"))\n", |
| 792 | + " start_date = UUIDTimeRange._parse_datetime(filter.get(\"__start_date\"))\n", |
| 793 | + " end_date = UUIDTimeRange._parse_datetime(filter.get(\"__end_date\"))\n", |
790 | 794 | " \n",
|
791 | 795 | " uuid_time_filter = UUIDTimeRange(start_date, end_date)\n",
|
792 | 796 | " \n",
|
|
1506 | 1510 | " #using uuid_time_filter\n",
|
1507 | 1511 | " rec = await vec.search([1.0, 2.0], limit=4, uuid_time_filter=UUIDTimeRange(start_date, end_date))\n",
|
1508 | 1512 | " assert len(rec) == expected\n",
|
| 1513 | + " rec = await vec.search([1.0, 2.0], limit=4, uuid_time_filter=UUIDTimeRange(str(start_date), str(end_date)))\n", |
| 1514 | + " assert len(rec) == expected\n", |
1509 | 1515 | " \n",
|
1510 | 1516 | " #using filters\n",
|
1511 | 1517 | " filter = {}\n",
|
|
2248 | 2254 | " #using uuid_time_filter\n",
|
2249 | 2255 | " rec = vec.search([1.0, 2.0], limit=4, uuid_time_filter=UUIDTimeRange(start_date, end_date))\n",
|
2250 | 2256 | " assert len(rec) == expected\n",
|
| 2257 | + " rec = vec.search([1.0, 2.0], limit=4, uuid_time_filter=UUIDTimeRange(str(start_date), str(end_date)))\n", |
| 2258 | + " assert len(rec) == expected\n", |
2251 | 2259 | " \n",
|
2252 | 2260 | " #using filters\n",
|
2253 | 2261 | " filter = {}\n",
|
|
0 commit comments