Skip to content

Commit 4485c6f

Browse files
committed
simplify to use a custom jwt data, also verify k8s object exists
1 parent faf5c04 commit 4485c6f

File tree

6 files changed

+56
-39
lines changed

6 files changed

+56
-39
lines changed

backend/btrixcloud/auth.py

Lines changed: 14 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from uuid import UUID, uuid4
55
import asyncio
66
from datetime import timedelta
7-
from typing import Optional, Tuple, List, Callable
7+
from typing import Optional, Tuple, List
88
from passlib import pwd
99
from passlib.context import CryptContext
1010

@@ -42,6 +42,7 @@
4242
PWD_CONTEXT = CryptContext(schemes=["bcrypt"], deprecated="auto")
4343

4444
# Audiences
45+
CUSTOM_AUTH_AUD = "btrix:custom-auth"
4546
AUTH_AUD = "btrix:auth"
4647
RESET_AUD = "btrix:reset"
4748
VERIFY_AUD = "btrix:verify"
@@ -121,14 +122,22 @@ def create_access_token(user: User) -> str:
121122

122123

123124
# ============================================================================
124-
def create_internal_crawler_access_token(sub: str, role: str) -> str:
125+
def create_custom_jwt_token(sub: str, data: dict[str, str]) -> str:
125126
"""create jwt token for internal crawler access"""
126127
return generate_jwt(
127-
{"sub": sub, "internal_role": role, "aud": AUTH_AUD},
128+
{**data, "sub": sub, "aud": CUSTOM_AUTH_AUD},
128129
INTERNAL_JWT_TOKEN_LIFETIME,
129130
)
130131

131132

133+
# ============================================================================
134+
def get_custom_jwt_token(request: Request) -> dict[str, str]:
135+
"""return data from custom jwt token"""
136+
token = request.query_params.get("auth_bearer") or ""
137+
payload = decode_jwt(token, [CUSTOM_AUTH_AUD])
138+
return payload
139+
140+
132141
# ============================================================================
133142
def verify_password(plain_password: str, hashed_password: str) -> bool:
134143
"""verify password by hash"""
@@ -157,7 +166,7 @@ def generate_password() -> str:
157166

158167
# ============================================================================
159168
# pylint: disable=raise-missing-from
160-
def init_jwt_auth(user_manager) -> tuple[Callable, Callable, Callable, Callable]:
169+
def init_jwt_auth(user_manager):
161170
"""init jwt auth router + current_active_user dependency"""
162171
oauth2_scheme = OA2BearerOrQuery(tokenUrl="/api/auth/jwt/login", auto_error=False)
163172

