Skip to content

Commit 9605a4a

Browse files
fix(mcp-service): ensure Flask app context in auth hook and resolve Pydantic warnings (apache#36013)
1 parent c2baba5 commit 9605a4a

File tree

4 files changed

+137
-57
lines changed

4 files changed

+137
-57
lines changed

superset/mcp_service/auth.py

Lines changed: 108 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -46,23 +46,34 @@ def get_user_from_request() -> User:
4646
"""
4747
Get the current user for the MCP tool request.
4848
49-
TODO (future PR): Add JWT token extraction and validation.
50-
TODO (future PR): Add user impersonation support.
51-
TODO (future PR): Add fallback user configuration.
49+
Priority order:
50+
1. g.user if already set (by Preset workspace middleware)
51+
2. MCP_DEV_USERNAME from configuration (for development/testing)
5252
53-
For now, this returns the admin user for development.
53+
Returns:
54+
User object with roles and groups eagerly loaded
55+
56+
Raises:
57+
ValueError: If user cannot be authenticated or found
5458
"""
5559
from flask import current_app
5660
from sqlalchemy.orm import joinedload
5761

5862
from superset.extensions import db
5963

60-
# TODO: Extract from JWT token once authentication is implemented
61-
# For now, use MCP_DEV_USERNAME from configuration
64+
# First check if user is already set by Preset workspace middleware
65+
if hasattr(g, "user") and g.user:
66+
return g.user
67+
68+
# Fall back to configured username for development/single-user deployments
6269
username = current_app.config.get("MCP_DEV_USERNAME")
6370

6471
if not username:
65-
raise ValueError("Username not configured")
72+
raise ValueError(
73+
"No authenticated user found. "
74+
"Either pass a valid JWT bearer token or configure "
75+
"MCP_DEV_USERNAME for development."
76+
)
6677

6778
# Query user directly with eager loading to ensure fresh session-bound object
6879
# Do NOT use security_manager.find_user() as it may return cached/detached user
@@ -115,65 +126,109 @@ def has_dataset_access(dataset: "SqlaTable") -> bool:
115126
return False # Deny access on error
116127

117128

129+
def _setup_user_context() -> User:
130+
"""
131+
Set up user context for MCP tool execution.
132+
133+
Returns:
134+
User object with roles and groups loaded
135+
"""
136+
user = get_user_from_request()
137+
138+
# Validate user has necessary relationships loaded
139+
# (Force access to ensure they're loaded if lazy)
140+
user_roles = user.roles # noqa: F841
141+
if hasattr(user, "groups"):
142+
user_groups = user.groups # noqa: F841
143+
144+
g.user = user
145+
return user
146+
147+
148+
def _cleanup_session_on_error() -> None:
149+
"""Clean up database session after an exception."""
150+
from superset.extensions import db
151+
152+
# pylint: disable=consider-using-transaction
153+
try:
154+
db.session.rollback()
155+
db.session.remove()
156+
except Exception as e:
157+
logger.warning("Error cleaning up session after exception: %s", e)
158+
159+
160+
def _cleanup_session_finally() -> None:
161+
"""Clean up database session in finally block."""
162+
from superset.extensions import db
163+
164+
# Rollback active session (no exception occurred)
165+
# Do NOT call remove() on success to avoid detaching user
166+
try:
167+
if db.session.is_active:
168+
# pylint: disable=consider-using-transaction
169+
db.session.rollback()
170+
except Exception as e:
171+
logger.warning("Error in finally block: %s", e)
172+
173+
118174
def mcp_auth_hook(tool_func: F) -> F:
119175
"""
120176
Authentication and authorization decorator for MCP tools.
121177
122-
This is a minimal implementation that:
123-
1. Gets the current user
124-
2. Sets g.user for Flask context
178+
This decorator assumes Flask application context and g.user
179+
have already been set by WorkspaceContextMiddleware.
180+
181+
Supports both sync and async tool functions.
125182
126183
TODO (future PR): Add permission checking
127184
TODO (future PR): Add JWT scope validation
128185
TODO (future PR): Add comprehensive audit logging
129-
TODO (future PR): Add rate limiting integration
130186
"""
131187
import functools
188+
import inspect
132189

133-
@functools.wraps(tool_func)
134-
def wrapper(*args: Any, **kwargs: Any) -> Any:
135-
from superset.extensions import db
136-
137-
# Get user and set Flask context OUTSIDE try block
138-
user = get_user_from_request()
139-
140-
# Force load relationships NOW while session is definitely active
141-
_ = user.roles
142-
if hasattr(user, "groups"):
143-
_ = user.groups
144-
145-
g.user = user
146-
147-
try:
148-
# TODO: Add permission checks here in future PR
149-
# TODO: Add audit logging here in future PR
190+
is_async = inspect.iscoroutinefunction(tool_func)
150191

151-
logger.debug(
152-
"MCP tool call: user=%s, tool=%s", user.username, tool_func.__name__
153-
)
192+
if is_async:
154193

155-
result = tool_func(*args, **kwargs)
194+
@functools.wraps(tool_func)
195+
async def async_wrapper(*args: Any, **kwargs: Any) -> Any:
196+
user = _setup_user_context()
156197

157-
return result
158-
159-
except Exception:
160-
# On error, rollback and cleanup session
161-
# pylint: disable=consider-using-transaction
162198
try:
163-
db.session.rollback()
164-
db.session.remove()
165-
except Exception as e:
166-
logger.warning("Error cleaning up session after exception: %s", e)
167-
raise
168-
169-
finally:
170-
# Only rollback if session is still active (no exception occurred)
171-
# Do NOT call remove() on success to avoid detaching user
172-
try:
173-
if db.session.is_active:
174-
# pylint: disable=consider-using-transaction
175-
db.session.rollback()
176-
except Exception as e:
177-
logger.warning("Error in finally block: %s", e)
199+
logger.debug(
200+
"MCP tool call: user=%s, tool=%s",
201+
user.username,
202+
tool_func.__name__,
203+
)
204+
result = await tool_func(*args, **kwargs)
205+
return result
206+
except Exception:
207+
_cleanup_session_on_error()
208+
raise
209+
finally:
210+
_cleanup_session_finally()
211+
212+
return async_wrapper # type: ignore[return-value]
213+
214+
else:
215+
216+
@functools.wraps(tool_func)
217+
def sync_wrapper(*args: Any, **kwargs: Any) -> Any:
218+
user = _setup_user_context()
178219

179-
return wrapper # type: ignore[return-value]
220+
try:
221+
logger.debug(
222+
"MCP tool call: user=%s, tool=%s",
223+
user.username,
224+
tool_func.__name__,
225+
)
226+
result = tool_func(*args, **kwargs)
227+
return result
228+
except Exception:
229+
_cleanup_session_on_error()
230+
raise
231+
finally:
232+
_cleanup_session_finally()
233+
234+
return sync_wrapper # type: ignore[return-value]

superset/mcp_service/common/error_schemas.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,9 +48,16 @@ class ValidationError(BaseModel):
4848
class DatasetContext(BaseModel):
4949
"""Dataset information for error context"""
5050

51+
model_config = {"populate_by_name": True}
52+
5153
id: int = Field(..., description="Dataset ID")
5254
table_name: str = Field(..., description="Table name")
53-
schema: str | None = Field(None, description="Schema name")
55+
schema_name: str | None = Field(
56+
None,
57+
alias="schema",
58+
serialization_alias="schema",
59+
description="Schema name",
60+
)
5461
database_name: str = Field(..., description="Database name")
5562
available_columns: List[Dict[str, Any]] = Field(
5663
default_factory=list, description="Available columns with metadata"

superset/mcp_service/server.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,9 +82,26 @@ def run_server(
8282
factory_config = get_mcp_factory_config()
8383
mcp_instance = create_mcp_app(**factory_config)
8484
else:
85-
# Use default initialization
85+
# Use default initialization with auth from Flask config
8686
logging.info("Creating MCP app with default configuration...")
87-
mcp_instance = init_fastmcp_server()
87+
from superset.mcp_service.flask_singleton import get_flask_app
88+
89+
flask_app = get_flask_app()
90+
91+
# Get auth factory from config and create auth provider
92+
auth_provider = None
93+
auth_factory = flask_app.config.get("MCP_AUTH_FACTORY")
94+
if auth_factory:
95+
try:
96+
auth_provider = auth_factory(flask_app)
97+
logging.info(
98+
"Auth provider created: %s",
99+
type(auth_provider).__name__ if auth_provider else "None",
100+
)
101+
except Exception as e:
102+
logging.error("Failed to create auth provider: %s", e)
103+
104+
mcp_instance = init_fastmcp_server(auth=auth_provider)
88105

89106
env_key = f"FASTMCP_RUNNING_{port}"
90107
if not os.environ.get(env_key):

superset/mcp_service/system/tool/health_check.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import logging
2222
import platform
2323

24+
from fastmcp import Context
2425
from flask import current_app
2526

2627
from superset.mcp_service.app import mcp
@@ -33,7 +34,7 @@
3334

3435
@mcp.tool
3536
@mcp_auth_hook
36-
async def health_check() -> HealthCheckResponse:
37+
async def health_check(ctx: Context) -> HealthCheckResponse:
3738
"""
3839
Simple health check tool for testing the MCP service.
3940

0 commit comments

Comments
 (0)