-
Notifications
You must be signed in to change notification settings - Fork 314
Add agentic functionalities to yeti #1267
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
d12a62a
5a3aa42
011f1f5
8cdcd16
d55b076
491ffe2
0538715
84bc9b7
ee0e044
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,163 @@ | ||
| import asyncio | ||
| import json | ||
|
|
||
| import httpx | ||
| import websockets | ||
| from typing import Any, Dict, List | ||
| from fastapi import APIRouter, FastAPI, Request, WebSocket, WebSocketDisconnect, HTTPException | ||
| from fastapi.responses import StreamingResponse | ||
| from pydantic import BaseModel, Field | ||
|
|
||
| from core.config.config import yeti_config | ||
| from core.schemas import roles | ||
| from core.schemas.rbac import global_permission | ||
|
|
||
| router = APIRouter() | ||
|
|
||
| # Configuration | ||
| AGENT_HTTP_BASE = yeti_config.get("agents", "http_root") | ||
| AGENT_WEBSOCKET_BASE = yeti_config.get("agents", "websocket_root") | ||
|
|
||
| AGENT_STREAM_ENDPOINT = f"{AGENT_HTTP_BASE}/run_stream" | ||
| AGENT_LIST_SESSIONS_ENDPOINT = f"{AGENT_HTTP_BASE}/sessions/{{user_id}}" | ||
| AGENT_WEBSOCKET_ENDPOINT = f"{AGENT_WEBSOCKET_BASE}/ws/chat" | ||
|
|
||
| ASYNC_TIMEOUT = httpx.Timeout(timeout=60.0) | ||
|
|
||
| class ADKSession(BaseModel): | ||
| id: str | ||
| appName: str | ||
| userId: str | ||
| state: Dict[str, Any] = Field(default_factory=dict) | ||
| events: List[Dict[str, Any]] = Field(default_factory=list) | ||
| lastUpdateTime: float = 0.0 | ||
|
|
||
| @router.get("/sessions") | ||
| @global_permission(roles.Permission.READ) | ||
| async def list_sessions_proxy(httpreq: Request) -> List[ADKSession]: | ||
| """ | ||
| Proxies the request to retrieve sessions for a given user from the Agent Service. | ||
| """ | ||
| user_id = httpreq.state.username | ||
| agent_url = f"{AGENT_LIST_SESSIONS_ENDPOINT.format(user_id=user_id)}" | ||
| async with httpx.AsyncClient(timeout=ASYNC_TIMEOUT) as client: | ||
| response = await client.get(agent_url) | ||
| print(response) | ||
| if response.status_code != 200: | ||
| raise HTTPException(status_code=response.status_code, detail=response.text) | ||
|
|
||
| # Parse the JSON response from the agent service into our Pydantic model | ||
| # which validates it matches the expected schema | ||
| items = response.json() | ||
| print(items) | ||
| return [ADKSession(**item) for item in items] | ||
|
|
||
| @router.post("/stream") | ||
| @global_permission(roles.Permission.READ) | ||
| async def chat_proxy(httpreq: Request, message: dict): | ||
| """ | ||
| 1. Authenticates user. | ||
| 2. Fetches relevant context (optional). | ||
| 3. Forwards request to Agent Service. | ||
| 4. Streams response back to Frontend. | ||
| """ | ||
|
|
||
| username = httpreq.state.username | ||
|
|
||
| # # 1. Inject Context (RAG or Database lookup) | ||
| # # E.g., "Alice is an admin looking at dashboard page X" | ||
| # system_context = ( | ||
| # f"User {username} is asking about page {message.get('current_page')}" | ||
| # ) | ||
|
|
||
| # 2. Prepare Payload for Agent | ||
| agent_payload = { | ||
| "user_id": username, | ||
| "session_id": message.get("session_id"), | ||
| "text": message.get("text"), | ||
| # "context_override": system_context, # Custom field your agent knows how to handle | ||
| } | ||
|
|
||
| # 3. Stream the response from the Agent Service | ||
| async def proxy_stream(): | ||
| async with httpx.AsyncClient(timeout=ASYNC_TIMEOUT) as client: | ||
| async with client.stream( | ||
| "POST", AGENT_STREAM_ENDPOINT, json=agent_payload | ||
| ) as r: | ||
| async for chunk in r.aiter_bytes(): | ||
| yield chunk | ||
|
|
||
| return StreamingResponse(proxy_stream(), media_type="text/event-stream") | ||
|
|
||
|
|
||
| @router.websocket("/api/v2/chat_proxy") | ||
| @global_permission(roles.Permission.READ) | ||
| async def chat_proxy_endpoint(httpreq: Request, client_ws: WebSocket): | ||
| """ | ||
| 1. Accepts connection from Vue.js | ||
| 2. Authenticates user (via Cookie or Query Param). | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do you think this is enough as a security pov? |
||
| 3. Connects to Agent Service. | ||
| 4. Injects User ID and forwards messages bi-directionally. | ||
| """ | ||
|
|
||
| # --- 1. Handshake & Auth --- | ||
| # In WebSockets, headers are hard to customize on the client. | ||
| # We often read the token from query params: ws://.../proxy?token=xyz | ||
| # token = client_ws.query_params.get("token") | ||
| # user = validate_token(token) # Your custom auth logic | ||
|
|
||
| # if not user: | ||
| # await client_ws.close(code=1008) # Policy Violation | ||
| # return | ||
|
|
||
| await client_ws.accept() | ||
|
|
||
| # --- 2. The Tunnel Loop --- | ||
| try: | ||
| # Connect to the Agent Service as a client | ||
| async with websockets.connect(AGENT_WEBSOCKET_ENDPOINT) as agent_ws: | ||
|
|
||
| # Task A: Listen to Frontend -> Inject ID -> Send to Agent | ||
| async def forward_to_agent(): | ||
| try: | ||
| while True: | ||
| # Wait for message from Vue | ||
| data = await client_ws.receive_text() | ||
| message_payload = json.loads(data) | ||
|
|
||
| # SECURITY: Overwrite/Inject the verified User ID | ||
| # This ensures the Agent Service trusts the ID provided by the Proxy | ||
| message_payload["user_id"] = user["id"] | ||
|
|
||
| # Forward to Agent Service | ||
| await agent_ws.send(json.dumps(message_payload)) | ||
| except WebSocketDisconnect: | ||
| # Frontend disconnected | ||
| pass | ||
| except Exception as e: | ||
| print(f"Error forwarding to agent: {e}") | ||
|
|
||
| # Task B: Listen to Agent -> Forward to Frontend | ||
| async def forward_to_client(): | ||
| try: | ||
| async for message in agent_ws: | ||
| # Forward raw message (tokens/JSON) back to Vue | ||
| await client_ws.send_text(message) | ||
| except Exception as e: | ||
| print(f"Error forwarding to client: {e}") | ||
|
|
||
| # --- 3. Run both directions concurrently --- | ||
| # If either side disconnects, the gather will eventually exit/cancel | ||
| await asyncio.gather( | ||
| forward_to_agent(), | ||
| forward_to_client(), | ||
| return_exceptions=True | ||
| ) | ||
|
|
||
| except Exception as e: | ||
| print(f"Proxy Connection Error: {e}") | ||
| # Ensure client socket is closed if upstream fails | ||
| try: | ||
| await client_ws.close() | ||
| except: | ||
| pass | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,88 @@ | ||
| import logging | ||
| from datetime import timedelta | ||
| import json | ||
|
|
||
| import httpx | ||
|
|
||
| from core import taskmanager | ||
| from core.config.config import yeti_config | ||
| from core.schemas import observable, task | ||
| from core.schemas.entities import investigation | ||
|
|
||
| AGENT_HTTP_BASE = yeti_config.get("agents", "http_root") | ||
| AGENT_STREAM_ENDPOINT = f"{AGENT_HTTP_BASE}/run_stream?agent_name=ioc_analyzer" | ||
|
|
||
| FILTER_TAG = "extract_investigation" | ||
|
|
||
|
|
||
| class UrlExtractInvestigation(task.AnalyticsTask): | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The name is confusing :) I initially thought it was only meant to extract URL from a report. However, correct me if I'm wrong, it's meant to extractr IOCs from a provided URL which corresponds to an article / report.
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah yeah, it's meant to Extract an investigation from a URL! I'm open to suggestions for a better name :) |
||
| _defaults = { | ||
| "name": "UrlExtractInvestigation", | ||
| "description": f"Extracts investigation details (summaries, IOCs, etc.) from URLs tagged with '{FILTER_TAG}' using LLMs", | ||
| "frequency": timedelta(hours=1), | ||
| } | ||
|
|
||
| def run(self): | ||
| urls, _ = observable.Observable.filter( | ||
| query_args={"tags.name": FILTER_TAG, "type": "url"} | ||
| ) | ||
|
|
||
| with httpx.Client(timeout=120.0) as client: | ||
| for url_obs in urls: | ||
| self.process_url(client, AGENT_STREAM_ENDPOINT, url_obs) | ||
|
|
||
| def process_url( | ||
| self, client: httpx.Client, endpoint: str, url_obs: observable.Observable | ||
| ): | ||
| payload = { | ||
| "user_id": "analytics_task", | ||
| "session_id": f"extract_investigation_{url_obs.id}", | ||
| "text": f"Analyze {url_obs.value} as per your instructions.", | ||
| } | ||
|
|
||
| try: | ||
| last_response = "" | ||
| with client.stream("POST", endpoint, json=payload) as response: | ||
| response.raise_for_status() | ||
| for chunk in response.iter_text(): | ||
| print(chunk) | ||
| parsed_event = json.loads(chunk[6:].strip()) | ||
| for part in parsed_event["content"]["parts"]: | ||
| if "text" in part and not part.get("thought", False): | ||
| last_response = part["text"] | ||
|
|
||
| parsed_report = json.loads(last_response) | ||
| self.process_report(parsed_report, source=url_obs) | ||
|
|
||
| # Tag as processed and remove the original tag | ||
| url_obs.expire_tag(FILTER_TAG) | ||
|
|
||
| except httpx.HTTPError as e: | ||
| logging.exception(f"HTTP Error processing URL {url_obs.value} with Agent") | ||
| logging.debug(last_response) | ||
| except Exception as e: | ||
| logging.exception(f"Error processing URL {url_obs.value} with Agent") | ||
| logging.debug(last_response) | ||
|
|
||
| def process_report(self, report, source: observable.Url): | ||
| report_entity = investigation.Investigation( | ||
| name=report["title"], | ||
| description=report["summary"], | ||
| reference=source.value, | ||
| ).save() | ||
|
|
||
| report_entity.link_to(source, "related_to", "source_url") | ||
|
|
||
| for ioc in report["iocs"]: | ||
| obs = observable.save(value=ioc["value"]) | ||
| obs.add_context( | ||
| source=self.name, context={"description": ioc["description"]} | ||
| ) | ||
| report_entity.link_to(obs, "contains", ioc["description"]) | ||
|
|
||
| logging.info( | ||
| f"Created investigation: {report_entity.id} with {len(report['iocs'])} IOCs" | ||
| ) | ||
|
|
||
|
|
||
| taskmanager.TaskManager.register_task(UrlExtractInvestigation) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Enforce wss? See comment in sample conf.