Skip to content

Commit 3ddcad2

Browse files
committed
(trying to) reduce sql connections
1 parent 67c5971 commit 3ddcad2

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

47 files changed

+2250
-2169
lines changed

backend/src/core/auth/auth_endpoint.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,7 @@ async def oidc_callback(
192192
)
193193

194194
try:
195-
user = await oauth_service.authenticate_oidc(request)
195+
user = await oauth_service.authenticate_oidc(db=db, request=request)
196196
except Exception as e:
197197
raise HTTPException(status_code=401, detail=str(e))
198198

backend/src/core/auth/oauth_service.py

Lines changed: 29 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,13 @@
44
from authlib.integrations.starlette_client import OAuth, OAuthError
55
from fastapi import Request
66
from loguru import logger
7+
from sqlalchemy.orm import Session
78

89
from common.singleton_meta import SingletonMeta
910
from config import conf
1011
from core.user.user_crud import crud_user
1112
from core.user.user_dto import UserCreate
1213
from core.user.user_orm import UserORM
13-
from repos.db.sql_repo import SQLRepo
1414
from repos.mail_repo import MailRepo
1515

1616

@@ -41,7 +41,7 @@ def __new__(cls, *args, **kwargs):
4141

4242
return super(OAuthService, cls).__new__(cls)
4343

44-
async def authenticate_oidc(self, request: Request) -> UserORM:
44+
async def authenticate_oidc(self, db: Session, request: Request) -> UserORM:
4545
try:
4646
token = await self.authentik.authorize_access_token(request)
4747
except OAuthError as error:
@@ -52,35 +52,34 @@ async def authenticate_oidc(self, request: Request) -> UserORM:
5252
userinfo = token.get("userinfo")
5353
print(f"Userinfo: {userinfo}")
5454

