Skip to content
Open
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/docker-hub.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ on:
push:
branches:
- 'main'
- 'refacto/blocknote-ai'
tags:
- 'v*'
pull_request:
Expand Down
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ and this project adheres to
- (doc) add documentation to install with compose #855
- ✨(backend) allow to disable checking unsafe mimetype on attachment upload
- ✨Ask for access #1081
- ✨(frontend) integrate new Blocknote AI feature #1016

### Changed

Expand Down
119 changes: 60 additions & 59 deletions docs/env.md

Large diffs are not rendered by default.

50 changes: 27 additions & 23 deletions src/backend/core/api/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from rest_framework import exceptions, serializers

from core import enums, models, utils
from core.services.ai_services import AI_ACTIONS
from core.services.converter_services import (
ConversionError,
YdocConverter,
Expand Down Expand Up @@ -718,33 +717,38 @@ class VersionFilterSerializer(serializers.Serializer):
)


class AITransformSerializer(serializers.Serializer):
"""Serializer for AI transform requests."""

action = serializers.ChoiceField(choices=AI_ACTIONS, required=True)
text = serializers.CharField(required=True)

def validate_text(self, value):
"""Ensure the text field is not empty."""

if len(value.strip()) == 0:
raise serializers.ValidationError("Text field cannot be empty.")
return value
class AIProxySerializer(serializers.Serializer):
"""Serializer for AI proxy requests."""

messages = serializers.ListField(
required=True,
child=serializers.DictField(
child=serializers.CharField(required=True),
),
allow_empty=False,
)
model = serializers.CharField(required=True)

class AITranslateSerializer(serializers.Serializer):
"""Serializer for AI translate requests."""
def validate_messages(self, messages):
"""Validate messages structure."""
# Ensure each message has the required fields
for message in messages:
if (
not isinstance(message, dict)
or "role" not in message
or "content" not in message
):
raise serializers.ValidationError(
"Each message must have 'role' and 'content' fields"
)

language = serializers.ChoiceField(
choices=tuple(enums.ALL_LANGUAGES.items()), required=True
)
text = serializers.CharField(required=True)
return messages

def validate_text(self, value):
"""Ensure the text field is not empty."""
def validate_model(self, value):
"""Validate model value is the same than settings.AI_MODEL"""
if value != settings.AI_MODEL:
raise serializers.ValidationError(f"{value} is not a valid model")

if len(value.strip()) == 0:
raise serializers.ValidationError("Text field cannot be empty.")
return value


