Skip to content

Commit 8774a83

Browse files
committed
Mixins: initial version of the materials
1 parent b9404c8 commit 8774a83

File tree

4 files changed

+381
-0
lines changed

4 files changed

+381
-0
lines changed

python-mixins/README.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
# TODO
2+
3+
This folder contains the source code that accompanies the Real Python tutorial, [TODO](TODO).

python-mixins/orm/__init__.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
from orm.core import ActiveRecord
2+
from orm.mixins import SQLMixin, TimestampMixin
3+
4+
__all__ = [
5+
"ActiveRecord",
6+
"SQLMixin",
7+
"TimestampMixin",
8+
]

python-mixins/orm/core.py

Lines changed: 313 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,313 @@
1+
import atexit
2+
import inspect
3+
import os
4+
import re
5+
import sqlite3
6+
import textwrap
7+
import typing
8+
from dataclasses import MISSING, Field, dataclass, field, fields
9+
from functools import cached_property
10+
from typing import Any, ClassVar, Iterator, Type, TypeVar
11+
12+
DATABASE_FILE = os.getenv("DATABASE", ":memory:")
13+
PRIMARY_KEY_COLUMN = "pk"
14+
SQL_COLUMN_TYPES = {
15+
bool: "NUMERIC",
16+
bytes: "BLOB",
17+
float: "REAL",
18+
int: "INTEGER",
19+
str: "TEXT",
20+
}
21+
22+
_T = TypeVar("_T", bound="ActiveRecord")
23+
24+
25+
class DataClassMeta(type):
26+
def __new__(mcs, name, bases, attrs, **kwargs):
27+
cls = super().__new__(mcs, name, bases, attrs)
28+
return dataclass(**kwargs)(cls)
29+
30+
31+
class ActiveRecordMeta(DataClassMeta):
32+
connection = sqlite3.connect(DATABASE_FILE)
33+
connection.execute("PRAGMA foreign_keys = ON")
34+
connection.row_factory = sqlite3.Row
35+
connection.autocommit = True
36+
atexit.register(connection.close)
37+
38+
def __new__(mcs, name, bases, attrs):
39+
cls = super().__new__(mcs, name, bases, attrs)
40+
if len(bases) > 0:
41+
cls.__table__ = SQLTable(cls)
42+
cls.__table__.create()
43+
return cls
44+
45+
46+
class ActiveRecord(metaclass=ActiveRecordMeta):
47+
__table__: ClassVar["SQLTable"]
48+
pk: int | None = field(kw_only=True, default=None)
49+
50+
@classmethod
51+
def find_all(cls: type[_T]) -> Iterator[_T]:
52+
return recursive_fetch(cls, cls.__table__.select_all())
53+
54+
@classmethod
55+
def find_by(cls: type[_T], **parameters) -> Iterator[_T]:
56+
if not parameters:
57+
raise ValueError("missing query conditions")
58+
return recursive_fetch(
59+
cls, cls.__table__.select_where(**parameters)
60+
)
61+
62+
@classmethod
63+
def find(cls: type[_T], *, pk: int) -> _T:
64+
try:
65+
return next(cls.find_by(pk=pk))
66+
except StopIteration as ex:
67+
raise ValueError(
68+
f"{cls.__name__} with pk={pk} not found"
69+
) from ex
70+
71+
def save(self) -> None:
72+
if self.pk is None:
73+
cursor = self.__table__.insert(self)
74+
self.pk = cursor.lastrowid
75+
else:
76+
self.__table__.update(self)
77+
78+
def delete(self) -> None:
79+
self.__table__.delete(self)
80+
self.pk = None
81+
82+
def __setattr__(self, name: str, value: Any) -> None:
83+
if name == PRIMARY_KEY_COLUMN:
84+
if frame := inspect.currentframe():
85+
if calling_frame := frame.f_back:
86+
local_self = calling_frame.f_locals.get("self", None)
87+
if local_self is not self:
88+
raise AttributeError(
89+
"primary key is managed by the database"
90+
)
91+
super().__setattr__(name, value)
92+
93+
94+
class SQLTable:
95+
def __init__(self, cls: type[_T]) -> None:
96+
self.cls = cls
97+
self.sql = SQLQueryGenerator(self)
98+
99+
@cached_property
100+
def name(self) -> str:
101+
return f"{snake_case(self.cls.__name__).rstrip("s")}s"
102+
103+
@cached_property
104+
def columns(self) -> list["SQLColumn"]:
105+
return [
106+
SQLColumn(class_field)
107+
for class_field in fields(self.cls)
108+
if class_field.name != PRIMARY_KEY_COLUMN
109+
]
110+
111+
@cached_property
112+
def foreign_keys(self) -> dict[str, type[_T]]:
113+
return {
114+
column.name: column.foreign_table.cls
115+
for column in self.columns
116+
if column.foreign_table
117+
}
118+
119+
def create(self) -> sqlite3.Cursor:
120+
return self.cls.connection.execute(self.sql.create.statement)
121+
122+
def insert(self, record: _T) -> sqlite3.Cursor:
123+
if record.pk is not None:
124+
raise ValueError("record has a primary key")
125+
query = self.sql.insert(record)
126+
return self.cls.connection.execute(query.statement, query.values)
127+
128+
def update(self, record: _T) -> sqlite3.Cursor:
129+
if record.pk is None:
130+
raise ValueError("record has no primary key")
131+
query = self.sql.update(record)
132+
return self.cls.connection.execute(query.statement, query.values)
133+
134+
def delete(self, record: _T) -> sqlite3.Cursor:
135+
if record.pk is None:
136+
raise ValueError("record hasn't been saved to database")
137+
query = self.sql.delete(record)
138+
return self.cls.connection.execute(query.statement, query.values)
139+
140+
def select_all(self) -> sqlite3.Cursor:
141+
return self.cls.connection.execute(self.sql.select_all.statement)
142+
143+
def select_where(self, **parameters) -> sqlite3.Cursor:
144+
query = self.sql.select_where(**parameters)
145+
return self.cls.connection.execute(query.statement, query.values)
146+
147+
148+
class SQLColumn:
149+
name: str
150+
type: str
151+
default: Any
152+
foreign_table: SQLTable | None
153+
154+
def __init__(self, class_field: Field) -> None:
155+
field_type = primary_type(class_field.type)
156+
if issubclass(field_type, ActiveRecord):
157+
self.name = f"{class_field.name}_{PRIMARY_KEY_COLUMN}"
158+
self.type = SQL_COLUMN_TYPES.get(int, "TEXT")
159+
self.foreign_table = SQLTable(field_type)
160+
else:
161+
self.name = class_field.name
162+
self.type = SQL_COLUMN_TYPES.get(field_type, "TEXT")
163+
self.foreign_table = None
164+
if class_field.default_factory is not MISSING:
165+
self.default = class_field.default_factory()
166+
elif class_field.default is not MISSING:
167+
self.default = class_field.default
168+
else:
169+
self.default = MISSING
170+
171+
@cached_property
172+
def definition(self) -> str:
173+
sql = f"{self.name} {self.type}"
174+
if self.foreign_table:
175+
sql += (
176+
" REFERENCES "
177+
f"{self.foreign_table.name}({PRIMARY_KEY_COLUMN})"
178+
)
179+
if self.default is MISSING:
180+
sql += " NOT NULL"
181+
elif self.default is not None:
182+
sql += f" DEFAULT {self.default!r}"
183+
return sql
184+
185+
186+
class SQLQuery:
187+
def __init__(
188+
self, statement: str, parameters: dict[str, Any] | None = None
189+
) -> None:
190+
self.statement = statement
191+
self.parameters = parameters or {}
192+
self.values = rename(self.parameters)
193+
194+
def __str__(self) -> str:
195+
result = self.statement
196+
for key, value in self.values.items():
197+
if value is None:
198+
result = result.replace(f"=:{key}", " IS NULL")
199+
else:
200+
result = result.replace(f":{key}", repr(value))
201+
return result + ";"
202+
203+
204+
class SQLQueryGenerator:
205+
def __init__(self, table: "SQLTable") -> None:
206+
self.table = table
207+
208+
@cached_property
209+
def create(self) -> SQLQuery:
210+
column_definitions = [
211+
f"{PRIMARY_KEY_COLUMN} INTEGER PRIMARY KEY AUTOINCREMENT",
212+
*(column.definition for column in self.table.columns),
213+
]
214+
return SQLQuery(
215+
textwrap.dedent(
216+
f"""\
217+
CREATE TABLE IF NOT EXISTS {self.table.name}(
218+
{",\n ".join(column_definitions)}
219+
)"""
220+
)
221+
)
222+
223+
def insert(self, record: _T) -> SQLQuery:
224+
column_names = ", ".join(
225+
column.name for column in self.table.columns
226+
)
227+
placeholders = ", ".join(
228+
f":{column.name}" for column in self.table.columns
229+
)
230+
return SQLQuery(
231+
(
232+
f"INSERT INTO {self.table.name}({column_names}) "
233+
f"VALUES ({placeholders})"
234+
),
235+
dict(vars(record)),
236+
)
237+
238+
def update(self, record: _T) -> SQLQuery:
239+
placeholders = ", ".join(
240+
f"{column.name}=:{column.name}"
241+
for column in self.table.columns
242+
)
243+
return SQLQuery(
244+
(
245+
f"UPDATE {self.table.name} "
246+
f"SET {placeholders} "
247+
f"WHERE {PRIMARY_KEY_COLUMN}=:{PRIMARY_KEY_COLUMN}"
248+
),
249+
dict(vars(record)),
250+
)
251+
252+
def delete(self, record: _T) -> SQLQuery:
253+
return SQLQuery(
254+
(
255+
f"DELETE FROM {self.table.name} "
256+
f"WHERE {PRIMARY_KEY_COLUMN}=:{PRIMARY_KEY_COLUMN}"
257+
),
258+
dict(vars(record)),
259+
)
260+
261+
@cached_property
262+
def select_all(self) -> SQLQuery:
263+
return SQLQuery(f"SELECT * FROM {self.table.name}")
264+
265+
def select_where(self, **parameters) -> SQLQuery:
266+
values = rename(parameters)
267+
conditions = " AND ".join(f"{param}=:{param}" for param in values)
268+
return SQLQuery(
269+
f"SELECT * FROM {self.table.name} WHERE {conditions}",
270+
values,
271+
)
272+
273+
274+
def recursive_fetch(cls: type[_T], cursor: sqlite3.Cursor) -> Iterator[_T]:
275+
for row in cursor.fetchall():
276+
attrs = {}
277+
for column_name in row.keys():
278+
if fk := cls.__table__.foreign_keys.get(column_name):
279+
name = column_name.removesuffix(f"_{PRIMARY_KEY_COLUMN}")
280+
attrs[name] = fk.find(pk=row[column_name])
281+
else:
282+
attrs[column_name] = row[column_name]
283+
yield cls(**attrs)
284+
285+
286+
def rename(parameters: dict[str, Any]) -> dict[str, Any]:
287+
values = {}
288+
for key, value in parameters.items():
289+
if isinstance(value, ActiveRecord):
290+
values[f"{key}_{PRIMARY_KEY_COLUMN}"] = value.pk
291+
else:
292+
values[key] = value
293+
return values
294+
295+
296+
def primary_type(type_hint: Type) -> type:
297+
if args := typing.get_args(type_hint):
298+
match [arg for arg in args if arg is not type(None)]:
299+
case [x] if x is not None:
300+
return x
301+
case []:
302+
raise TypeError("no primary types found")
303+
case _:
304+
raise TypeError("cannot have multiple primary types")
305+
306+
if type_hint is None:
307+
raise TypeError("type cannot be None")
308+
309+
return type_hint
310+
311+
312+
def snake_case(text: str) -> str:
313+
return re.sub(r"([a-z])([A-Z])", r"\1_\2", text).lower()

