Skip to content

Commit 3d3c24c

Browse files
authored
Update auth_tool.py
1 parent 401bc72 commit 3d3c24c

File tree

1 file changed

+135
-8
lines changed

1 file changed

+135
-8
lines changed

main/api/mcp/tools/auth_tool.py

Lines changed: 135 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import os
1313
import secrets
1414
from datetime import datetime, timedelta
15+
import pyotp
1516

1617
logger = logging.getLogger("mcp.auth")
1718
logger.setLevel(logging.INFO)
@@ -27,6 +28,7 @@ class AuthTokenInput(BaseModel):
2728
code: str
2829
redirect_uri: str
2930
code_verifier: str
31+
totp_code: Optional[str] = None
3032

3133
class AuthTokenOutput(BaseModel):
3234
access_token: str
@@ -44,12 +46,27 @@ class AuthRevokeOutput(BaseModel):
4446
class AuthRefreshInput(BaseModel):
4547
user_id: str
4648
refresh_token: str
49+
totp_code: Optional[str] = None
4750

4851
class AuthRefreshOutput(BaseModel):
4952
access_token: str
5053
refresh_token: str
5154
session_id: str
5255

56+
class AuthEnable2FAInput(BaseModel):
57+
user_id: str
58+
59+
class AuthEnable2FAOutput(BaseModel):
60+
secret: str
61+
qr_code_url: str
62+
63+
class AuthVerify2FAInput(BaseModel):
64+
user_id: str
65+
totp_code: str
66+
67+
class AuthVerify2FAOutput(BaseModel):
68+
status: str
69+
5370
class AuthTool:
5471
def __init__(self, db: DatabaseConfig):
5572
self.db = db
@@ -72,6 +89,12 @@ async def execute(self, input: Dict[str, Any]) -> Any:
7289
elif method == "refreshToken":
7390
refresh_input = AuthRefreshInput(**input)
7491
return await self.refresh_token(refresh_input)
92+
elif method == "enable2FA":
93+
enable_2fa_input = AuthEnable2FAInput(**input)
94+
return await self.enable_2fa(enable_2fa_input)
95+
elif method == "verify2FA":
96+
verify_2fa_input = AuthVerify2FAInput(**input)
97+
return await self.verify_2fa(verify_2fa_input)
7598
else:
7699
raise ValidationError(f"Unknown method: {method}")
77100
except Exception as e:
@@ -103,6 +126,11 @@ async def generate_api_credentials(self, input: AuthGenerateInput) -> AuthGenera
103126
user_id=input.user_id,
104127
details={"api_key": api_key}
105128
)
129+
await self.security_handler.log_user_action(
130+
user_id=input.user_id,
131+
action="generate_api_credentials",
132+
details={"api_key": api_key}
133+
)
106134
logger.info(f"Generated API credentials for {input.user_id}")
107135
return AuthGenerateOutput(api_key=api_key, api_secret=api_secret)
108136
except Exception as e:
@@ -114,6 +142,77 @@ async def generate_api_credentials(self, input: AuthGenerateInput) -> AuthGenera
114142
)
115143
raise HTTPException(400, str(e))
116144

