Skip to content

Commit 73c98dc

Browse files
committed
Let Flask handle graphql authorization
Since we need the request body and FastAPI requires a bit of hackery to read the response body and pass it on to Flask, we just let Flask handle this. The bug stemmed from FastAPI calling the authorization rule which assumed a Flask context. Signed-off-by: mprahl <mprahl@users.noreply.github.com>
1 parent 901c6eb commit 73c98dc

File tree

4 files changed

+137
-99
lines changed

4 files changed

+137
-99
lines changed

kubernetes-workspace-provider/src/kubernetes_workspace_provider/auth.py

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1002,6 +1002,7 @@ def _authorize_request(
10021002
authorizer: KubernetesAuthorizer,
10031003
config_values: KubernetesAuthConfig,
10041004
workspace: str | None,
1005+
graphql_payload: dict[str, object] | None = None,
10051006
) -> _AuthorizationResult:
10061007
"""
10071008
Resolve the caller identity and ensure the MLflow request is permitted.
@@ -1041,7 +1042,7 @@ def _authorize_request(
10411042
if isinstance(workspace, str):
10421043
workspace_name = workspace.strip() or None
10431044

1044-
rules = _find_authorization_rules(path, method)
1045+
rules = _find_authorization_rules(path, method, graphql_payload=graphql_payload)
10451046
if rules is None or len(rules) == 0:
10461047
_logger.warning(
10471048
"No Kubernetes authorization rule matched request %s %s; returning 404.",
@@ -1550,7 +1551,9 @@ def _validate_fastapi_route_authorization(fastapi_app: FastAPI) -> None:
15501551
)
15511552

15521553

1553-
def _find_authorization_rules(request_path: str, method: str) -> list[AuthorizationRule] | None:
1554+
def _find_authorization_rules(
1555+
request_path: str, method: str, graphql_payload: dict[str, object] | None = None
1556+
) -> list[AuthorizationRule] | None:
15541557
"""Find authorization rules for a request.
15551558
15561559
For most endpoints, returns a single-element list. For GraphQL endpoints,
@@ -1569,10 +1572,7 @@ def _find_authorization_rules(request_path: str, method: str) -> list[Authorizat
15691572
# send operationName="GetRun" but include model registry fields in the
15701573
# query, bypassing authorization checks for those resources.
15711574
if canonical_path.endswith("/graphql"):
1572-
try:
1573-
payload = request.get_json(silent=True) or {}
1574-
except Exception:
1575-
payload = {}
1575+
payload = graphql_payload or {}
15761576

15771577
query_string = payload.get("query", "")
15781578
if not query_string:
@@ -1656,6 +1656,14 @@ async def dispatch(self, request: Request, call_next):
16561656
workspace_context.set_server_request_workspace(workspace_name)
16571657
workspace_set = True
16581658

1659+
if canonical_path.endswith("/graphql"):
1660+
# Let Flask authorize GraphQL to avoid consuming/rebuffering the ASGI body.
1661+
try:
1662+
return await call_next(request)
1663+
finally:
1664+
if workspace_set:
1665+
workspace_context.clear_server_request_workspace()
1666+
16591667
try:
16601668
auth_result = _authorize_request(
16611669
authorization_header=request.headers.get("Authorization"),
@@ -1759,6 +1767,15 @@ def _k8s_auth_before_request():
17591767
if _is_unprotected_path(canonical_path):
17601768
return None
17611769

1770+
graphql_payload: dict[str, object] | None = None
1771+
if canonical_path.endswith("/graphql"):
1772+
try:
1773+
payload = request.get_json(silent=True) or {}
1774+
if isinstance(payload, dict):
1775+
graphql_payload = payload
1776+
except Exception:
1777+
graphql_payload = None
1778+
17621779
try:
17631780
auth_result = _authorize_request(
17641781
authorization_header=request.headers.get("Authorization"),
@@ -1770,6 +1787,7 @@ def _k8s_auth_before_request():
17701787
authorizer=authorizer,
17711788
config_values=config_values,
17721789
workspace=workspace_context.get_request_workspace(),
1790+
graphql_payload=graphql_payload,
17731791
)
17741792
except MlflowException as exc:
17751793
response = Response(mimetype="application/json")

kubernetes-workspace-provider/tests/test_auth.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -593,7 +593,7 @@ def test_workspace_scope_string_is_normalized(monkeypatch):
593593

594594
monkeypatch.setattr(
595595
"kubernetes_workspace_provider.auth._find_authorization_rules",
596-
lambda path, method: [AuthorizationRule("list", resource="experiments")],
596+
lambda path, method, **kwargs: [AuthorizationRule("list", resource="experiments")],
597597
)
598598
monkeypatch.setattr(
599599
"kubernetes_workspace_provider.auth._parse_jwt_subject",
@@ -629,7 +629,7 @@ def test_workspace_listing_allows_missing_context(monkeypatch):
629629
)
630630
monkeypatch.setattr(
631631
"kubernetes_workspace_provider.auth._find_authorization_rules",
632-
lambda path, method: [rule],
632+
lambda path, method, **kwargs: [rule],
633633
)
634634
monkeypatch.setattr(
635635
"kubernetes_workspace_provider.auth._parse_jwt_subject",
@@ -659,7 +659,7 @@ def test_unmapped_endpoint_returns_not_found(monkeypatch):
659659
authorizer.is_allowed.return_value = True
660660
monkeypatch.setattr(
661661
"kubernetes_workspace_provider.auth._find_authorization_rules",
662-
lambda path, method: None,
662+
lambda path, method, **kwargs: None,
663663
)
664664

665665
config = KubernetesAuthConfig()
@@ -687,7 +687,7 @@ def test_subject_access_review_mode_uses_remote_headers(monkeypatch):
687687
rule = AuthorizationRule("list", resource=RESOURCE_EXPERIMENTS)
688688
monkeypatch.setattr(
689689
"kubernetes_workspace_provider.auth._find_authorization_rules",
690-
lambda path, method: [rule],
690+
lambda path, method, **kwargs: [rule],
691691
)
692692

693693
config = KubernetesAuthConfig(authorization_mode=AuthorizationMode.SUBJECT_ACCESS_REVIEW)
@@ -718,7 +718,7 @@ def test_subject_access_review_mode_requires_user_header(monkeypatch):
718718
rule = AuthorizationRule("get", resource=RESOURCE_EXPERIMENTS)
719719
monkeypatch.setattr(
720720
"kubernetes_workspace_provider.auth._find_authorization_rules",
721-
lambda path, method: [rule],
721+
lambda path, method, **kwargs: [rule],
722722
)
723723

724724
config = KubernetesAuthConfig(authorization_mode=AuthorizationMode.SUBJECT_ACCESS_REVIEW)
@@ -748,7 +748,7 @@ def test_gateway_endpoint_create_requires_model_definition_use(monkeypatch):
748748
rule = AuthorizationRule("create", resource=RESOURCE_GATEWAY_ENDPOINTS)
749749
monkeypatch.setattr(
750750
"kubernetes_workspace_provider.auth._find_authorization_rules",
751-
lambda path, method: [rule],
751+
lambda path, method, **kwargs: [rule],
752752
)
753753
monkeypatch.setattr(
754754
"kubernetes_workspace_provider.auth._parse_jwt_subject",
@@ -794,7 +794,7 @@ def test_gateway_endpoint_update_requires_model_definition_use(monkeypatch):
794794
rule = AuthorizationRule("update", resource=RESOURCE_GATEWAY_ENDPOINTS)
795795
monkeypatch.setattr(
796796
"kubernetes_workspace_provider.auth._find_authorization_rules",
797-
lambda path, method: [rule],
797+
lambda path, method, **kwargs: [rule],
798798
)
799799
monkeypatch.setattr(
800800
"kubernetes_workspace_provider.auth._parse_jwt_subject",
@@ -839,7 +839,7 @@ def test_gateway_model_definition_create_requires_secret_use(monkeypatch):
839839
rule = AuthorizationRule("create", resource=RESOURCE_GATEWAY_MODEL_DEFINITIONS)
840840
monkeypatch.setattr(
841841
"kubernetes_workspace_provider.auth._find_authorization_rules",
842-
lambda path, method: [rule],
842+
lambda path, method, **kwargs: [rule],
843843
)
844844
monkeypatch.setattr(
845845
"kubernetes_workspace_provider.auth._parse_jwt_subject",
@@ -884,7 +884,7 @@ def test_gateway_model_definition_update_requires_secret_use(monkeypatch):
884884
rule = AuthorizationRule("update", resource=RESOURCE_GATEWAY_MODEL_DEFINITIONS)
885885
monkeypatch.setattr(
886886
"kubernetes_workspace_provider.auth._find_authorization_rules",
887-
lambda path, method: [rule],
887+
lambda path, method, **kwargs: [rule],
888888
)
889889
monkeypatch.setattr(
890890
"kubernetes_workspace_provider.auth._parse_jwt_subject",
@@ -928,7 +928,7 @@ def test_workspace_scope_falls_back_to_view_args(monkeypatch):
928928
rule = AuthorizationRule(None, requires_workspace=False, workspace_access_check=True)
929929
monkeypatch.setattr(
930930
"kubernetes_workspace_provider.auth._find_authorization_rules",
931-
lambda path, method: [rule],
931+
lambda path, method, **kwargs: [rule],
932932
)
933933
monkeypatch.setattr(
934934
"kubernetes_workspace_provider.auth._parse_jwt_subject",
@@ -962,7 +962,7 @@ def test_workspace_create_requests_are_denied(monkeypatch):
962962
rule = AuthorizationRule("create", deny=True, requires_workspace=False)
963963
monkeypatch.setattr(
964964
"kubernetes_workspace_provider.auth._find_authorization_rules",
965-
lambda path, method: [rule],
965+
lambda path, method, **kwargs: [rule],
966966
)
967967
monkeypatch.setattr(
968968
"kubernetes_workspace_provider.auth._parse_jwt_subject",

kubernetes-workspace-provider/tests/test_auth_fastapi.py

Lines changed: 86 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,15 @@
33
These tests ensure OTEL and job APIs enforce workspace-aware authentication.
44
"""
55

6+
from types import SimpleNamespace
67
from unittest.mock import Mock, patch
78

89
import pytest
910
from fastapi import FastAPI
1011
from fastapi.middleware.wsgi import WSGIMiddleware
1112
from fastapi.testclient import TestClient
1213
from flask import Flask
14+
from flask import request as flask_request
1315
from kubernetes_workspace_provider.auth import (
1416
DEFAULT_REMOTE_GROUPS_HEADER,
1517
DEFAULT_REMOTE_GROUPS_SEPARATOR,
@@ -26,7 +28,6 @@
2628
from starlette.requests import Request
2729
from starlette.responses import JSONResponse
2830

29-
from mlflow.environment_variables import MLFLOW_ENABLE_WORKSPACES
3031
from mlflow.exceptions import MlflowException
3132
from mlflow.tracing.utils.otlp import OTLP_TRACES_PATH
3233
from mlflow.utils import workspace_context
@@ -490,6 +491,88 @@ def test_job_api_endpoints_prefer_forwarded_token_on_invalid_authorization(
490491
assert subresource is None
491492

492493

494+
def test_graphql_flask_authorizes_when_fastapi_defers(
495+
mock_authorizer, mock_config, monkeypatch
496+
) -> None:
497+
from kubernetes_workspace_provider.auth import create_app
498+
499+
flask_app = Flask(__name__)
500+
501+
@flask_app.post("/graphql")
502+
def graphql_endpoint():
503+
payload = flask_request.get_json(silent=True) or {}
504+
return {"query": payload.get("query")}
505+
506+
monkeypatch.setenv("MLFLOW_K8S_AUTH_CACHE_TTL_SECONDS", "300")
507+
fake_k8s_config = SimpleNamespace(
508+
host="https://cluster.local",
509+
ssl_ca_cert=None,
510+
verify_ssl=True,
511+
proxy=None,
512+
no_proxy=None,
513+
proxy_headers=None,
514+
safe_chars_for_path_param=None,
515+
connection_pool_maxsize=10,
516+
)
517+
with (
518+
patch(
519+
"kubernetes_workspace_provider.auth.KubernetesAuthorizer.is_allowed",
520+
return_value=True,
521+
) as flask_is_allowed,
522+
patch(
523+
"kubernetes_workspace_provider.auth._load_kubernetes_configuration",
524+
return_value=fake_k8s_config,
525+
),
526+
):
527+
create_app(flask_app)
528+
529+
fastapi_app = FastAPI()
530+
fastapi_app.mount("/", WSGIMiddleware(flask_app))
531+
532+
class _WorkspaceContextMiddleware(BaseHTTPMiddleware):
533+
async def dispatch(self, request: Request, call_next):
534+
workspace_header = request.headers.get(WORKSPACE_HEADER_NAME)
535+
if not workspace_header:
536+
return JSONResponse(
537+
status_code=400,
538+
content={"error": {"message": f"Missing {WORKSPACE_HEADER_NAME} header"}},
539+
)
540+
workspace_context.set_server_request_workspace(workspace_header)
541+
try:
542+
return await call_next(request)
543+
finally:
544+
workspace_context.clear_server_request_workspace()
545+
546+
fastapi_app.add_middleware(
547+
KubernetesAuthMiddleware,
548+
authorizer=mock_authorizer,
549+
config_values=mock_config,
550+
)
551+
fastapi_app.add_middleware(_WorkspaceContextMiddleware)
552+
553+
client = TestClient(fastapi_app)
554+
query = '{ mlflowGetExperiment(input: { experimentId: "123" }) { experiment { name } } }'
555+
with patch(
556+
"kubernetes_workspace_provider.auth._parse_jwt_subject",
557+
return_value="test-user",
558+
):
559+
response = client.post(
560+
"/graphql",
561+
headers={
562+
"Authorization": "Bearer valid-token",
563+
WORKSPACE_HEADER_NAME: "team-a",
564+
},
565+
json={"query": query},
566+
)
567+
568+
assert response.status_code == 200
569+
assert response.json()["query"] == query
570+
# FastAPI middleware should NOT have called its authorizer
571+
mock_authorizer.is_allowed.assert_not_called()
572+
# Flask's before_request handler SHOULD have called the authorizer
573+
flask_is_allowed.assert_called_once()
574+
575+
493576
def test_job_api_missing_workspace_context_returns_error(
494577
mock_authorizer, mock_config, monkeypatch
495578
) -> None:
@@ -499,10 +582,9 @@ def test_job_api_missing_workspace_context_returns_error(
499582
authorizer=mock_authorizer,
500583
config_values=mock_config,
501584
)
502-
monkeypatch.setenv(MLFLOW_ENABLE_WORKSPACES.name, "true")
503585
monkeypatch.setattr(
504-
"mlflow.server.workspace_helpers.get_default_workspace_optional",
505-
lambda _store: (None, False),
586+
"kubernetes_workspace_provider.auth.resolve_workspace_from_header",
587+
lambda _header: None,
506588
)
507589

508590
@app.get("/ajax-api/3.0/jobs/123")

0 commit comments

Comments
 (0)