|
1 | 1 | """AI services.""" |
| 2 | +from __future__ import annotations |
2 | 3 |
|
3 | | -import logging |
4 | | -from typing import Generator |
| 4 | +import json |
| 5 | +from typing import Any, Dict, Generator |
5 | 6 |
|
| 7 | +import httpx |
6 | 8 | from django.conf import settings |
7 | 9 | from django.core.exceptions import ImproperlyConfigured |
8 | 10 |
|
9 | 11 | from core import enums |
10 | 12 |
|
| 13 | +import logging |
| 14 | + |
11 | 15 | if settings.LANGFUSE_PUBLIC_KEY: |
12 | 16 | from langfuse.openai import OpenAI |
13 | 17 | else: |
|
16 | 20 | log = logging.getLogger(__name__) |
17 | 21 |
|
18 | 22 |
|
| 23 | +BLOCKNOTE_TOOL_STRICT_PROMPT = """You are editing a BlockNote document via the tool applyDocumentOperations. |
| 24 | +
|
| 25 | +You MUST respond ONLY by calling applyDocumentOperations. |
| 26 | +The tool input MUST be valid JSON: |
| 27 | +{ "operations": [ ... ] } |
| 28 | +
|
| 29 | +Each operation MUST include "type" and it MUST be one of: |
| 30 | +- "update" (requires: id, block) |
| 31 | +- "add" (requires: referenceId, position, blocks) |
| 32 | +- "delete" (requires: id) |
| 33 | +
|
| 34 | +VALID SHAPES (FOLLOW EXACTLY): |
| 35 | +
|
| 36 | +Update: |
| 37 | +{ "type":"update", "id":"<id$>", "block":"<p>...</p>" } |
| 38 | +IMPORTANT: "block" MUST be a STRING containing a SINGLE valid HTML element. |
| 39 | +
|
| 40 | +Add: |
| 41 | +{ "type":"add", "referenceId":"<id$>", "position":"before|after", "blocks":["<p>...</p>"] } |
| 42 | +IMPORTANT: "blocks" MUST be an ARRAY OF STRINGS. |
| 43 | +Each item MUST be a STRING containing a SINGLE valid HTML element. |
| 44 | +
|
| 45 | +Delete: |
| 46 | +{ "type":"delete", "id":"<id$>" } |
| 47 | +
|
| 48 | +IDs ALWAYS end with "$". Use ids EXACTLY as provided. |
| 49 | +
|
| 50 | +Return ONLY the JSON tool input. No prose, no markdown. |
| 51 | +""" |
| 52 | + |
| 53 | + |
| 54 | +def _drop_nones(obj: Any) -> Any: |
| 55 | + if isinstance(obj, dict): |
| 56 | + return {k: _drop_nones(v) for k, v in obj.items() if v is not None} |
| 57 | + if isinstance(obj, list): |
| 58 | + return [_drop_nones(v) for v in obj] |
| 59 | + return obj |
| 60 | + |
| 61 | + |
19 | 62 | AI_ACTIONS = { |
20 | 63 | "prompt": ( |
21 | 64 | "Answer the prompt using markdown formatting for structure and emphasis. " |
@@ -72,7 +115,8 @@ def __init__(self): |
72 | 115 | or settings.AI_MODEL is None |
73 | 116 | ): |
74 | 117 | raise ImproperlyConfigured("AI configuration not set") |
75 | | - self.client = OpenAI(base_url=settings.AI_BASE_URL, api_key=settings.AI_API_KEY) |
| 118 | + self.api_key = settings.AI_API_KEY |
| 119 | + self.client = OpenAI(base_url=settings.AI_BASE_URL, api_key=self.api_key) |
76 | 120 |
|
77 | 121 | def call_ai_api(self, system_content, text): |
78 | 122 | """Helper method to call the OpenAI API and process the response.""" |
@@ -102,18 +146,95 @@ def translate(self, text, language): |
102 | 146 | system_content = AI_TRANSLATE.format(language=language_display) |
103 | 147 | return self.call_ai_api(system_content, text) |
104 | 148 |
|
105 | | - def proxy(self, data: dict, stream: bool = False) -> Generator[str, None, None]: |
106 | | - """Proxy AI API requests to the configured AI provider.""" |
107 | | - data["stream"] = stream |
108 | | - try: |
109 | | - return self.client.chat.completions.create(**data) |
110 | | - except OpenAIError as e: |
111 | | - raise RuntimeError(f"Failed to proxy AI request: {e}") from e |
112 | | - |
113 | | - def stream(self, data: dict) -> Generator[str, None, None]: |
114 | | - """Stream AI API requests to the configured AI provider.""" |
115 | | - stream = self.proxy(data, stream=True) |
116 | | - for chunk in stream: |
117 | | - yield f"data: {chunk.model_dump_json()}\n\n" |
118 | 149 |
|
119 | | - yield "data: [DONE]\n\n" |
| 150 | + def _filtered_headers(self, incoming_headers: Dict[str, str]) -> Dict[str, str]: |
| 151 | + hop_by_hop = {"host", "connection", "content-length", "accept-encoding"} |
| 152 | + out: Dict[str, str] = {} |
| 153 | + for k, v in incoming_headers.items(): |
| 154 | + lk = k.lower() |
| 155 | + if lk in hop_by_hop: |
| 156 | + continue |
| 157 | + if lk == "authorization": |
| 158 | + # Client auth is for Django only, not upstream |
| 159 | + continue |
| 160 | + out[k] = v |
| 161 | + |
| 162 | + out["Authorization"] = f"Bearer {self.api_key}" |
| 163 | + return out |
| 164 | + |
| 165 | + def _normalize_tools(self, tools: list) -> list: |
| 166 | + normalized = [] |
| 167 | + for tool in tools: |
| 168 | + if isinstance(tool, dict) and tool.get("type") == "function": |
| 169 | + fn = tool.get("function") or {} |
| 170 | + if isinstance(fn, dict) and not fn.get("description"): |
| 171 | + fn["description"] = f"Tool {fn.get('name', 'unknown')}." |
| 172 | + tool["function"] = fn |
| 173 | + normalized.append(_drop_nones(tool)) |
| 174 | + return normalized |
| 175 | + |
| 176 | + def _harden_payload(self, payload: Dict[str, Any]) -> Dict[str, Any]: |
| 177 | + payload = dict(payload) |
| 178 | + |
| 179 | + # Enforce server model (important with Albert routing) |
| 180 | + if getattr(settings, "AI_MODEL", None): |
| 181 | + payload["model"] = settings.AI_MODEL |
| 182 | + |
| 183 | + # Compliance |
| 184 | + payload["temperature"] = 0 |
| 185 | + |
| 186 | + # Tools normalization |
| 187 | + if isinstance(payload.get("tools"), list): |
| 188 | + payload["tools"] = self._normalize_tools(payload["tools"]) |
| 189 | + |
| 190 | + # Force tool call if tools exist |
| 191 | + if payload.get("tools"): |
| 192 | + payload["tool_choice"] = {"type": "function", "function": {"name": "applyDocumentOperations"}} |
| 193 | + |
| 194 | + # Convert non-standard "required" |
| 195 | + if payload.get("tool_choice") == "required": |
| 196 | + payload["tool_choice"] = {"type": "function", "function": {"name": "applyDocumentOperations"}} |
| 197 | + |
| 198 | + # Inject strict system prompt once |
| 199 | + msgs = payload.get("messages") |
| 200 | + if isinstance(msgs, list): |
| 201 | + need = True |
| 202 | + if msgs and isinstance(msgs[0], dict) and msgs[0].get("role") == "system": |
| 203 | + c = msgs[0].get("content") or "" |
| 204 | + if isinstance(c, str) and "applyDocumentOperations" in c and "blocks" in c: |
| 205 | + need = False |
| 206 | + if need: |
| 207 | + payload["messages"] = [{"role": "system", "content": BLOCKNOTE_TOOL_STRICT_PROMPT}] + msgs |
| 208 | + |
| 209 | + return _drop_nones(payload) |
| 210 | + |
| 211 | + def _maybe_harden_json_body(self, body: bytes, headers: Dict[str, str]) -> bytes: |
| 212 | + ct = (headers.get("Content-Type") or headers.get("content-type") or "").lower() |
| 213 | + if "application/json" not in ct: |
| 214 | + return body |
| 215 | + try: |
| 216 | + payload = json.loads(body.decode("utf-8")) |
| 217 | + except Exception: |
| 218 | + return body |
| 219 | + if isinstance(payload, dict): |
| 220 | + payload = self._harden_payload(payload) |
| 221 | + return json.dumps(payload, ensure_ascii=False).encode("utf-8") |
| 222 | + return body |
| 223 | + |
| 224 | + def stream_proxy( |
| 225 | + self, |
| 226 | + *, |
| 227 | + url: str, |
| 228 | + method: str, |
| 229 | + headers: Dict[str, str], |
| 230 | + body: bytes, |
| 231 | + ) -> Generator[bytes, None, None]: |
| 232 | + req_headers = self._filtered_headers(dict(headers)) |
| 233 | + req_body = self._maybe_harden_json_body(body, req_headers) |
| 234 | + |
| 235 | + timeout = httpx.Timeout(connect=10.0, read=300.0, write=60.0, pool=10.0) |
| 236 | + with httpx.Client(timeout=timeout, follow_redirects=False) as client: |
| 237 | + with client.stream(method.upper(), url, headers=req_headers, content=req_body) as r: |
| 238 | + for chunk in r.iter_bytes(): |
| 239 | + if chunk: |
| 240 | + yield chunk |
0 commit comments