55-
with SQLRepo().db_session() as db:
56-
try:
57-
# Warning: Security concern
58-
user = crud_user.read_by_email(db=db, email=userinfo["email"])
59-
return user
60-
except Exception as e:
61-
logger.info(f"User not found, creating new user: {e}")
62-
# Create user if not exists
63-
user = crud_user.create(
64-
db=db,
65-
create_dto=UserCreate(
66-
email=userinfo["email"],
67-
first_name=userinfo.get("given_name", "Unknown"),
68-
last_name=userinfo.get("family_name", "Unknown"),
69-
# Set a random password since we'll only use OIDC
70-
password="".join(
71-
random.choices(
72-
string.ascii_letters + string.digits,
73-
k=32,
74-
)
75-
),
55+
try:
56+
# Warning: Security concern
57+
user = crud_user.read_by_email(db=db, email=userinfo["email"])
58+
return user
59+
except Exception as e:
60+
logger.info(f"User not found, creating new user: {e}")
61+
# Create user if not exists
62+
user = crud_user.create(
63+
db=db,
64+
create_dto=UserCreate(
65+
email=userinfo["email"],
66+
first_name=userinfo.get("given_name", "Unknown"),
67+
last_name=userinfo.get("family_name", "Unknown"),
68+
# Set a random password since we'll only use OIDC
69+
password="".join(
70+
random.choices(
71+
string.ascii_letters + string.digits,
72+
k=32,
73+
)
7674
),
77-
)
78-
await MailRepo().send_welcome_mail(
79-
email=user.email,
80-
first_name=user.first_name,
81-
last_name=user.last_name,
82-
)
83-
return user
75+
),
76+
)
77+
await MailRepo().send_welcome_mail(
78+
email=user.email,
79+
first_name=user.first_name,
80+
last_name=user.last_name,
81+
)
82+
return user
8483
except Exception as e:
8584
logger.error(f"Error processing OIDC authentication: {e}")
8685
raise Exception("Authentication failed")

backend/src/core/project/project_crud.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ class CRUDProject(CRUDBase[ProjectORM, ProjectCreate, ProjectUpdate]):
2626
### CREATE OPERATIONS ###
2727

2828
def create(
29-
self, db: Session, *, create_dto: ProjectCreate, creating_user: UserORM
29+
self, db: Session, *, create_dto: ProjectCreate, creating_user_id: int
3030
) -> ProjectORM:
3131
# 1) create the project
3232
dto_obj_data = jsonable_encoder(create_dto)
@@ -44,13 +44,13 @@ def create(
4444
self.associate_user(db=db, proj_id=project_id, user_id=ASSISTANT_TRAINED_ID)
4545

4646
# 3) associate the user that created the project
47-
if creating_user.id not in [
47+
if creating_user_id not in [
4848
SYSTEM_USER_ID,
4949
ASSISTANT_ZEROSHOT_ID,
5050
ASSISTANT_FEWSHOT_ID,
5151
ASSISTANT_TRAINED_ID,
5252
]:
53-
self.associate_user(db=db, proj_id=project_id, user_id=creating_user.id)
53+
self.associate_user(db=db, proj_id=project_id, user_id=creating_user_id)
5454

5555
# 4) create system codes
5656
crud_code.create_system_codes_for_project(db=db, proj_id=project_id)

backend/src/core/project/project_endpoint.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,9 @@ def create_new_project(
3434
proj: ProjectCreate,
3535
current_user: UserORM = Depends(get_current_user),
3636
) -> ProjectRead:
37-
db_obj = crud_project.create(db=db, create_dto=proj, creating_user=current_user)
37+
db_obj = crud_project.create(
38+
db=db, create_dto=proj, creating_user_id=current_user.id
39+
)
3840
return ProjectRead.model_validate(db_obj)
3941

4042

backend/src/modules/analysis/analysis_endpoint.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
)
3131
def code_frequencies(
3232
*,
33+
db: Session = Depends(get_db_session),
3334
project_id: int,
3435
code_ids: list[int],
3536
user_ids: list[int],
@@ -39,7 +40,11 @@ def code_frequencies(
3940
authz_user.assert_in_project(project_id)
4041

4142
return find_code_frequencies(
42-
project_id=project_id, code_ids=code_ids, user_ids=user_ids, doctypes=doctypes
43+
db=db,
44+
project_id=project_id,
45+
code_ids=code_ids,
46+
user_ids=user_ids,
47+
doctypes=doctypes,
4348
)
4449

4550

@@ -50,6 +55,7 @@ def code_frequencies(
5055
)
5156
def code_occurrences(
5257
*,
58+
db: Session = Depends(get_db_session),
5359
project_id: int,
5460
user_ids: list[int],
5561
code_id: int,
@@ -58,7 +64,7 @@ def code_occurrences(
5864
authz_user.assert_in_project(project_id)
5965

6066
return find_code_occurrences(
61-
project_id=project_id, user_ids=user_ids, code_id=code_id
67+
db=db, project_id=project_id, user_ids=user_ids, code_id=code_id
6268
)
6369

6470

@@ -69,13 +75,15 @@ def code_occurrences(
6975
)
7076
def count_sdocs_with_date_metadata(
7177
*,
78+
db: Session = Depends(get_db_session),
7279
project_id: int,
7380
date_metadata_id: int,
7481
authz_user: AuthzUser = Depends(),
7582
) -> tuple[int, int]:
7683
authz_user.assert_in_project(project_id)
7784

7885
return compute_num_sdocs_with_date_metadata(
86+
db=db,
7987
project_id=project_id,
8088
date_metadata_id=date_metadata_id,
8189
)
@@ -97,5 +105,5 @@ def sample_sdocs_by_tags(
97105
) -> list[SampledSdocsResults]:
98106
authz_user.assert_in_project(project_id)
99107
return document_sampler_by_tags(
100-
project_id=project_id, tag_ids=tag_groups, n=n, frac=frac
108+
db=db, project_id=project_id, tag_ids=tag_groups, n=n, frac=frac
101109
)

0 commit comments

Comments
 (0)