python-mixins/orm/mixins.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
import logging
2+
import time
3+
from dataclasses import field
4+
5+
from orm.core import ActiveRecordMeta
6+
7+
CREATED_AT_COLUMN = "created_at"
8+
UPDATED_AT_COLUMN = "updated_at"
9+
10+
logger = logging.getLogger(__name__)
11+
12+
13+
class SQLMixin:
14+
@classmethod
15+
def find_all(cls):
16+
logger.debug(cls.__table__.sql.select_all)
17+
return super().find_all()
18+
19+
@classmethod
20+
def find_by(cls, **parameters):
21+
logger.debug(cls.__table__.sql.select_where(**parameters))
22+
return super().find_by(**parameters)
23+
24+
def save(self):
25+
if self.pk is None:
26+
logger.debug(self.__table__.sql.insert(self))
27+
else:
28+
logger.debug(self.__table__.sql.update(self))
29+
super().save()
30+
31+
def delete(self):
32+
logger.debug(self.__table__.sql.delete(self))
33+
return super().delete()
34+
35+
36+
class TimestampMixinMeta(ActiveRecordMeta):
37+
def __call__(cls, *args, **kwargs):
38+
created_at = kwargs.pop(CREATED_AT_COLUMN, None)
39+
updated_at = kwargs.pop(UPDATED_AT_COLUMN, None)
40+
instance = super().__call__(*args, **kwargs)
41+
instance.created_at = created_at
42+
instance.updated_at = updated_at
43+
return instance
44+
45+
46+
class TimestampMixin(metaclass=TimestampMixinMeta):
47+
created_at: int = field(init=False, repr=False)
48+
updated_at: int = field(init=False, repr=False)
49+
50+
def save(self) -> None:
51+
current_time = int(time.time())
52+
if self.pk is None:
53+
self.created_at = current_time
54+
self.updated_at = current_time
55+
else:
56+
self.updated_at = current_time
57+
super().save()

0 commit comments

Comments
 (0)