Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/docker-hub.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ on:
push:
branches:
- 'main'
- 'feat/blocknote-ai'
tags:
- 'v*'
pull_request:
Expand Down
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ and this project adheres to
- ✨(frontend) Add stat for Crisp #1824
- ✨(auth) add silent login #1690
- 🔧(project) add DJANGO_EMAIL_URL_APP environment variable #1825
- ✨(frontend) integrate new Blocknote AI feature #1016

### Changed

Expand Down
1 change: 1 addition & 0 deletions docs/env.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ These are the environment variables you can set for the `impress-backend` contai
| AI_ALLOW_REACH_FROM | Users that can use AI must be this level. options are "public", "authenticated", "restricted" | authenticated |
| AI_API_KEY | AI key to be used for AI Base url | |
| AI_BASE_URL | OpenAI compatible AI base url | |
| AI_BOT | Information to give to the frontend about the AI bot | { "name": "Docs AI", "color": "#8bc6ff" }
| AI_FEATURE_ENABLED | Enable AI options | false |
| AI_MODEL | AI Model to use | |
| ALLOW_LOGOUT_GET_METHOD | Allow get logout method | true |
Expand Down
35 changes: 35 additions & 0 deletions src/backend/core/api/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -822,6 +822,41 @@ def validate_text(self, value):
return value


class AIProxySerializer(serializers.Serializer):
"""Serializer for AI proxy requests."""

messages = serializers.ListField(
required=True,
child=serializers.DictField(
child=serializers.CharField(required=True),
),
allow_empty=False,
)
model = serializers.CharField(required=True)

def validate_messages(self, messages):
"""Validate messages structure."""
# Ensure each message has the required fields
for message in messages:
if (
not isinstance(message, dict)
or "role" not in message
or "content" not in message
):
raise serializers.ValidationError(
"Each message must have 'role' and 'content' fields"
)

return messages

def validate_model(self, value):
"""Validate model value is the same than settings.AI_MODEL"""
if value != settings.AI_MODEL:
raise serializers.ValidationError(f"{value} is not a valid model")

return value