Expand Down
82 changes: 26 additions & 56 deletions src/backend/core/api/viewsets.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,21 +387,8 @@ class DocumentViewSet(
9. **Media Auth**: Authorize access to document media.
Example: GET /documents/media-auth/

10. **AI Transform**: Apply a transformation action on a piece of text with AI.
Example: POST /documents/{id}/ai-transform/
Expected data:
- text (str): The input text.
- action (str): The transformation type, one of [prompt, correct, rephrase, summarize].
Returns: JSON response with the processed text.
Throttled by: AIDocumentRateThrottle, AIUserRateThrottle.

11. **AI Translate**: Translate a piece of text with AI.
Example: POST /documents/{id}/ai-translate/
Expected data:
- text (str): The input text.
- language (str): The target language, chosen from settings.LANGUAGES.
Returns: JSON response with the translated text.
Throttled by: AIDocumentRateThrottle, AIUserRateThrottle.
10. **AI Proxy**: Proxy an AI request to an external AI service.
Example: POST /api/v1.0/documents/<resource_id>/ai-proxy

### Ordering: created_at, updated_at, is_favorite, title

Expand Down Expand Up @@ -438,7 +425,6 @@ class DocumentViewSet(
]
queryset = models.Document.objects.all()
serializer_class = serializers.DocumentSerializer
ai_translate_serializer_class = serializers.AITranslateSerializer
children_serializer_class = serializers.ListDocumentSerializer
descendants_serializer_class = serializers.ListDocumentSerializer
list_serializer_class = serializers.ListDocumentSerializer
Expand Down Expand Up @@ -1356,58 +1342,39 @@ def media_check(self, request, *args, **kwargs):
@drf.decorators.action(
detail=True,
methods=["post"],
name="Apply a transformation action on a piece of text with AI",
url_path="ai-transform",
name="Proxy AI requests to the AI provider",
url_path="ai-proxy",
throttle_classes=[utils.AIDocumentRateThrottle, utils.AIUserRateThrottle],
)
def ai_transform(self, request, *args, **kwargs):
def ai_proxy(self, request, *args, **kwargs):
"""
POST /api/v1.0/documents/<resource_id>/ai-transform
with expected data:
- text: str
- action: str [prompt, correct, rephrase, summarize]
Return JSON response with the processed text.
POST /api/v1.0/documents/<resource_id>/ai-proxy
Proxy AI requests to the configured AI provider.
This endpoint forwards requests to the AI provider and returns the complete response.
"""
# Check permissions first
self.get_object()

serializer = serializers.AITransformSerializer(data=request.data)
serializer.is_valid(raise_exception=True)

text = serializer.validated_data["text"]
action = serializer.validated_data["action"]

response = AIService().transform(text, action)

return drf.response.Response(response, status=drf.status.HTTP_200_OK)

@drf.decorators.action(
detail=True,
methods=["post"],
name="Translate a piece of text with AI",
url_path="ai-translate",
throttle_classes=[utils.AIDocumentRateThrottle, utils.AIUserRateThrottle],
)
def ai_translate(self, request, *args, **kwargs):
"""
POST /api/v1.0/documents/<resource_id>/ai-translate
with expected data:
- text: str
- language: str [settings.LANGUAGES]
Return JSON response with the translated text.
"""
# Check permissions first
self.get_object()
if not settings.AI_FEATURE_ENABLED:
raise ValidationError("AI feature is not enabled.")

serializer = self.get_serializer(data=request.data)
serializer = serializers.AIProxySerializer(data=request.data)
serializer.is_valid(raise_exception=True)

text = serializer.validated_data["text"]
language = serializer.validated_data["language"]
ai_service = AIService()

response = AIService().translate(text, language)
if settings.AI_STREAM:
return StreamingHttpResponse(
ai_service.stream(request.data),
content_type="text/event-stream",
status=drf.status.HTTP_200_OK,
)

return drf.response.Response(response, status=drf.status.HTTP_200_OK)
ai_response = ai_service.proxy(request.data)
return drf.response.Response(
ai_response.model_dump(),
status=drf.status.HTTP_200_OK,
)

@drf.decorators.action(
detail=True,
Expand Down Expand Up @@ -1863,7 +1830,10 @@ def get(self, request):
Return a dictionary of public settings.
"""
array_settings = [
"AI_BOT",
"AI_FEATURE_ENABLED",
"AI_MODEL",
"AI_STREAM",
"COLLABORATION_WS_URL",
"COLLABORATION_WS_NOT_CONNECTED_READY_ONLY",
"CRISP_WEBSITE_ID",
Expand Down
3 changes: 1 addition & 2 deletions src/backend/core/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -832,8 +832,7 @@ def get_abilities(self, user, ancestors_links=None):
return {
"accesses_manage": is_owner_or_admin,
"accesses_view": has_access_role,
"ai_transform": ai_access,
"ai_translate": ai_access,
"ai_proxy": ai_access,
"attachment_upload": can_update,
"media_check": can_get,
"children_list": can_get,
Expand Down
92 changes: 20 additions & 72 deletions src/backend/core/services/ai_services.py
Original file line number Diff line number Diff line change
@@ -1,54 +1,14 @@
"""AI services."""

import logging
from typing import Generator

from django.conf import settings
from django.core.exceptions import ImproperlyConfigured

from openai import OpenAI

from core import enums
from openai import OpenAI, OpenAIError

AI_ACTIONS = {
"prompt": (
"Answer the prompt in markdown format. "
"Preserve the language and markdown formatting. "
"Do not provide any other information. "
"Preserve the language."
),
"correct": (
"Correct grammar and spelling of the markdown text, "
"preserving language and markdown formatting. "
"Do not provide any other information. "
"Preserve the language."
),
"rephrase": (
"Rephrase the given markdown text, "
"preserving language and markdown formatting. "
"Do not provide any other information. "
"Preserve the language."
),
"summarize": (
"Summarize the markdown text, preserving language and markdown formatting. "
"Do not provide any other information. "
"Preserve the language."
),
"beautify": (
"Add formatting to the text to make it more readable. "
"Do not provide any other information. "
"Preserve the language."
),
"emojify": (
"Add emojis to the important parts of the text. "
"Do not provide any other information. "
"Preserve the language."
),
}

AI_TRANSLATE = (
"Keep the same html structure and formatting. "
"Translate the content in the html to the specified language {language:s}. "
"Check the translation for accuracy and make any necessary corrections. "
"Do not provide any other information."
)
log = logging.getLogger(__name__)


class AIService:
Expand All @@ -64,30 +24,18 @@ def __init__(self):
raise ImproperlyConfigured("AI configuration not set")
self.client = OpenAI(base_url=settings.AI_BASE_URL, api_key=settings.AI_API_KEY)

def call_ai_api(self, system_content, text):
"""Helper method to call the OpenAI API and process the response."""
response = self.client.chat.completions.create(
model=settings.AI_MODEL,
messages=[
{"role": "system", "content": system_content},
{"role": "user", "content": text},
],
)

content = response.choices[0].message.content

if not content:
raise RuntimeError("AI response does not contain an answer")

return {"answer": content}

def transform(self, text, action):
"""Transform text based on specified action."""
system_content = AI_ACTIONS[action]
return self.call_ai_api(system_content, text)

def translate(self, text, language):
"""Translate text to a specified language."""
language_display = enums.ALL_LANGUAGES.get(language, language)
system_content = AI_TRANSLATE.format(language=language_display)
return self.call_ai_api(system_content, text)
def proxy(self, data: dict, stream: bool = False) -> Generator[str, None, None]:
"""Proxy AI API requests to the configured AI provider."""
data["stream"] = stream
try:
return self.client.chat.completions.create(**data)
except OpenAIError as e:
raise RuntimeError(f"Failed to proxy AI request: {e}") from e

def stream(self, data: dict) -> Generator[str, None, None]:
"""Stream AI API requests to the configured AI provider."""
stream = self.proxy(data, stream=True)
for chunk in stream:
yield f"data: {chunk.model_dump_json()}\n\n"

yield "data: [DONE]\n\n"
Loading
Loading