1212import os
1313import secrets
1414from datetime import datetime , timedelta
15+ import pyotp
1516
1617logger = logging .getLogger ("mcp.auth" )
1718logger .setLevel (logging .INFO )
@@ -27,6 +28,7 @@ class AuthTokenInput(BaseModel):
2728 code : str
2829 redirect_uri : str
2930 code_verifier : str
31+ totp_code : Optional [str ] = None
3032
3133class AuthTokenOutput (BaseModel ):
3234 access_token : str
@@ -44,12 +46,27 @@ class AuthRevokeOutput(BaseModel):
4446class AuthRefreshInput (BaseModel ):
4547 user_id : str
4648 refresh_token : str
49+ totp_code : Optional [str ] = None
4750
4851class AuthRefreshOutput (BaseModel ):
4952 access_token : str
5053 refresh_token : str
5154 session_id : str
5255
56+ class AuthEnable2FAInput (BaseModel ):
57+ user_id : str
58+
59+ class AuthEnable2FAOutput (BaseModel ):
60+ secret : str
61+ qr_code_url : str
62+
63+ class AuthVerify2FAInput (BaseModel ):
64+ user_id : str
65+ totp_code : str
66+
67+ class AuthVerify2FAOutput (BaseModel ):
68+ status : str
69+
5370class AuthTool :
5471 def __init__ (self , db : DatabaseConfig ):
5572 self .db = db
@@ -72,6 +89,12 @@ async def execute(self, input: Dict[str, Any]) -> Any:
7289 elif method == "refreshToken" :
7390 refresh_input = AuthRefreshInput (** input )
7491 return await self .refresh_token (refresh_input )
92+ elif method == "enable2FA" :
93+ enable_2fa_input = AuthEnable2FAInput (** input )
94+ return await self .enable_2fa (enable_2fa_input )
95+ elif method == "verify2FA" :
96+ verify_2fa_input = AuthVerify2FAInput (** input )
97+ return await self .verify_2fa (verify_2fa_input )
7598 else :
7699 raise ValidationError (f"Unknown method: { method } " )
77100 except Exception as e :
@@ -103,6 +126,11 @@ async def generate_api_credentials(self, input: AuthGenerateInput) -> AuthGenera
103126 user_id = input .user_id ,
104127 details = {"api_key" : api_key }
105128 )
129+ await self .security_handler .log_user_action (
130+ user_id = input .user_id ,
131+ action = "generate_api_credentials" ,
132+ details = {"api_key" : api_key }
133+ )
106134 logger .info (f"Generated API credentials for { input .user_id } " )
107135 return AuthGenerateOutput (api_key = api_key , api_secret = api_secret )
108136 except Exception as e :
@@ -114,6 +142,77 @@ async def generate_api_credentials(self, input: AuthGenerateInput) -> AuthGenera
114142 )
115143 raise HTTPException (400 , str (e ))
116144
145+ async def enable_2fa (self , input : AuthEnable2FAInput ) -> AuthEnable2FAOutput :
146+ try :
147+ user = await self .db .query ("SELECT user_id FROM users WHERE user_id = $1" , [input .user_id ])
148+ if not user .rows :
149+ raise ValidationError (f"User not found: { input .user_id } " )
150+
151+ totp_secret = pyotp .random_base32 ()
152+ totp = pyotp .TOTP (totp_secret )
153+ qr_code_url = totp .provisioning_uri (name = input .user_id , issuer_name = "Vial MCP" )
154+
155+ await self .db .query (
156+ "UPDATE users SET totp_secret = $1 WHERE user_id = $2" ,
157+ [totp_secret , input .user_id ]
158+ )
159+
160+ await self .security_handler .log_event (
161+ event_type = "2fa_enabled" ,
162+ user_id = input .user_id ,
163+ details = {"secret" : totp_secret [:8 ] + "..." }
164+ )
165+ await self .security_handler .log_user_action (
166+ user_id = input .user_id ,
167+ action = "enable_2fa" ,
168+ details = {"secret" : totp_secret [:8 ] + "..." }
169+ )
170+ logger .info (f"Enabled 2FA for user { input .user_id } " )
171+ return AuthEnable2FAOutput (secret = totp_secret , qr_code_url = qr_code_url )
172+ except Exception as e :
173+ logger .error (f"Enable 2FA error: { str (e )} " )
174+ await self .security_handler .log_event (
175+ event_type = "2fa_enable_error" ,
176+ user_id = input .user_id ,
177+ details = {"error" : str (e )}
178+ )
179+ raise HTTPException (400 , str (e ))
180+
181+ async def verify_2fa (self , input : AuthVerify2FAInput ) -> AuthVerify2FAOutput :
182+ try :
183+ user = await self .db .query ("SELECT totp_secret FROM users WHERE user_id = $1" , [input .user_id ])
184+ if not user .rows :
185+ raise ValidationError (f"User not found: { input .user_id } " )
186+
187+ totp_secret = user .rows [0 ]["totp_secret" ]
188+ if not totp_secret :
189+ raise ValidationError ("2FA not enabled for this user" )
190+
191+ totp = pyotp .TOTP (totp_secret )
192+ if not totp .verify (input .totp_code ):
193+ raise ValidationError ("Invalid 2FA code" )
194+
195+ await self .security_handler .log_event (
196+ event_type = "2fa_verified" ,
197+ user_id = input .user_id ,
198+ details = {}
199+ )
200+ await self .security_handler .log_user_action (
201+ user_id = input .user_id ,
202+ action = "verify_2fa" ,
203+ details = {}
204+ )
205+ logger .info (f"Verified 2FA for user { input .user_id } " )
206+ return AuthVerify2FAOutput (status = "verified" )
207+ except Exception as e :
208+ logger .error (f"Verify 2FA error: { str (e )} " )
209+ await self .security_handler .log_event (
210+ event_type = "2fa_verify_error" ,
211+ user_id = input .user_id ,
212+ details = {"error" : str (e )}
213+ )
214+ raise HTTPException (400 , str (e ))
215+
117216 async def exchange_token (self , input : AuthTokenInput ) -> AuthTokenOutput :
118217 try :
119218 if input .redirect_uri not in self .redirect_uri_allowlist :
@@ -147,12 +246,11 @@ async def exchange_token(self, input: AuthTokenInput) -> AuthTokenOutput:
147246 user_data = user_response .json ()
148247 user_id = str (user_data ["id" ])
149248
150- # Validate token audience
151249 if user_data .get ("aud" ) != self .api_config .github_client_id :
152250 raise ValidationError ("Invalid token audience" )
153251
154252 existing_user = await self .db .query (
155- "SELECT user_id FROM users WHERE user_id = $1" ,
253+ "SELECT user_id, totp_secret FROM users WHERE user_id = $1" ,
156254 [user_id ]
157255 )
158256 if not existing_user .rows :
@@ -164,8 +262,14 @@ async def exchange_token(self, input: AuthTokenInput) -> AuthTokenOutput:
164262 from tools .wallet import WalletTool
165263 wallet_tool = WalletTool (self .db )
166264 await wallet_tool .initialize_new_wallet (user_id , wallet_address , str (uuid .uuid4 ()), str (uuid .uuid4 ()))
265+ else :
266+ if existing_user .rows [0 ]["totp_secret" ] and not input .totp_code :
267+ raise ValidationError ("2FA code required" )
268+ if existing_user .rows [0 ]["totp_secret" ]:
269+ totp = pyotp .TOTP (existing_user .rows [0 ]["totp_secret" ])
270+ if not totp .verify (input .totp_code ):
271+ raise ValidationError ("Invalid 2FA code" )
167272
168- # Create secure session
169273 session_id = f"{ user_id } :{ secrets .token_urlsafe (32 )} "
170274 expires_at = datetime .utcnow () + timedelta (minutes = 15 )
171275 await self .db .query (
@@ -183,6 +287,11 @@ async def exchange_token(self, input: AuthTokenInput) -> AuthTokenOutput:
183287 user_id = user_id ,
184288 details = {"access_token" : access_token [:8 ] + "..." , "session_id" : session_id }
185289 )
290+ await self .security_handler .log_user_action (
291+ user_id = user_id ,
292+ action = "auth_exchange_token" ,
293+ details = {"access_token" : access_token [:8 ] + "..." , "session_id" : session_id }
294+ )
186295 logger .info (f"Exchanged OAuth token for user { user_id } " )
187296 return AuthTokenOutput (access_token = access_token , refresh_token = refresh_token , user_id = user_id , session_id = session_id )
188297 except Exception as e :
@@ -203,13 +312,11 @@ async def revoke_token(self, input: AuthRevokeInput) -> AuthRevokeOutput:
203312 if not user .rows :
204313 raise ValidationError ("Invalid user or token" )
205314
206- # Revoke access token
207315 await self .db .query (
208316 "UPDATE users SET access_token = NULL, refresh_token = NULL WHERE user_id = $1" ,
209317 [input .user_id ]
210318 )
211319
212- # Terminate session
213320 await self .db .query (
214321 "DELETE FROM sessions WHERE user_id = $1" ,
215322 [input .user_id ]
@@ -220,6 +327,11 @@ async def revoke_token(self, input: AuthRevokeInput) -> AuthRevokeOutput:
220327 user_id = input .user_id ,
221328 details = {"access_token" : input .access_token [:8 ] + "..." }
222329 )
330+ await self .security_handler .log_user_action (
331+ user_id = input .user_id ,
332+ action = "revoke_token" ,
333+ details = {"access_token" : input .access_token [:8 ] + "..." }
334+ )
223335 logger .info (f"Revoked token for user { input .user_id } " )
224336 return AuthRevokeOutput (status = "revoked" )
225337 except Exception as e :
@@ -234,12 +346,19 @@ async def revoke_token(self, input: AuthRevokeInput) -> AuthRevokeOutput:
234346 async def refresh_token (self , input : AuthRefreshInput ) -> AuthRefreshOutput :
235347 try :
236348 user = await self .db .query (
237- "SELECT user_id, refresh_token FROM users WHERE user_id = $1 AND refresh_token = $2" ,
349+ "SELECT user_id, refresh_token, totp_secret FROM users WHERE user_id = $1 AND refresh_token = $2" ,
238350 [input .user_id , input .refresh_token ]
239351 )
240352 if not user .rows :
241353 raise ValidationError ("Invalid user or refresh token" )
242354
355+ if user .rows [0 ]["totp_secret" ] and not input .totp_code :
356+ raise ValidationError ("2FA code required" )
357+ if user .rows [0 ]["totp_secret" ]:
358+ totp = pyotp .TOTP (user .rows [0 ]["totp_secret" ])
359+ if not totp .verify (input .totp_code ):
360+ raise ValidationError ("Invalid 2FA code" )
361+
243362 async with httpx .AsyncClient () as client :
244363 response = await client .post (
245364 "https://github.com/login/oauth/access_token" ,
@@ -260,7 +379,6 @@ async def refresh_token(self, input: AuthRefreshInput) -> AuthRefreshOutput:
260379 new_access_token = data ["access_token" ]
261380 new_refresh_token = data .get ("refresh_token" , str (uuid .uuid4 ()))
262381
263- # Create new session
264382 session_id = f"{ input .user_id } :{ secrets .token_urlsafe (32 )} "
265383 expires_at = datetime .utcnow () + timedelta (minutes = 15 )
266384 await self .db .query (
@@ -278,6 +396,11 @@ async def refresh_token(self, input: AuthRefreshInput) -> AuthRefreshOutput:
278396 user_id = input .user_id ,
279397 details = {"access_token" : new_access_token [:8 ] + "..." , "session_id" : session_id }
280398 )
399+ await self .security_handler .log_user_action (
400+ user_id = input .user_id ,
401+ action = "refresh_token" ,
402+ details = {"access_token" : new_access_token [:8 ] + "..." , "session_id" : session_id }
403+ )
281404 logger .info (f"Refreshed token for user { input .user_id } " )
282405 return AuthRefreshOutput (access_token = new_access_token , refresh_token = new_refresh_token , session_id = session_id )
283406 except Exception as e :
@@ -291,7 +414,6 @@ async def refresh_token(self, input: AuthRefreshInput) -> AuthRefreshOutput:
291414
292415 async def verify_token (self , token : str , session_id : str ) -> Dict [str , Any ]:
293416 try :
294- # Verify session
295417 session = await self .db .query (
296418 "SELECT session_key, user_id, expires_at FROM sessions WHERE session_key = $1" ,
297419 [session_id ]
@@ -311,6 +433,11 @@ async def verify_token(self, token: str, session_id: str) -> Dict[str, Any]:
311433 user_id = user .rows [0 ]["user_id" ],
312434 details = {"session_id" : session_id }
313435 )
436+ await self .security_handler .log_user_action (
437+ user_id = user .rows [0 ]["user_id" ],
438+ action = "verify_token" ,
439+ details = {"session_id" : session_id }
440+ )
314441 return {"user_id" : user .rows [0 ]["user_id" ]}
315442 except Exception as e :
316443 logger .error (f"Verify token error: { str (e )} " )
0 commit comments