Skip to content

Commit 46e3aef

Browse files
authored
Optimize field conversion to database format (#1840)
* Remove unnecessary code from to_db_value * Remove TO_DB_OVERRIDE * Add sqlite converters * Cache timezone related settings * Get rid of column_map
1 parent 8b5ac14 commit 46e3aef

File tree

9 files changed

+52
-132
lines changed

9 files changed

+52
-132
lines changed

CHANGELOG.rst

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,18 @@ Changelog
66

77
.. rst-class:: emphasize-children
88

9-
0.23
9+
0.24
1010
====
1111

12-
0.23.1
12+
0.24.0 (unreleased)
1313
------
1414
Fixed
1515
^^^^^
1616
- Rename pypika to pypika_tortoise for fixing package name conflict (#1829)
1717
- Concurrent connection pool initialization (#1825)
18+
Changed
19+
^^^^^^^
20+
- Optimize field conversion to database format to speed up `create` and `bulk_create` (#1840)
1821

1922
0.23.0
2023
------

tests/fields/test_time.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,14 @@ async def test_empty(self):
2626

2727

2828
class TestDatetimeFields(TestEmpty):
29+
async def asyncSetUp(self):
30+
await super().asyncSetUp()
31+
timezone._reset_timezone_cache()
32+
33+
async def asyncTearDown(self):
34+
await super().asyncTearDown()
35+
timezone._reset_timezone_cache()
36+
2937
def test_both_auto_bad(self):
3038
with self.assertRaisesRegex(
3139
ConfigurationError, "You can choose only 'auto_now' or 'auto_now_add'"

tortoise/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from tortoise.filters import get_m2m_filters
2929
from tortoise.log import logger
3030
from tortoise.models import Model, ModelMeta
31+
from tortoise.timezone import _reset_timezone_cache
3132
from tortoise.utils import generate_schema_for_client
3233

3334

@@ -614,6 +615,7 @@ async def _drop_databases(cls) -> None:
614615
def _init_timezone(cls, use_tz: bool, timezone: str) -> None:
615616
os.environ["USE_TZ"] = str(use_tz)
616617
os.environ["TIMEZONE"] = timezone
618+
_reset_timezone_cache()
617619

618620

619621
def run_async(coro: Coroutine) -> None:

tortoise/backends/base/executor.py

Lines changed: 14 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
import datetime
33
import decimal
44
from copy import copy
5-
from functools import partial
65
from typing import (
76
TYPE_CHECKING,
87
Any,
@@ -24,7 +23,6 @@
2423

2524
from tortoise.exceptions import OperationalError
2625
from tortoise.expressions import Expression, ResolveContext
27-
from tortoise.fields.base import Field
2826
from tortoise.fields.relational import (
2927
BackwardFKRelation,
3028
BackwardOneToOneRelation,
@@ -42,12 +40,11 @@
4240

4341
EXECUTOR_CACHE: Dict[
4442
Tuple[str, Optional[str], str],
45-
Tuple[list, str, list, str, Dict[str, Callable], str, Dict[str, str]],
43+
Tuple[list, str, list, str, str, Dict[str, str]],
4644
] = {}
4745

4846

4947
class BaseExecutor:
50-
TO_DB_OVERRIDE: Dict[Type[Field], Callable] = {}
5148
FILTER_FUNC_OVERRIDE: Dict[Callable, Callable] = {}
5249
EXPLAIN_PREFIX: str = "EXPLAIN"
5350
DB_NATIVE = {bytes, str, int, float, decimal.Decimal, datetime.datetime, datetime.date}
@@ -81,16 +78,6 @@ def __init__(
8178
self._prepare_insert_statement(columns_all, has_generated=False)
8279
)
8380

84-
self.column_map: Dict[str, Callable[[Any, Any], Any]] = {}
85-
for column in self.regular_columns_all:
86-
field_object = self.model._meta.fields_map[column]
87-
if field_object.__class__ in self.TO_DB_OVERRIDE:
88-
self.column_map[column] = partial(
89-
self.TO_DB_OVERRIDE[field_object.__class__], field_object
90-
)
91-
else:
92-
self.column_map[column] = field_object.to_db_value
93-
9481
table = self.model._meta.basetable
9582
basequery = cast(QueryBuilder, self.model._meta.basequery)
9683
self.delete_query = str(
@@ -103,7 +90,6 @@ def __init__(
10390
self.insert_query,
10491
self.regular_columns_all,
10592
self.insert_query_all,
106-
self.column_map,
10793
self.delete_query,
10894
self.update_cache,
10995
)
@@ -114,7 +100,6 @@ def __init__(
114100
self.insert_query,
115101
self.regular_columns_all,
116102
self.insert_query_all,
117-
self.column_map,
118103
self.delete_query,
119104
self.update_cache,
120105
) = EXECUTOR_CACHE[key]
@@ -194,15 +179,19 @@ def parameter(self, pos: int) -> Parameter:
194179
async def execute_insert(self, instance: "Model") -> None:
195180
if not instance._custom_generated_pk:
196181
values = [
197-
self.column_map[field_name](getattr(instance, field_name), instance)
182+
self.model._meta.fields_map[field_name].to_db_value(
183+
getattr(instance, field_name), instance
184+
)
198185
for field_name in self.regular_columns
199186
]
200187
insert_result = await self.db.execute_insert(self.insert_query, values)
201188
await self._process_insert_result(instance, insert_result)
202189

203190
else:
204191
values = [
205-
self.column_map[field_name](getattr(instance, field_name), instance)
192+
self.model._meta.fields_map[field_name].to_db_value(
193+
getattr(instance, field_name), instance
194+
)
206195
for field_name in self.regular_columns_all
207196
]
208197
await self.db.execute_insert(self.insert_query_all, values)
@@ -219,14 +208,18 @@ async def execute_bulk_insert(
219208
if instance._custom_generated_pk:
220209
values_lists_all.append(
221210
[
222-
self.column_map[field_name](getattr(instance, field_name), instance)
211+
self.model._meta.fields_map[field_name].to_db_value(
212+
getattr(instance, field_name), instance
213+
)
223214
for field_name in self.regular_columns_all
224215
]
225216
)
226217
else:
227218
values_lists.append(
228219
[
229-
self.column_map[field_name](getattr(instance, field_name), instance)
220+
self.model._meta.fields_map[field_name].to_db_value(
221+
getattr(instance, field_name), instance
222+
)
230223
for field_name in self.regular_columns
231224
]
232225
)
@@ -292,7 +285,7 @@ async def execute_update(
292285
if isinstance(instance_field, Expression):
293286
expressions[field] = instance_field
294287
else:
295-
value = self.column_map[field](instance_field, instance)
288+
value = self.model._meta.fields_map[field].to_db_value(instance_field, instance)
296289
values.append(value)
297290
values.append(self.model._meta.pk.to_db_value(instance.pk, instance))
298291
return (
Lines changed: 1 addition & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,9 @@
1-
from typing import Any, Optional, Type, Union
1+
from typing import Any
22

3-
from tortoise import Model, fields
43
from tortoise.backends.odbc.executor import ODBCExecutor
54
from tortoise.exceptions import UnSupportedError
6-
from tortoise.fields import BooleanField
7-
8-
9-
def to_db_bool(
10-
self: BooleanField, value: Optional[Union[bool, int]], instance: Union[Type[Model], Model]
11-
) -> Optional[int]:
12-
self.validate(value)
13-
if value is None:
14-
return None
15-
return int(bool(value))
165

176

187
class MSSQLExecutor(ODBCExecutor):
19-
TO_DB_OVERRIDE = {
20-
fields.BooleanField: to_db_bool,
21-
}
22-
238
async def execute_explain(self, sql: str) -> Any:
249
raise UnSupportedError("MSSQL does not support explain")

tortoise/backends/sqlite/executor.py

Lines changed: 5 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -1,99 +1,24 @@
11
import datetime
22
import sqlite3
33
from decimal import Decimal
4-
from typing import Optional, Type, Union
54

6-
import pytz
7-
8-
from tortoise import Model, fields, timezone
5+
from tortoise import Model
96
from tortoise.backends.base.executor import BaseExecutor
107
from tortoise.contrib.sqlite.regex import (
118
insensitive_posix_sqlite_regexp,
129
posix_sqlite_regexp,
1310
)
14-
from tortoise.fields import (
15-
BigIntField,
16-
BooleanField,
17-
DatetimeField,
18-
DecimalField,
19-
IntField,
20-
SmallIntField,
21-
TimeField,
22-
)
11+
from tortoise.fields import BigIntField, IntField, SmallIntField
2312
from tortoise.filters import insensitive_posix_regex, posix_regex
2413

25-
26-
def to_db_bool(
27-
self: BooleanField, value: Optional[Union[bool, int]], instance: Union[Type[Model], Model]
28-
) -> Optional[int]:
29-
self.validate(value)
30-
if value is None:
31-
return None
32-
return int(bool(value))
33-
34-
35-
def to_db_decimal(
36-
self: DecimalField,
37-
value: Optional[Union[str, float, int, Decimal]],
38-
instance: Union[Type[Model], Model],
39-
) -> Optional[str]:
40-
self.validate(value)
41-
if value is None:
42-
return None
43-
return str(Decimal(value).quantize(self.quant).normalize())
44-
45-
46-
def to_db_datetime(
47-
self: DatetimeField, value: Optional[datetime.datetime], instance: Union[Type[Model], Model]
48-
) -> Optional[str]:
49-
self.validate(value)
50-
# Only do this if it is a Model instance, not class. Test for guaranteed instance var
51-
if hasattr(instance, "_saved_in_db") and (
52-
self.auto_now
53-
or (self.auto_now_add and getattr(instance, self.model_field_name, None) is None)
54-
):
55-
if timezone.get_use_tz():
56-
value = datetime.datetime.now(tz=pytz.utc)
57-
else:
58-
value = datetime.datetime.now(tz=timezone.get_default_timezone())
59-
setattr(instance, self.model_field_name, value)
60-
return value.isoformat(" ")
61-
if isinstance(value, datetime.datetime):
62-
return value.isoformat(" ")
63-
return None
64-
65-
66-
def to_db_time(
67-
self: TimeField, value: Optional[datetime.time], instance: Union[Type[Model], Model]
68-
) -> Optional[str]:
69-
self.validate(value)
70-
if hasattr(instance, "_saved_in_db") and (
71-
self.auto_now
72-
or (self.auto_now_add and getattr(instance, self.model_field_name, None) is None)
73-
):
74-
if timezone.get_use_tz():
75-
value = datetime.datetime.now(tz=pytz.utc).time()
76-
else:
77-
value = datetime.datetime.now(tz=timezone.get_default_timezone()).time()
78-
setattr(instance, self.model_field_name, value)
79-
return value.isoformat()
80-
if isinstance(value, datetime.time):
81-
return value.isoformat()
82-
return None
83-
84-
85-
# Converts Decimal to string for sqlite in cases where it's hard to know the
14+
# Conversion for the cases where it's hard to know the
8615
# related field, e.g. in raw queries, math or annotations.
8716
sqlite3.register_adapter(Decimal, str)
17+
sqlite3.register_adapter(datetime.date, lambda val: val.isoformat())
18+
sqlite3.register_adapter(datetime.datetime, lambda val: val.isoformat(" "))
8819

8920

9021
class SqliteExecutor(BaseExecutor):
91-
TO_DB_OVERRIDE = {
92-
fields.BooleanField: to_db_bool,
93-
fields.DecimalField: to_db_decimal,
94-
fields.DatetimeField: to_db_datetime,
95-
fields.TimeField: to_db_time,
96-
}
9722
EXPLAIN_PREFIX = "EXPLAIN QUERY PLAN"
9823
DB_NATIVE = {bytes, str, int, float}
9924
FILTER_FUNC_OVERRIDE = {

tortoise/fields/base.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -259,11 +259,6 @@ def to_db_value(self, value: Any, instance: "Union[Type[Model], Model]") -> Any:
259259
if value is not None and not isinstance(value, self.field_type):
260260
value = self.field_type(value) # pylint: disable=E1102
261261

262-
if self.__class__ in self.model._meta.db.executor_class.TO_DB_OVERRIDE:
263-
value = self.model._meta.db.executor_class.TO_DB_OVERRIDE[self.__class__](
264-
self, value, instance
265-
)
266-
267262
self.validate(value)
268263
return value
269264

tortoise/queryset.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1195,8 +1195,6 @@ def _make_query(self) -> None:
11951195
self.resolve_ordering(self.model, table, self._orderings, self._annotations)
11961196

11971197
self.resolve_filters()
1198-
# Need to get executor to get correct column_map
1199-
executor = self._db.executor_class(model=self.model, db=self._db)
12001198
for key, value in self.update_kwargs.items():
12011199
field_object = self.model._meta.fields_map.get(key)
12021200
if not field_object:
@@ -1207,7 +1205,7 @@ def _make_query(self) -> None:
12071205
self.model._validate_relation_type(key, value)
12081206
fk_field: str = field_object.source_field # type: ignore
12091207
db_field = self.model._meta.fields_map[fk_field].source_field
1210-
value = executor.column_map[fk_field](
1208+
value = self.model._meta.fields_map[fk_field].to_db_value(
12111209
getattr(value, field_object.to_field_instance.model_field_name),
12121210
None,
12131211
)
@@ -1227,7 +1225,7 @@ def _make_query(self) -> None:
12271225
)
12281226
).term
12291227
else:
1230-
value = executor.column_map[key](value, None)
1228+
value = self.model._meta.fields_map[key].to_db_value(value, None)
12311229

12321230
self.query = self.query.set(db_field, value)
12331231

@@ -1838,7 +1836,6 @@ def _make_queries(self) -> List[Tuple[str, List[Any]]]:
18381836
)
18391837

18401838
self.resolve_filters()
1841-
executor = self._db.executor_class(model=self.model, db=self._db)
18421839
pk_attr = self.model._meta.pk_attr
18431840
source_pk_attr = self.model._meta.fields_map[pk_attr].source_field or pk_attr
18441841
pk = Field(source_pk_attr)
@@ -1848,7 +1845,7 @@ def _make_queries(self) -> List[Tuple[str, List[Any]]]:
18481845
case = Case()
18491846
pk_list = []
18501847
for obj in objects_item:
1851-
pk_value = executor.column_map[pk_attr](obj.pk, None)
1848+
pk_value = self.model._meta.fields_map[pk_attr].to_db_value(obj.pk, None)
18521849
field_obj = obj._meta.fields_map[field]
18531850
field_value = field_obj.to_db_value(getattr(obj, field), obj)
18541851
case.when(
@@ -1945,14 +1942,15 @@ def _make_queries(self) -> Tuple[str, str]:
19451942
return self._executor.insert_query, self._executor.insert_query_all
19461943

19471944
async def _execute_many(self, insert_sql: str, insert_sql_all: str) -> None:
1945+
fields_map = self.model._meta.fields_map
19481946
for instance_chunk in chunk(self._objects, self._batch_size):
19491947
values_lists_all = []
19501948
values_lists = []
19511949
for instance in instance_chunk:
19521950
if instance._custom_generated_pk:
19531951
values_lists_all.append(
19541952
[
1955-
self._executor.column_map[field_name](
1953+
fields_map[field_name].to_db_value(
19561954
getattr(instance, field_name), instance
19571955
)
19581956
for field_name in self._executor.regular_columns_all
@@ -1961,7 +1959,7 @@ async def _execute_many(self, insert_sql: str, insert_sql_all: str) -> None:
19611959
else:
19621960
values_lists.append(
19631961
[
1964-
self._executor.column_map[field_name](
1962+
fields_map[field_name].to_db_value(
19651963
getattr(instance, field_name), instance
19661964
)
19671965
for field_name in self._executor.regular_columns

0 commit comments

Comments
 (0)