145+
async def enable_2fa(self, input: AuthEnable2FAInput) -> AuthEnable2FAOutput:
146+
try:
147+
user = await self.db.query("SELECT user_id FROM users WHERE user_id = $1", [input.user_id])
148+
if not user.rows:
149+
raise ValidationError(f"User not found: {input.user_id}")
150+
151+
totp_secret = pyotp.random_base32()
152+
totp = pyotp.TOTP(totp_secret)
153+
qr_code_url = totp.provisioning_uri(name=input.user_id, issuer_name="Vial MCP")
154+
155+
await self.db.query(
156+
"UPDATE users SET totp_secret = $1 WHERE user_id = $2",
157+
[totp_secret, input.user_id]
158+
)
159+
160+
await self.security_handler.log_event(
161+
event_type="2fa_enabled",
162+
user_id=input.user_id,
163+
details={"secret": totp_secret[:8] + "..."}
164+
)
165+
await self.security_handler.log_user_action(
166+
user_id=input.user_id,
167+
action="enable_2fa",
168+
details={"secret": totp_secret[:8] + "..."}
169+
)
170+
logger.info(f"Enabled 2FA for user {input.user_id}")
171+
return AuthEnable2FAOutput(secret=totp_secret, qr_code_url=qr_code_url)
172+
except Exception as e:
173+
logger.error(f"Enable 2FA error: {str(e)}")
174+
await self.security_handler.log_event(
175+
event_type="2fa_enable_error",
176+
user_id=input.user_id,
177+
details={"error": str(e)}
178+
)
179+
raise HTTPException(400, str(e))
180+
181+
async def verify_2fa(self, input: AuthVerify2FAInput) -> AuthVerify2FAOutput:
182+
try:
183+
user = await self.db.query("SELECT totp_secret FROM users WHERE user_id = $1", [input.user_id])
184+
if not user.rows:
185+
raise ValidationError(f"User not found: {input.user_id}")
186+
187+
totp_secret = user.rows[0]["totp_secret"]
188+
if not totp_secret:
189+
raise ValidationError("2FA not enabled for this user")
190+
191+
totp = pyotp.TOTP(totp_secret)
192+
if not totp.verify(input.totp_code):
193+
raise ValidationError("Invalid 2FA code")
194+
195+
await self.security_handler.log_event(
196+
event_type="2fa_verified",
197+
user_id=input.user_id,
198+
details={}
199+
)
200+
await self.security_handler.log_user_action(
201+
user_id=input.user_id,
202+
action="verify_2fa",
203+
details={}
204+
)
205+
logger.info(f"Verified 2FA for user {input.user_id}")
206+
return AuthVerify2FAOutput(status="verified")
207+
except Exception as e:
208+
logger.error(f"Verify 2FA error: {str(e)}")
209+
await self.security_handler.log_event(
210+
event_type="2fa_verify_error",
211+
user_id=input.user_id,
212+
details={"error": str(e)}
213+
)
214+
raise HTTPException(400, str(e))
215+
117216
async def exchange_token(self, input: AuthTokenInput) -> AuthTokenOutput:
118217
try:
119218
if input.redirect_uri not in self.redirect_uri_allowlist:
@@ -147,12 +246,11 @@ async def exchange_token(self, input: AuthTokenInput) -> AuthTokenOutput:
147246
user_data = user_response.json()
148247
user_id = str(user_data["id"])
149248

150-
# Validate token audience
151249
if user_data.get("aud") != self.api_config.github_client_id:
152250
raise ValidationError("Invalid token audience")
153251

154252
existing_user = await self.db.query(
155-
"SELECT user_id FROM users WHERE user_id = $1",
253+
"SELECT user_id, totp_secret FROM users WHERE user_id = $1",
156254
[user_id]
157255
)
158256
if not existing_user.rows:
@@ -164,8 +262,14 @@ async def exchange_token(self, input: AuthTokenInput) -> AuthTokenOutput:
164262
from tools.wallet import WalletTool
165263
wallet_tool = WalletTool(self.db)
166264
await wallet_tool.initialize_new_wallet(user_id, wallet_address, str(uuid.uuid4()), str(uuid.uuid4()))
265+
else:
266+
if existing_user.rows[0]["totp_secret"] and not input.totp_code:
267+
raise ValidationError("2FA code required")
268+
if existing_user.rows[0]["totp_secret"]:
269+
totp = pyotp.TOTP(existing_user.rows[0]["totp_secret"])
270+
if not totp.verify(input.totp_code):
271+
raise ValidationError("Invalid 2FA code")
167272

168-
# Create secure session
169273
session_id = f"{user_id}:{secrets.token_urlsafe(32)}"
170274
expires_at = datetime.utcnow() + timedelta(minutes=15)
171275
await self.db.query(
@@ -183,6 +287,11 @@ async def exchange_token(self, input: AuthTokenInput) -> AuthTokenOutput:
183287
user_id=user_id,
184288
details={"access_token": access_token[:8] + "...", "session_id": session_id}
185289
)
290+
await self.security_handler.log_user_action(
291+
user_id=user_id,
292+
action="auth_exchange_token",
293+
details={"access_token": access_token[:8] + "...", "session_id": session_id}
294+
)
186295
logger.info(f"Exchanged OAuth token for user {user_id}")
187296
return AuthTokenOutput(access_token=access_token, refresh_token=refresh_token, user_id=user_id, session_id=session_id)
188297
except Exception as e:
@@ -203,13 +312,11 @@ async def revoke_token(self, input: AuthRevokeInput) -> AuthRevokeOutput:
203312
if not user.rows:
204313
raise ValidationError("Invalid user or token")
205314