@@ -167,8 +176,6 @@ async def get_current_user(
167176
try:
168177
payload = decode_jwt(token, AUTH_ALLOW_AUD)
169178
uid: Optional[str] = payload.get("sub") or payload.get("user_id")
170-
# insure not an internal token
171-
assert not payload.get("internal_role")
172179
user = await user_manager.get_by_id(UUID(uid))
173180
assert user
174181
return user
@@ -194,17 +201,6 @@ async def shared_secret_or_superuser(
194201

195202
return user
196203

197-
def get_custom_access(role: str) -> Callable[[str], str]:
198-
def get_access_dep(token: str = Depends(oauth2_scheme)) -> str:
199-
payload = decode_jwt(token, AUTH_ALLOW_AUD)
200-
sub = payload.get("sub")
201-
if not sub or payload.get("internal_role") != role:
202-
raise HTTPException(status_code=401, detail="invalid_credentials")
203-
204-
return sub
205-
206-
return get_access_dep
207-
208204
current_active_user = get_current_user
209205

210206
auth_jwt_router = APIRouter()
@@ -284,9 +280,4 @@ async def refresh_jwt(user=Depends(current_active_user)):
284280
user_info = await user_manager.get_user_info_with_orgs(user)
285281
return get_bearer_response(user, user_info)
286282

287-
return (
288-
auth_jwt_router,
289-
current_active_user,
290-
shared_secret_or_superuser,
291-
get_custom_access,
292-
)
283+
return auth_jwt_router, current_active_user, shared_secret_or_superuser

backend/btrixcloud/colls.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,8 @@
6363
case_insensitive_collation,
6464
)
6565

66+
from .auth import get_custom_jwt_token
67+
6668
from .crawlmanager import CrawlManager
6769

6870
if TYPE_CHECKING:
@@ -1274,7 +1276,6 @@ def init_collections_api(
12741276
crawl_manager: CrawlManager,
12751277
event_webhook_ops: EventWebhookOps,
12761278
user_dep,
1277-
coll_dep,
12781279
) -> CollectionOps:
12791280
"""init collections api"""
12801281
# pylint: disable=invalid-name, unused-argument, too-many-arguments
@@ -1288,10 +1289,20 @@ def init_collections_api(
12881289
org_viewer_dep = orgs.org_viewer_dep
12891290
org_public = orgs.org_public
12901291

1291-
def coll_access_dep(coll_id: UUID, coll_access_id=Depends(coll_dep)) -> UUID:
1292-
if coll_access_id == str(coll_id):
1293-
return coll_id
1292+
async def coll_access_dep(
1293+
coll_id: UUID, token_data: dict[str, str] = Depends(get_custom_jwt_token)
1294+
) -> UUID:
1295+
# first, check subject match collection id and type is collection
1296+
if token_data.get("sub_type") == "coll" and token_data.get("sub") == str(
1297+
coll_id
1298+
):
1299+
# second, check that the k8s object access is scoped to exists
1300+
if await crawl_manager.validate_k8s_obj_exists(
1301+
token_data.get("scope_type", ""), token_data.get("scope", "")
1302+
):
1303+
return coll_id
12941304

1305+
# otherwise, deny access
12951306
raise HTTPException(status_code=403, detail="access_denied")
12961307

12971308
@app.post(

backend/btrixcloud/crawlmanager.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
from .utils import dt_now, date_to_str, scale_from_browser_windows
1212
from .k8sapi import K8sAPI, ApiException
13-
from .auth import create_internal_crawler_access_token
13+
from .auth import create_custom_jwt_token
1414

1515
from .models import (
1616
StorageRef,
@@ -245,7 +245,9 @@ async def run_index_import_job(
245245
)
246246

247247
if job_type in ("purge", "import"):
248-
auth_bearer = create_internal_crawler_access_token(coll_id, "coll")
248+
auth_bearer = create_custom_jwt_token(
249+
coll_id, {"sub_type": "coll", "scope_type": "job", "scope": name}
250+
)
249251
import_source_url = (
250252
f"{BACKEND_ORIGIN}/api/orgs/{oid}/collections/{coll_id}"
251253
+ f"/internal/replay.json?auth_bearer={auth_bearer}"
@@ -282,6 +284,13 @@ async def run_index_import_job(
282284

283285
return name
284286

287+
async def validate_k8s_obj_exists(self, obj_type: str, name: str) -> bool:
288+
"""return true/false if specified k8s object exists"""
289+
if obj_type == "job":
290+
return await self.has_job(name)
291+
292+
return False
293+
285294
async def delete_dedupe_index_resources(self, oid: str, coll_id: str) -> None:
286295
"""Delete dedupe index-related jobs and index itself"""
287296
await self._delete_jobs(f"role=index-import-job,oid={oid},coll={coll_id}")

backend/btrixcloud/k8sapi.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -348,6 +348,17 @@ async def unsuspend_k8s_job(self, name) -> dict:
348348
traceback.print_exc()
349349
return {"error": str(exc)}
350350

351+
async def has_job(self, name) -> bool:
352+
"""return true/false if job exists"""
353+
try:
354+
await self.batch_api.read_namespaced_job(
355+
name=name, namespace=self.namespace
356+
)
357+
return True
358+
# pylint: disable=bare-except
359+
except:
360+
return False
361+
351362
async def print_pod_logs(self, pod_names, lines=100):
352363
"""print pod logs"""
353364
for pod in pod_names:

backend/btrixcloud/main.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -172,9 +172,7 @@ def main() -> None:
172172

173173
user_manager = init_user_manager(mdb, email, invites)
174174

175-
current_active_user, shared_secret_or_superuser, custom_access = init_users_api(
176-
app, user_manager
177-
)
175+
current_active_user, shared_secret_or_superuser = init_users_api(app, user_manager)
178176

179177
org_ops = init_orgs_api(
180178
app,
@@ -247,7 +245,6 @@ def main() -> None:
247245
crawl_manager,
248246
event_webhook_ops,
249247
current_active_user,
250-
custom_access("coll"),
251248
)
252249

253250
base_crawl_init = (

backend/btrixcloud/users.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -593,13 +593,11 @@ def init_user_manager(mdb, emailsender, invites):
593593

594594
# ============================================================================
595595
# pylint: disable=too-many-locals, raise-missing-from
596-
def init_users_api(
597-
app, user_manager: UserManager
598-
) -> tuple[Callable, Callable, Callable]:
596+
def init_users_api(app, user_manager: UserManager):
599597
"""init fastapi_users"""
600598

601-
auth_jwt_router, current_active_user, shared_secret_or_superuser, custom_access = (
602-
init_jwt_auth(user_manager)
599+
auth_jwt_router, current_active_user, shared_secret_or_superuser = init_jwt_auth(
600+
user_manager
603601
)
604602

605603
app.include_router(
@@ -620,7 +618,7 @@ def init_users_api(
620618
tags=["users"],
621619
)
622620

623-
return current_active_user, shared_secret_or_superuser, custom_access
621+
return current_active_user, shared_secret_or_superuser
624622

625623

626624
# ============================================================================

0 commit comments

Comments
 (0)