diff --git a/src/fastapi/service_files.rs b/src/fastapi/service_files.rs index ea61667f..138bb2ea 100644 --- a/src/fastapi/service_files.rs +++ b/src/fastapi/service_files.rs @@ -144,14 +144,15 @@ async def create_user(*, pool: Pool, cache_client: Valkey, user: UserCreate) -> """ async with pool.acquire() as conn: - result = await conn.fetchrow( - query, - user.email, - user.full_name, - get_password_hash(user.password), - user.is_active, - user.is_superuser, - ) + async with conn.transaction(): + result = await conn.fetchrow( + query, + user.email, + user.full_name, + get_password_hash(user.password), + user.is_active, + user.is_superuser, + ) # failsafe: this shouldn't happen if not result: # pragma: no cover @@ -166,17 +167,15 @@ async def create_user(*, pool: Pool, cache_client: Valkey, user: UserCreate) -> async def delete_user(*, pool: Pool, cache_client: Valkey, user_id: str) -> None: query = "DELETE FROM users WHERE id::text = $1" async with pool.acquire() as conn: - async with asyncio.TaskGroup() as tg: - db_task = tg.create_task(conn.execute(query, user_id)) - tg.create_task( - user_cache_services.delete_cached_user(cache_client=cache_client, user_id=user_id) - ) - - result = await db_task + async with conn.transaction(): + result = await conn.execute(query, user_id) if result == "DELETE 0": # pragma: no cover raise UserNotFoundError(f"User with id {{user_id}} not found") + logger.debug("Deleting cached user") + await user_cache_services.delete_cached_user(cache_client=cache_client, user_id=user_id) + async def get_users(*, pool: Pool, offset: int = 0, limit: int = 100) -> list[UserInDb] | None: query = """ @@ -350,18 +349,10 @@ async def update_user( """ async with pool.acquire() as conn: - async with asyncio.TaskGroup() as tg: - db_task = tg.create_task( - conn.fetchrow(query, get_password_hash(user_in.new_password), db_user.id) - ) - tg.create_task( - user_cache_services.delete_cached_user( - cache_client=cache_client, user_id=db_user.id - ) + async with conn.transaction(): + result = await conn.fetchrow( + query, get_password_hash(user_in.new_password), db_user.id ) - - result = await db_task - else: user_data = user_in.model_dump(exclude_unset=True) if "password" in user_data: @@ -382,19 +373,14 @@ async def update_user( """ async with pool.acquire() as conn: - async with asyncio.TaskGroup() as tg: - db_task = tg.create_task(conn.fetchrow(query, db_user.id, *user_data.values())) - tg.create_task( - user_cache_services.delete_cached_user( - cache_client=cache_client, user_id=db_user.id - ) - ) - - result = await db_task + async with conn.transaction(): + result = await conn.fetchrow(query, db_user.id, *user_data.values()) if not result or result == "UPDATE 0": # pragma: no cover raise DbUpdateError("Error updating user") + await user_cache_services.delete_cached_user(cache_client=cache_client, user_id=db_user.id) + return UserInDb(**dict(result)) "# )