206-
# Revoke access token
207315
await self.db.query(
208316
"UPDATE users SET access_token = NULL, refresh_token = NULL WHERE user_id = $1",
209317
[input.user_id]
210318
)
211319

212-
# Terminate session
213320
await self.db.query(
214321
"DELETE FROM sessions WHERE user_id = $1",
215322
[input.user_id]
@@ -220,6 +327,11 @@ async def revoke_token(self, input: AuthRevokeInput) -> AuthRevokeOutput:
220327
user_id=input.user_id,
221328
details={"access_token": input.access_token[:8] + "..."}
222329
)
330+
await self.security_handler.log_user_action(
331+
user_id=input.user_id,
332+
action="revoke_token",
333+
details={"access_token": input.access_token[:8] + "..."}
334+
)
223335
logger.info(f"Revoked token for user {input.user_id}")
224336
return AuthRevokeOutput(status="revoked")
225337
except Exception as e:
@@ -234,12 +346,19 @@ async def revoke_token(self, input: AuthRevokeInput) -> AuthRevokeOutput:
234346
async def refresh_token(self, input: AuthRefreshInput) -> AuthRefreshOutput:
235347
try:
236348
user = await self.db.query(
237-
"SELECT user_id, refresh_token FROM users WHERE user_id = $1 AND refresh_token = $2",
349+
"SELECT user_id, refresh_token, totp_secret FROM users WHERE user_id = $1 AND refresh_token = $2",
238350
[input.user_id, input.refresh_token]
239351
)
240352
if not user.rows:
241353
raise ValidationError("Invalid user or refresh token")
242354

355+
if user.rows[0]["totp_secret"] and not input.totp_code:
356+
raise ValidationError("2FA code required")
357+
if user.rows[0]["totp_secret"]:
358+
totp = pyotp.TOTP(user.rows[0]["totp_secret"])
359+
if not totp.verify(input.totp_code):
360+
raise ValidationError("Invalid 2FA code")
361+
243362
async with httpx.AsyncClient() as client:
244363
response = await client.post(
245364
"https://github.com/login/oauth/access_token",
@@ -260,7 +379,6 @@ async def refresh_token(self, input: AuthRefreshInput) -> AuthRefreshOutput:
260379
new_access_token = data["access_token"]
261380
new_refresh_token = data.get("refresh_token", str(uuid.uuid4()))
262381

263-
# Create new session
264382
session_id = f"{input.user_id}:{secrets.token_urlsafe(32)}"
265383
expires_at = datetime.utcnow() + timedelta(minutes=15)
266384
await self.db.query(
@@ -278,6 +396,11 @@ async def refresh_token(self, input: AuthRefreshInput) -> AuthRefreshOutput:
278396
user_id=input.user_id,
279397
details={"access_token": new_access_token[:8] + "...", "session_id": session_id}
280398
)
399+
await self.security_handler.log_user_action(
400+
user_id=input.user_id,
401+
action="refresh_token",
402+
details={"access_token": new_access_token[:8] + "...", "session_id": session_id}
403+
)
281404
logger.info(f"Refreshed token for user {input.user_id}")
282405
return AuthRefreshOutput(access_token=new_access_token, refresh_token=new_refresh_token, session_id=session_id)
283406
except Exception as e:
@@ -291,7 +414,6 @@ async def refresh_token(self, input: AuthRefreshInput) -> AuthRefreshOutput:
291414

292415
async def verify_token(self, token: str, session_id: str) -> Dict[str, Any]:
293416
try:
294-
# Verify session
295417
session = await self.db.query(
296418
"SELECT session_key, user_id, expires_at FROM sessions WHERE session_key = $1",
297419
[session_id]
@@ -311,6 +433,11 @@ async def verify_token(self, token: str, session_id: str) -> Dict[str, Any]:
311433
user_id=user.rows[0]["user_id"],
312434
details={"session_id": session_id}
313435
)
436+
await self.security_handler.log_user_action(
437+
user_id=user.rows[0]["user_id"],
438+
action="verify_token",
439+
details={"session_id": session_id}
440+
)
314441
return {"user_id": user.rows[0]["user_id"]}
315442
except Exception as e:
316443
logger.error(f"Verify token error: {str(e)}")

0 commit comments

Comments
 (0)