Skip to content

Commit 3f964a4

Browse files
authored
Merge pull request #966 from sanders41/transactions
Add db transactions
2 parents aecf2b6 + 55394b5 commit 3f964a4

File tree

1 file changed

+21
-35
lines changed

1 file changed

+21
-35
lines changed

src/fastapi/service_files.rs

Lines changed: 21 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -144,14 +144,15 @@ async def create_user(*, pool: Pool, cache_client: Valkey, user: UserCreate) ->
144144
"""
145145
146146
async with pool.acquire() as conn:
147-
result = await conn.fetchrow(
148-
query,
149-
user.email,
150-
user.full_name,
151-
get_password_hash(user.password),
152-
user.is_active,
153-
user.is_superuser,
154-
)
147+
async with conn.transaction():
148+
result = await conn.fetchrow(
149+
query,
150+
user.email,
151+
user.full_name,
152+
get_password_hash(user.password),
153+
user.is_active,
154+
user.is_superuser,
155+
)
155156
156157
# failsafe: this shouldn't happen
157158
if not result: # pragma: no cover
@@ -166,17 +167,15 @@ async def create_user(*, pool: Pool, cache_client: Valkey, user: UserCreate) ->
166167
async def delete_user(*, pool: Pool, cache_client: Valkey, user_id: str) -> None:
167168
query = "DELETE FROM users WHERE id::text = $1"
168169
async with pool.acquire() as conn:
169-
async with asyncio.TaskGroup() as tg:
170-
db_task = tg.create_task(conn.execute(query, user_id))
171-
tg.create_task(
172-
user_cache_services.delete_cached_user(cache_client=cache_client, user_id=user_id)
173-
)
174-
175-
result = await db_task
170+
async with conn.transaction():
171+
result = await conn.execute(query, user_id)
176172
177173
if result == "DELETE 0": # pragma: no cover
178174
raise UserNotFoundError(f"User with id {{user_id}} not found")
179175
176+
logger.debug("Deleting cached user")
177+
await user_cache_services.delete_cached_user(cache_client=cache_client, user_id=user_id)
178+
180179
181180
async def get_users(*, pool: Pool, offset: int = 0, limit: int = 100) -> list[UserInDb] | None:
182181
query = """
@@ -350,18 +349,10 @@ async def update_user(
350349
"""
351350
352351
async with pool.acquire() as conn:
353-
async with asyncio.TaskGroup() as tg:
354-
db_task = tg.create_task(
355-
conn.fetchrow(query, get_password_hash(user_in.new_password), db_user.id)
356-
)
357-
tg.create_task(
358-
user_cache_services.delete_cached_user(
359-
cache_client=cache_client, user_id=db_user.id
360-
)
352+
async with conn.transaction():
353+
result = await conn.fetchrow(
354+
query, get_password_hash(user_in.new_password), db_user.id
361355
)
362-
363-
result = await db_task
364-
365356
else:
366357
user_data = user_in.model_dump(exclude_unset=True)
367358
if "password" in user_data:
@@ -382,19 +373,14 @@ async def update_user(
382373
"""
383374
384375
async with pool.acquire() as conn:
385-
async with asyncio.TaskGroup() as tg:
386-
db_task = tg.create_task(conn.fetchrow(query, db_user.id, *user_data.values()))
387-
tg.create_task(
388-
user_cache_services.delete_cached_user(
389-
cache_client=cache_client, user_id=db_user.id
390-
)
391-
)
392-
393-
result = await db_task
376+
async with conn.transaction():
377+
result = await conn.fetchrow(query, db_user.id, *user_data.values())
394378
395379
if not result or result == "UPDATE 0": # pragma: no cover
396380
raise DbUpdateError("Error updating user")
397381
382+
await user_cache_services.delete_cached_user(cache_client=cache_client, user_id=db_user.id)
383+
398384
return UserInDb(**dict(result))
399385
"#
400386
)

0 commit comments

Comments
 (0)