|
12 | 12 | import asyncio |
13 | 13 | import logging |
14 | 14 | import time |
| 15 | +from datetime import datetime, timezone, timedelta |
15 | 16 |
|
16 | 17 | from fastapi import APIRouter, Request, HTTPException |
17 | 18 | from fastapi.responses import JSONResponse, StreamingResponse |
18 | 19 |
|
| 20 | +from config import settings |
| 21 | +from crud.mcp_approvals import create_approval_request, get_approval_status, get_approved_request, get_pending_request |
19 | 22 | from crud.mcp_tools import get_all_tools |
20 | 23 | from services.mcp_audit_service import log_tool_call |
21 | 24 | from services.mcp_guardrail_service import scan_tool_input |
|
34 | 37 | router = APIRouter() |
35 | 38 |
|
36 | 39 |
|
37 | | -def _jsonrpc_error(id, code: int, message: str) -> dict: |
38 | | - return {"jsonrpc": "2.0", "id": id, "error": {"code": code, "message": message}} |
| 40 | +def _jsonrpc_error(id, code: int, message: str, data: dict | None = None) -> dict: |
| 41 | + error = {"code": code, "message": message} |
| 42 | + if data: |
| 43 | + error["data"] = data |
| 44 | + return {"jsonrpc": "2.0", "id": id, "error": error} |
39 | 45 |
|
40 | 46 |
|
41 | 47 | def _jsonrpc_result(id, result: dict) -> dict: |
@@ -135,7 +141,28 @@ async def _audit(status: str, summary: str | None, is_error: bool): |
135 | 141 | await enforce_mcp_rate_limits(agent_key, tool_name) |
136 | 142 |
|
137 | 143 | if tool.get("requires_approval"): |
138 | | - return JSONResponse(content=_jsonrpc_error(msg_id, -32001, "Tool requires approval"), status_code=200) |
| 144 | + # Check if an approved request already exists for this agent+tool |
| 145 | + approved = await get_approved_request(org_id, agent_key["id"], tool_name) |
| 146 | + if not approved: |
| 147 | + # Reuse an existing pending request if one exists, otherwise create a new one |
| 148 | + pending = await get_pending_request(org_id, agent_key["id"], tool_name) |
| 149 | + if pending: |
| 150 | + approval = pending |
| 151 | + else: |
| 152 | + expires_at = datetime.now(timezone.utc) + timedelta(seconds=settings.mcp_approval_expiry_seconds) |
| 153 | + approval = await create_approval_request(org_id, { |
| 154 | + "agent_key_id": agent_key["id"], |
| 155 | + "tool_id": tool.get("id"), |
| 156 | + "tool_name": tool_name, |
| 157 | + "arguments": arguments, |
| 158 | + "expires_at": expires_at, |
| 159 | + }) |
| 160 | + await _audit("approval_required", f"Approval request {approval.get('id')} created", False) |
| 161 | + return JSONResponse(content=_jsonrpc_error(msg_id, -32001, "Tool requires approval", { |
| 162 | + "approval_id": approval.get("id"), |
| 163 | + "poll_endpoint": f"/v1/mcp/approvals/{approval.get('id')}/status", |
| 164 | + "expires_at": approval["expires_at"].isoformat() if hasattr(approval["expires_at"], "isoformat") else str(approval["expires_at"]), |
| 165 | + }), status_code=200) |
139 | 166 |
|
140 | 167 | scan_result = await scan_tool_input(org_id, tool_name, arguments) |
141 | 168 | if scan_result and scan_result.blocked: |
@@ -182,6 +209,25 @@ async def _audit(status: str, summary: str | None, is_error: bool): |
182 | 209 | return JSONResponse(content=_jsonrpc_error(msg_id, -32601, f"Method not found: {method}"), status_code=200) |
183 | 210 |
|
184 | 211 |
|
| 212 | +@router.get("/v1/mcp/approvals/{request_id}/status") |
| 213 | +async def mcp_approval_status(request: Request, request_id: int): |
| 214 | + """Poll approval status — authenticated via agent key.""" |
| 215 | + agent_key = await _extract_agent_key(request) |
| 216 | + org_id = agent_key["organization_id"] |
| 217 | + |
| 218 | + approval = await get_approval_status(org_id, request_id) |
| 219 | + if not approval: |
| 220 | + raise HTTPException(status_code=404, detail="Approval request not found") |
| 221 | + |
| 222 | + return { |
| 223 | + "approval_id": approval["id"], |
| 224 | + "status": approval["status"], |
| 225 | + "decided_at": approval["decided_at"].isoformat() if approval.get("decided_at") else None, |
| 226 | + "decision_reason": approval.get("decision_reason"), |
| 227 | + "expires_at": approval["expires_at"].isoformat() if approval.get("expires_at") else None, |
| 228 | + } |
| 229 | + |
| 230 | + |
185 | 231 | @router.get("/v1/mcp") |
186 | 232 | async def mcp_sse(request: Request): |
187 | 233 | """SSE endpoint for server-initiated messages. Keep-alive for v1.""" |
|
0 commit comments