class MoveDocumentSerializer(serializers.Serializer):
"""
Serializer for validating input data to move a document within the tree structure.
Expand Down
44 changes: 44 additions & 0 deletions src/backend/core/api/viewsets.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,6 +354,9 @@ class DocumentViewSet(
Returns: JSON response with the translated text.
Throttled by: AIDocumentRateThrottle, AIUserRateThrottle.

12. **AI Proxy**: Proxy an AI request to an external AI service.
Example: POST /api/v1.0/documents/<resource_id>/ai-proxy

### Ordering: created_at, updated_at, is_favorite, title

Example:
Expand Down Expand Up @@ -1642,6 +1645,44 @@ def media_check(self, request, *args, **kwargs):

return drf.response.Response(body, status=drf.status.HTTP_200_OK)

@drf.decorators.action(
detail=True,
methods=["post"],
name="Proxy AI requests to the AI provider",
url_path="ai-proxy",
# throttle_classes=[utils.AIDocumentRateThrottle, utils.AIUserRateThrottle],
)
def ai_proxy(self, request, *args, **kwargs):
"""
POST /api/v1.0/documents/<resource_id>/ai-proxy
Proxy AI requests to the configured AI provider.
This endpoint forwards requests to the AI provider and returns the complete response.
"""
# Check permissions first
self.get_object()

if not settings.AI_FEATURE_ENABLED:
raise ValidationError("AI feature is not enabled.")

ai_service = AIService()

if settings.AI_STREAM:
stream_gen = ai_service.stream_proxy(
url=settings.AI_BASE_URL.rstrip("/") + "/chat/completions",
method="POST",
headers={"Content-Type": "application/json"},
body=json.dumps(request.data, ensure_ascii=False).encode("utf-8"),
)

resp = StreamingHttpResponse(
streaming_content=stream_gen,
content_type="text/event-stream",
status=200,
)
resp["X-Accel-Buffering"] = "no"
resp["Cache-Control"] = "no-cache"
return resp

@drf.decorators.action(
detail=True,
methods=["post"],
Expand Down Expand Up @@ -2337,7 +2378,10 @@ def get(self, request):
Return a dictionary of public settings.
"""
array_settings = [
"AI_BOT",
"AI_FEATURE_ENABLED",
"AI_MODEL",
"AI_STREAM",
"COLLABORATION_WS_URL",
"COLLABORATION_WS_NOT_CONNECTED_READY_ONLY",
"CONVERSION_FILE_EXTENSIONS_ALLOWED",
Expand Down
1 change: 1 addition & 0 deletions src/backend/core/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -783,6 +783,7 @@ def get_abilities(self, user):
return {
"accesses_manage": is_owner_or_admin,
"accesses_view": has_access_role,
"ai_proxy": ai_access,
"ai_transform": ai_access,
"ai_translate": ai_access,
"attachment_upload": can_update,
Expand Down
160 changes: 158 additions & 2 deletions src/backend/core/services/ai_services.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,63 @@
"""AI services."""

from __future__ import annotations

import json
import logging
from typing import Any, Dict, Generator

from django.conf import settings
from django.core.exceptions import ImproperlyConfigured

import httpx

from core import enums

if settings.LANGFUSE_PUBLIC_KEY:
from langfuse.openai import OpenAI
else:
from openai import OpenAI
from openai import OpenAI, OpenAIError

log = logging.getLogger(__name__)


BLOCKNOTE_TOOL_STRICT_PROMPT = """You are editing a BlockNote document via the tool applyDocumentOperations.

You MUST respond ONLY by calling applyDocumentOperations.
The tool input MUST be valid JSON:
{ "operations": [ ... ] }

Each operation MUST include "type" and it MUST be one of:
- "update" (requires: id, block)
- "add" (requires: referenceId, position, blocks)
- "delete" (requires: id)

VALID SHAPES (FOLLOW EXACTLY):

Update:
{ "type":"update", "id":"<id$>", "block":"<p>...</p>" }
IMPORTANT: "block" MUST be a STRING containing a SINGLE valid HTML element.

Add:
{ "type":"add", "referenceId":"<id$>", "position":"before|after", "blocks":["<p>...</p>"] }
IMPORTANT: "blocks" MUST be an ARRAY OF STRINGS.
Each item MUST be a STRING containing a SINGLE valid HTML element.

Delete:
{ "type":"delete", "id":"<id$>" }

IDs ALWAYS end with "$". Use ids EXACTLY as provided.

Return ONLY the JSON tool input. No prose, no markdown.
"""


def _drop_nones(obj: Any) -> Any:
if isinstance(obj, dict):
return {k: _drop_nones(v) for k, v in obj.items() if v is not None}
if isinstance(obj, list):
return [_drop_nones(v) for v in obj]
return obj


AI_ACTIONS = {
Expand Down Expand Up @@ -67,7 +116,8 @@ def __init__(self):
or settings.AI_MODEL is None
):
raise ImproperlyConfigured("AI configuration not set")
self.client = OpenAI(base_url=settings.AI_BASE_URL, api_key=settings.AI_API_KEY)
self.api_key = settings.AI_API_KEY
self.client = OpenAI(base_url=settings.AI_BASE_URL, api_key=self.api_key)

def call_ai_api(self, system_content, text):
"""Helper method to call the OpenAI API and process the response."""
Expand Down Expand Up @@ -96,3 +146,109 @@ def translate(self, text, language):
language_display = enums.ALL_LANGUAGES.get(language, language)
system_content = AI_TRANSLATE.format(language=language_display)
return self.call_ai_api(system_content, text)

def _filtered_headers(self, incoming_headers: Dict[str, str]) -> Dict[str, str]:
hop_by_hop = {"host", "connection", "content-length", "accept-encoding"}
out: Dict[str, str] = {}
for k, v in incoming_headers.items():
lk = k.lower()
if lk in hop_by_hop:
continue
if lk == "authorization":
# Client auth is for Django only, not upstream
continue
out[k] = v

out["Authorization"] = f"Bearer {self.api_key}"
return out

def _normalize_tools(self, tools: list) -> list:
normalized = []
for tool in tools:
if isinstance(tool, dict) and tool.get("type") == "function":
fn = tool.get("function") or {}
if isinstance(fn, dict) and not fn.get("description"):
fn["description"] = f"Tool {fn.get('name', 'unknown')}."
tool["function"] = fn
normalized.append(_drop_nones(tool))
return normalized

def _harden_payload(self, payload: Dict[str, Any]) -> Dict[str, Any]:
payload = dict(payload)

# Enforce server model (important with Albert routing)
if getattr(settings, "AI_MODEL", None):
payload["model"] = settings.AI_MODEL

# Compliance
payload["temperature"] = 0

# Tools normalization
if isinstance(payload.get("tools"), list):
payload["tools"] = self._normalize_tools(payload["tools"])

# Force tool call if tools exist
if payload.get("tools"):
payload["tool_choice"] = {
"type": "function",
"function": {"name": "applyDocumentOperations"},
}

# Convert non-standard "required"
if payload.get("tool_choice") == "required":
payload["tool_choice"] = {
"type": "function",
"function": {"name": "applyDocumentOperations"},
}

# Inject strict system prompt once
msgs = payload.get("messages")
if isinstance(msgs, list):
need = True
if msgs and isinstance(msgs[0], dict) and msgs[0].get("role") == "system":
c = msgs[0].get("content") or ""
if (
isinstance(c, str)
and "applyDocumentOperations" in c
and "blocks" in c
):
need = False
if need:
payload["messages"] = [
{"role": "system", "content": BLOCKNOTE_TOOL_STRICT_PROMPT}
] + msgs

return _drop_nones(payload)

def _maybe_harden_json_body(self, body: bytes, headers: Dict[str, str]) -> bytes:
ct = (headers.get("Content-Type") or headers.get("content-type") or "").lower()
if "application/json" not in ct:
return body
try:
payload = json.loads(body.decode("utf-8"))
except json.JSONDecodeError:
return body
if isinstance(payload, dict):
payload = self._harden_payload(payload)
return json.dumps(payload, ensure_ascii=False).encode("utf-8")
return body

def stream_proxy(
self,
*,
url: str,
method: str,
headers: Dict[str, str],
body: bytes,
) -> Generator[bytes, None, None]:
req_headers = self._filtered_headers(dict(headers))
req_body = self._maybe_harden_json_body(body, req_headers)

timeout = httpx.Timeout(connect=10.0, read=300.0, write=60.0, pool=10.0)
with httpx.Client(timeout=timeout, follow_redirects=False) as client:
with client.stream(
method.upper(), url, headers=req_headers, content=req_body
) as r:
for chunk in r.iter_bytes():
if chunk:
yield chunk
Loading
Loading