diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 51dc6f7e8..0bad99c0c 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -6,15 +6,18 @@ Changelog .. rst-class:: emphasize-children -0.23 +0.24 ==== -0.23.1 +0.24.0 (unreleased) ------ Fixed ^^^^^ - Rename pypika to pypika_tortoise for fixing package name conflict (#1829) - Concurrent connection pool initialization (#1825) +Changed +^^^^^^^ +- Optimize field conversion to database format to speed up `create` and `bulk_create` (#1840) 0.23.0 ------ diff --git a/tests/fields/test_time.py b/tests/fields/test_time.py index e963e8e01..94f11cf3a 100644 --- a/tests/fields/test_time.py +++ b/tests/fields/test_time.py @@ -26,6 +26,14 @@ async def test_empty(self): class TestDatetimeFields(TestEmpty): + async def asyncSetUp(self): + await super().asyncSetUp() + timezone._reset_timezone_cache() + + async def asyncTearDown(self): + await super().asyncTearDown() + timezone._reset_timezone_cache() + def test_both_auto_bad(self): with self.assertRaisesRegex( ConfigurationError, "You can choose only 'auto_now' or 'auto_now_add'" diff --git a/tortoise/__init__.py b/tortoise/__init__.py index 8f4896a77..a234ebe0b 100644 --- a/tortoise/__init__.py +++ b/tortoise/__init__.py @@ -28,6 +28,7 @@ from tortoise.filters import get_m2m_filters from tortoise.log import logger from tortoise.models import Model, ModelMeta +from tortoise.timezone import _reset_timezone_cache from tortoise.utils import generate_schema_for_client @@ -614,6 +615,7 @@ async def _drop_databases(cls) -> None: def _init_timezone(cls, use_tz: bool, timezone: str) -> None: os.environ["USE_TZ"] = str(use_tz) os.environ["TIMEZONE"] = timezone + _reset_timezone_cache() def run_async(coro: Coroutine) -> None: diff --git a/tortoise/backends/base/executor.py b/tortoise/backends/base/executor.py index a0648eb00..08ecd9464 100644 --- a/tortoise/backends/base/executor.py +++ b/tortoise/backends/base/executor.py @@ -2,7 +2,6 @@ import datetime import decimal from copy import copy -from functools import partial from typing import ( TYPE_CHECKING, Any, @@ -24,7 +23,6 @@ from tortoise.exceptions import OperationalError from tortoise.expressions import Expression, ResolveContext -from tortoise.fields.base import Field from tortoise.fields.relational import ( BackwardFKRelation, BackwardOneToOneRelation, @@ -42,12 +40,11 @@ EXECUTOR_CACHE: Dict[ Tuple[str, Optional[str], str], - Tuple[list, str, list, str, Dict[str, Callable], str, Dict[str, str]], + Tuple[list, str, list, str, str, Dict[str, str]], ] = {} class BaseExecutor: - TO_DB_OVERRIDE: Dict[Type[Field], Callable] = {} FILTER_FUNC_OVERRIDE: Dict[Callable, Callable] = {} EXPLAIN_PREFIX: str = "EXPLAIN" DB_NATIVE = {bytes, str, int, float, decimal.Decimal, datetime.datetime, datetime.date} @@ -81,16 +78,6 @@ def __init__( self._prepare_insert_statement(columns_all, has_generated=False) ) - self.column_map: Dict[str, Callable[[Any, Any], Any]] = {} - for column in self.regular_columns_all: - field_object = self.model._meta.fields_map[column] - if field_object.__class__ in self.TO_DB_OVERRIDE: - self.column_map[column] = partial( - self.TO_DB_OVERRIDE[field_object.__class__], field_object - ) - else: - self.column_map[column] = field_object.to_db_value - table = self.model._meta.basetable basequery = cast(QueryBuilder, self.model._meta.basequery) self.delete_query = str( @@ -103,7 +90,6 @@ def __init__( self.insert_query, self.regular_columns_all, self.insert_query_all, - self.column_map, self.delete_query, self.update_cache, ) @@ -114,7 +100,6 @@ def __init__( self.insert_query, self.regular_columns_all, self.insert_query_all, - self.column_map, self.delete_query, self.update_cache, ) = EXECUTOR_CACHE[key] @@ -194,7 +179,9 @@ def parameter(self, pos: int) -> Parameter: async def execute_insert(self, instance: "Model") -> None: if not instance._custom_generated_pk: values = [ - self.column_map[field_name](getattr(instance, field_name), instance) + self.model._meta.fields_map[field_name].to_db_value( + getattr(instance, field_name), instance + ) for field_name in self.regular_columns ] insert_result = await self.db.execute_insert(self.insert_query, values) @@ -202,7 +189,9 @@ async def execute_insert(self, instance: "Model") -> None: else: values = [ - self.column_map[field_name](getattr(instance, field_name), instance) + self.model._meta.fields_map[field_name].to_db_value( + getattr(instance, field_name), instance + ) for field_name in self.regular_columns_all ] await self.db.execute_insert(self.insert_query_all, values) @@ -219,14 +208,18 @@ async def execute_bulk_insert( if instance._custom_generated_pk: values_lists_all.append( [ - self.column_map[field_name](getattr(instance, field_name), instance) + self.model._meta.fields_map[field_name].to_db_value( + getattr(instance, field_name), instance + ) for field_name in self.regular_columns_all ] ) else: values_lists.append( [ - self.column_map[field_name](getattr(instance, field_name), instance) + self.model._meta.fields_map[field_name].to_db_value( + getattr(instance, field_name), instance + ) for field_name in self.regular_columns ] ) @@ -292,7 +285,7 @@ async def execute_update( if isinstance(instance_field, Expression): expressions[field] = instance_field else: - value = self.column_map[field](instance_field, instance) + value = self.model._meta.fields_map[field].to_db_value(instance_field, instance) values.append(value) values.append(self.model._meta.pk.to_db_value(instance.pk, instance)) return ( diff --git a/tortoise/backends/mssql/executor.py b/tortoise/backends/mssql/executor.py index 9d41171a6..e2b9fbcb3 100644 --- a/tortoise/backends/mssql/executor.py +++ b/tortoise/backends/mssql/executor.py @@ -1,24 +1,9 @@ -from typing import Any, Optional, Type, Union +from typing import Any -from tortoise import Model, fields from tortoise.backends.odbc.executor import ODBCExecutor from tortoise.exceptions import UnSupportedError -from tortoise.fields import BooleanField - - -def to_db_bool( - self: BooleanField, value: Optional[Union[bool, int]], instance: Union[Type[Model], Model] -) -> Optional[int]: - self.validate(value) - if value is None: - return None - return int(bool(value)) class MSSQLExecutor(ODBCExecutor): - TO_DB_OVERRIDE = { - fields.BooleanField: to_db_bool, - } - async def execute_explain(self, sql: str) -> Any: raise UnSupportedError("MSSQL does not support explain") diff --git a/tortoise/backends/sqlite/executor.py b/tortoise/backends/sqlite/executor.py index a46d05960..1efd6fef9 100644 --- a/tortoise/backends/sqlite/executor.py +++ b/tortoise/backends/sqlite/executor.py @@ -1,99 +1,24 @@ import datetime import sqlite3 from decimal import Decimal -from typing import Optional, Type, Union -import pytz - -from tortoise import Model, fields, timezone +from tortoise import Model from tortoise.backends.base.executor import BaseExecutor from tortoise.contrib.sqlite.regex import ( insensitive_posix_sqlite_regexp, posix_sqlite_regexp, ) -from tortoise.fields import ( - BigIntField, - BooleanField, - DatetimeField, - DecimalField, - IntField, - SmallIntField, - TimeField, -) +from tortoise.fields import BigIntField, IntField, SmallIntField from tortoise.filters import insensitive_posix_regex, posix_regex - -def to_db_bool( - self: BooleanField, value: Optional[Union[bool, int]], instance: Union[Type[Model], Model] -) -> Optional[int]: - self.validate(value) - if value is None: - return None - return int(bool(value)) - - -def to_db_decimal( - self: DecimalField, - value: Optional[Union[str, float, int, Decimal]], - instance: Union[Type[Model], Model], -) -> Optional[str]: - self.validate(value) - if value is None: - return None - return str(Decimal(value).quantize(self.quant).normalize()) - - -def to_db_datetime( - self: DatetimeField, value: Optional[datetime.datetime], instance: Union[Type[Model], Model] -) -> Optional[str]: - self.validate(value) - # Only do this if it is a Model instance, not class. Test for guaranteed instance var - if hasattr(instance, "_saved_in_db") and ( - self.auto_now - or (self.auto_now_add and getattr(instance, self.model_field_name, None) is None) - ): - if timezone.get_use_tz(): - value = datetime.datetime.now(tz=pytz.utc) - else: - value = datetime.datetime.now(tz=timezone.get_default_timezone()) - setattr(instance, self.model_field_name, value) - return value.isoformat(" ") - if isinstance(value, datetime.datetime): - return value.isoformat(" ") - return None - - -def to_db_time( - self: TimeField, value: Optional[datetime.time], instance: Union[Type[Model], Model] -) -> Optional[str]: - self.validate(value) - if hasattr(instance, "_saved_in_db") and ( - self.auto_now - or (self.auto_now_add and getattr(instance, self.model_field_name, None) is None) - ): - if timezone.get_use_tz(): - value = datetime.datetime.now(tz=pytz.utc).time() - else: - value = datetime.datetime.now(tz=timezone.get_default_timezone()).time() - setattr(instance, self.model_field_name, value) - return value.isoformat() - if isinstance(value, datetime.time): - return value.isoformat() - return None - - -# Converts Decimal to string for sqlite in cases where it's hard to know the +# Conversion for the cases where it's hard to know the # related field, e.g. in raw queries, math or annotations. sqlite3.register_adapter(Decimal, str) +sqlite3.register_adapter(datetime.date, lambda val: val.isoformat()) +sqlite3.register_adapter(datetime.datetime, lambda val: val.isoformat(" ")) class SqliteExecutor(BaseExecutor): - TO_DB_OVERRIDE = { - fields.BooleanField: to_db_bool, - fields.DecimalField: to_db_decimal, - fields.DatetimeField: to_db_datetime, - fields.TimeField: to_db_time, - } EXPLAIN_PREFIX = "EXPLAIN QUERY PLAN" DB_NATIVE = {bytes, str, int, float} FILTER_FUNC_OVERRIDE = { diff --git a/tortoise/fields/base.py b/tortoise/fields/base.py index 0b7c13242..355bb0abb 100644 --- a/tortoise/fields/base.py +++ b/tortoise/fields/base.py @@ -259,11 +259,6 @@ def to_db_value(self, value: Any, instance: "Union[Type[Model], Model]") -> Any: if value is not None and not isinstance(value, self.field_type): value = self.field_type(value) # pylint: disable=E1102 - if self.__class__ in self.model._meta.db.executor_class.TO_DB_OVERRIDE: - value = self.model._meta.db.executor_class.TO_DB_OVERRIDE[self.__class__]( - self, value, instance - ) - self.validate(value) return value diff --git a/tortoise/queryset.py b/tortoise/queryset.py index dfeee7343..076d2f6ac 100644 --- a/tortoise/queryset.py +++ b/tortoise/queryset.py @@ -1195,8 +1195,6 @@ def _make_query(self) -> None: self.resolve_ordering(self.model, table, self._orderings, self._annotations) self.resolve_filters() - # Need to get executor to get correct column_map - executor = self._db.executor_class(model=self.model, db=self._db) for key, value in self.update_kwargs.items(): field_object = self.model._meta.fields_map.get(key) if not field_object: @@ -1207,7 +1205,7 @@ def _make_query(self) -> None: self.model._validate_relation_type(key, value) fk_field: str = field_object.source_field # type: ignore db_field = self.model._meta.fields_map[fk_field].source_field - value = executor.column_map[fk_field]( + value = self.model._meta.fields_map[fk_field].to_db_value( getattr(value, field_object.to_field_instance.model_field_name), None, ) @@ -1227,7 +1225,7 @@ def _make_query(self) -> None: ) ).term else: - value = executor.column_map[key](value, None) + value = self.model._meta.fields_map[key].to_db_value(value, None) self.query = self.query.set(db_field, value) @@ -1838,7 +1836,6 @@ def _make_queries(self) -> List[Tuple[str, List[Any]]]: ) self.resolve_filters() - executor = self._db.executor_class(model=self.model, db=self._db) pk_attr = self.model._meta.pk_attr source_pk_attr = self.model._meta.fields_map[pk_attr].source_field or pk_attr pk = Field(source_pk_attr) @@ -1848,7 +1845,7 @@ def _make_queries(self) -> List[Tuple[str, List[Any]]]: case = Case() pk_list = [] for obj in objects_item: - pk_value = executor.column_map[pk_attr](obj.pk, None) + pk_value = self.model._meta.fields_map[pk_attr].to_db_value(obj.pk, None) field_obj = obj._meta.fields_map[field] field_value = field_obj.to_db_value(getattr(obj, field), obj) case.when( @@ -1945,6 +1942,7 @@ def _make_queries(self) -> Tuple[str, str]: return self._executor.insert_query, self._executor.insert_query_all async def _execute_many(self, insert_sql: str, insert_sql_all: str) -> None: + fields_map = self.model._meta.fields_map for instance_chunk in chunk(self._objects, self._batch_size): values_lists_all = [] values_lists = [] @@ -1952,7 +1950,7 @@ async def _execute_many(self, insert_sql: str, insert_sql_all: str) -> None: if instance._custom_generated_pk: values_lists_all.append( [ - self._executor.column_map[field_name]( + fields_map[field_name].to_db_value( getattr(instance, field_name), instance ) for field_name in self._executor.regular_columns_all @@ -1961,7 +1959,7 @@ async def _execute_many(self, insert_sql: str, insert_sql_all: str) -> None: else: values_lists.append( [ - self._executor.column_map[field_name]( + fields_map[field_name].to_db_value( getattr(instance, field_name), instance ) for field_name in self._executor.regular_columns diff --git a/tortoise/timezone.py b/tortoise/timezone.py index 8cccefbeb..3081ad21b 100644 --- a/tortoise/timezone.py +++ b/tortoise/timezone.py @@ -1,10 +1,12 @@ import os from datetime import datetime, time, tzinfo +from functools import lru_cache from typing import Optional, Union import pytz +@lru_cache(maxsize=None) def get_use_tz() -> bool: """ Get use_tz from env set in Tortoise config. @@ -12,6 +14,7 @@ def get_use_tz() -> bool: return os.environ.get("USE_TZ") == "True" +@lru_cache(maxsize=None) def get_timezone() -> str: """ Get timezone from env set in Tortoise config. @@ -29,6 +32,7 @@ def now() -> datetime: return datetime.now(get_default_timezone()) +@lru_cache(maxsize=None) def get_default_timezone() -> tzinfo: """ Return the default time zone as a tzinfo instance. @@ -38,6 +42,13 @@ def get_default_timezone() -> tzinfo: return pytz.timezone(get_timezone()) +def _reset_timezone_cache() -> None: + """Reset timezone cache. For internal use only.""" + get_default_timezone.cache_clear() + get_use_tz.cache_clear() + get_timezone.cache_clear() + + def localtime(value: Optional[datetime] = None, timezone: Optional[str] = None) -> datetime: """ Convert an aware datetime.datetime to local time.