|
1 | | -import json |
2 | 1 | import logging |
3 | | -from typing import Dict, Any |
4 | | -from fastapi import WebSocket |
5 | | -from config.config import DatabaseConfig |
| 2 | +from typing import Dict, Any, Optional |
| 3 | +from config.config import DatabaseConfig, SMTPConfig |
| 4 | +from email.mime.text import MIMEText |
| 5 | +import smtplib |
| 6 | +import asyncio |
| 7 | +from datetime import datetime |
| 8 | +from pydantic import BaseModel |
6 | 9 |
|
7 | 10 | logger = logging.getLogger("mcp.notifications") |
8 | 11 | logger.setLevel(logging.INFO) |
9 | 12 |
|
10 | | -class NotificationHandler: |
11 | | - def __init__(self): |
12 | | - self.active_connections: Dict[str, WebSocket] = {} |
13 | | - |
14 | | - async def connect(self, websocket: WebSocket, client_id: str): |
15 | | - await websocket.accept() |
16 | | - self.active_connections[client_id] = websocket |
17 | | - await self.notify(client_id, { |
18 | | - "method": "connection", |
19 | | - "params": {"status": "connected", "client_id": client_id} |
20 | | - }) |
21 | | - logger.info(f"WebSocket connected for client {client_id}") |
22 | | - |
23 | | - async def disconnect(self, client_id: str): |
24 | | - if client_id in self.active_connections: |
25 | | - await self.active_connections[client_id].close() |
26 | | - del self.active_connections[client_id] |
27 | | - logger.info(f"WebSocket disconnected for client {client_id}") |
| 13 | +class NotificationInput(BaseModel): |
| 14 | + user_id: str |
| 15 | + notification_type: str |
| 16 | + details: Dict[str, Any] |
28 | 17 |
|
29 | | - async def notify(self, client_id: str, message: Dict[str, Any]): |
30 | | - if client_id in self.active_connections: |
31 | | - try: |
32 | | - await self.active_connections[client_id].send_text(json.dumps(message)) |
33 | | - except Exception as e: |
34 | | - logger.error(f"Error sending notification to {client_id}: {str(e)}") |
35 | | - await self.disconnect(client_id) |
36 | | - |
37 | | - async def notify_auth(self, user_id: str, access_token: str): |
38 | | - await self.notify(user_id, { |
39 | | - "method": "auth.success", |
40 | | - "params": {"user_id": user_id, "access_token": access_token} |
41 | | - }) |
42 | | - logger.info(f"Notified auth success for {user_id}") |
43 | | - |
44 | | - async def notify_git_push(self, user_id: str, vial_id: str, commit_hash: str, balance: float): |
45 | | - await self.notify(user_id, { |
46 | | - "method": "vial_management.gitPush", |
47 | | - "params": {"vial_id": vial_id, "commit_hash": commit_hash, "balance": balance} |
48 | | - }) |
49 | | - logger.info(f"Notified Git push for {user_id}, vial {vial_id}") |
| 18 | +class NotificationHandler: |
| 19 | + def __init__(self, db: DatabaseConfig): |
| 20 | + self.db = db |
| 21 | + self.smtp_config = SMTPConfig() |
50 | 22 |
|
51 | | - async def notify_cash_out(self, user_id: str, transaction_id: str, amount: float, new_balance: float): |
52 | | - await self.notify(user_id, { |
53 | | - "method": "wallet.cashOut", |
54 | | - "params": {"transaction_id": transaction_id, "amount": amount, "new_balance": new_balance} |
55 | | - }) |
56 | | - logger.info(f"Notified cash-out for {user_id}, transaction {transaction_id}") |
| 23 | + async def send_notification(self, input: NotificationInput): |
| 24 | + try: |
| 25 | + # Store notification in database |
| 26 | + await self.db.query( |
| 27 | + """ |
| 28 | + INSERT INTO notifications (user_id, notification_type, details, created_at) |
| 29 | + VALUES ($1, $2, $3, $4) |
| 30 | + """, |
| 31 | + [ |
| 32 | + input.user_id, |
| 33 | + input.notification_type, |
| 34 | + json.dumps(input.details), |
| 35 | + datetime.utcnow() |
| 36 | + ] |
| 37 | + ) |
| 38 | + |
| 39 | + # Send email notification for critical events |
| 40 | + if input.notification_type in ["anomaly_detected", "data_erasure"]: |
| 41 | + subject = f"Vial MCP: {input.notification_type.replace('_', ' ').title()}" |
| 42 | + body = f"Notification for user {input.user_id}:\n\nType: {input.notification_type}\nDetails: {json.dumps(input.details, indent=2)}" |
| 43 | + |
| 44 | + msg = MIMEText(body) |
| 45 | + msg["Subject"] = subject |
| 46 | + msg["From"] = self.smtp_config.smtp_user |
| 47 | + msg["To"] = self.smtp_config.alert_email |
| 48 | + |
| 49 | + with smtplib.SMTP(self.smtp_config.smtp_server, self.smtp_config.smtp_port) as server: |
| 50 | + server.starttls() |
| 51 | + server.login(self.smtp_config.smtp_user, self.smtp_config.smtp_password) |
| 52 | + server.send_message(msg) |
| 53 | + |
| 54 | + logger.info(f"Sent notification email for {input.notification_type} to user {input.user_id}") |
| 55 | + |
| 56 | + logger.info(f"Stored notification for user {input.user_id}: {input.notification_type}") |
| 57 | + except Exception as e: |
| 58 | + logger.error(f"Error sending notification: {str(e)}") |
| 59 | + raise |
57 | 60 |
|
58 | | - async def notify_wallet_update(self, user_id: str, vial_id: str, balance: float): |
59 | | - await self.notify(user_id, { |
60 | | - "method": "wallet.update", |
61 | | - "params": {"vial_id": vial_id, "balance": balance} |
62 | | - }) |
63 | | - logger.info(f"Notified wallet update for {user_id}, vial {vial_id}") |
| 61 | + async def monitor_audit_logs(self): |
| 62 | + """Monitor audit logs for critical events and send notifications.""" |
| 63 | + try: |
| 64 | + while True: |
| 65 | + # Check for recent critical audit log entries |
| 66 | + critical_actions = await self.db.query( |
| 67 | + """ |
| 68 | + SELECT user_id, action, details FROM audit_logs |
| 69 | + WHERE action IN ('anomaly_detected', 'data_erasure') |
| 70 | + AND created_at > $1 |
| 71 | + AND notified IS NULL |
| 72 | + """, |
| 73 | + [datetime.utcnow() - timedelta(minutes=5)] |
| 74 | + ) |
| 75 | + |
| 76 | + for row in critical_actions.rows: |
| 77 | + await self.send_notification(NotificationInput( |
| 78 | + user_id=row["user_id"], |
| 79 | + notification_type=row["action"], |
| 80 | + details=json.loads(row["details"]) |
| 81 | + )) |
| 82 | + await self.db.query( |
| 83 | + "UPDATE audit_logs SET notified = TRUE WHERE user_id = $1 AND action = $2 AND created_at = $3", |
| 84 | + [row["user_id"], row["action"], row["created_at"]] |
| 85 | + ) |
| 86 | + |
| 87 | + await asyncio.sleep(60) # Check every minute |
| 88 | + except Exception as e: |
| 89 | + logger.error(f"Error monitoring audit logs: {str(e)}") |
0 commit comments