Skip to content

Commit aff3b43

Browse files
AntoLCsampaccoud
authored andcommitted
✨(backend) create ai endpoint
We created 2 new action endpoints on the document to perform AI operations: - POST /api/v1.0/documents/{uuid}/ai-transform - POST /api/v1.0/documents/{uuid}/ai-translate
1 parent e8d95fa commit aff3b43

17 files changed

+1440
-22
lines changed

env.d/development/common.dist

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,3 +39,8 @@ LOGOUT_REDIRECT_URL=http://localhost:3000
3939

4040
OIDC_REDIRECT_ALLOWED_HOSTS=["http://localhost:8083", "http://localhost:3000"]
4141
OIDC_AUTH_REQUEST_EXTRA_PARAMS={"acr_values": "eidas1"}
42+
43+
# AI
44+
AI_BASE_URL=https://openaiendpoint.com
45+
AI_API_KEY=password
46+
AI_MODEL=llama

src/backend/core/api/serializers.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@
99
import magic
1010
from rest_framework import exceptions, serializers
1111

12-
from core import models
12+
from core import enums, models
13+
from core.services.ai_services import AI_ACTIONS
1314

1415

1516
class UserSerializer(serializers.ModelSerializer):
@@ -378,3 +379,33 @@ class VersionFilterSerializer(serializers.Serializer):
378379
page_size = serializers.IntegerField(
379380
required=False, min_value=1, max_value=50, default=20
380381
)
382+
383+
384+
class AITransformSerializer(serializers.Serializer):
385+
"""Serializer for AI transform requests."""
386+
387+
action = serializers.ChoiceField(choices=AI_ACTIONS, required=True)
388+
text = serializers.CharField(required=True)
389+
390+
def validate_text(self, value):
391+
"""Ensure the text field is not empty."""
392+
393+
if len(value.strip()) == 0:
394+
raise serializers.ValidationError("Text field cannot be empty.")
395+
return value
396+
397+
398+
class AITranslateSerializer(serializers.Serializer):
399+
"""Serializer for AI translate requests."""
400+
401+
language = serializers.ChoiceField(
402+
choices=tuple(enums.ALL_LANGUAGES.items()), required=True
403+
)
404+
text = serializers.CharField(required=True)
405+
406+
def validate_text(self, value):
407+
"""Ensure the text field is not empty."""
408+
409+
if len(value.strip()) == 0:
410+
raise serializers.ValidationError("Text field cannot be empty.")
411+
return value

src/backend/core/api/utils.py

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,14 @@
11
"""Util to generate S3 authorization headers for object storage access control"""
22

3+
import time
4+
from abc import ABC, abstractmethod
5+
6+
from django.conf import settings
7+
from django.core.cache import cache
38
from django.core.files.storage import default_storage
49

510
import botocore
11+
from rest_framework.throttling import BaseThrottle
612

713

814
def generate_s3_authorization_headers(key):
@@ -31,3 +37,93 @@ def generate_s3_authorization_headers(key):
3137
auth.add_auth(request)
3238

3339
return request
40+
41+
42+
class AIBaseRateThrottle(BaseThrottle, ABC):
43+
"""Base throttle class for AI-related rate limiting with backoff."""
44+
45+
def __init__(self, rates):
46+
"""Initialize instance attributes with configurable rates."""
47+
super().__init__()
48+
self.rates = rates
49+
self.cache_key = None
50+
self.recent_requests_minute = 0
51+
self.recent_requests_hour = 0
52+
self.recent_requests_day = 0
53+
54+
@abstractmethod
55+
def get_cache_key(self, request, view):
56+
"""Abstract method to generate cache key for throttling."""
57+
58+
def allow_request(self, request, view):
59+
"""Check if the request is allowed based on rate limits."""
60+
self.cache_key = self.get_cache_key(request, view)
61+
if not self.cache_key:
62+
return True # Allow if no cache key is generated
63+
64+
now = time.time()
65+
history = cache.get(self.cache_key, [])
66+
# Keep requests within the last 24 hours
67+
history = [req for req in history if req > now - 86400]
68+
69+
# Calculate recent requests
70+
self.recent_requests_minute = len([req for req in history if req > now - 60])
71+
self.recent_requests_hour = len([req for req in history if req > now - 3600])
72+
self.recent_requests_day = len(history)
73+
74+
# Check rate limits
75+
if self.recent_requests_minute >= self.rates["minute"]:
76+
return False
77+
if self.recent_requests_hour >= self.rates["hour"]:
78+
return False
79+
if self.recent_requests_day >= self.rates["day"]:
80+
return False
81+
82+
# Log the request
83+
history.append(now)
84+
cache.set(self.cache_key, history, timeout=86400)
85+
return True
86+
87+
def wait(self):
88+
"""Implement a backoff strategy by increasing wait time based on limits hit."""
89+
if self.recent_requests_day >= self.rates["day"]:
90+
return 86400
91+
if self.recent_requests_hour >= self.rates["hour"]:
92+
return 3600
93+
if self.recent_requests_minute >= self.rates["minute"]:
94+
return 60
95+
return None
96+
97+
98+
class AIDocumentRateThrottle(AIBaseRateThrottle):
99+
"""Throttle for limiting AI requests per document with backoff."""
100+
101+
def __init__(self, *args, **kwargs):
102+
super().__init__(settings.AI_DOCUMENT_RATE_THROTTLE_RATES)
103+
104+
def get_cache_key(self, request, view):
105+
"""Include document ID in the cache key."""
106+
document_id = view.kwargs["pk"]
107+
return f"document_{document_id}_throttle_ai"
108+
109+
110+
class AIUserRateThrottle(AIBaseRateThrottle):
111+
"""Throttle that limits requests per user or IP with backoff and rate limits."""
112+
113+
def __init__(self, *args, **kwargs):
114+
super().__init__(settings.AI_USER_RATE_THROTTLE_RATES)
115+
116+
def get_cache_key(self, request, view=None):
117+
"""Generate a cache key based on the user ID or IP for anonymous users."""
118+
if request.user.is_authenticated:
119+
return f"user_{request.user.id!s}_throttle_ai"
120+
return f"anonymous_{self.get_ident(request)}_throttle_ai"
121+
122+
def get_ident(self, request):
123+
"""Return the request IP address."""
124+
x_forwarded_for = request.META.get("HTTP_X_FORWARDED_FOR")
125+
return (
126+
x_forwarded_for.split(",")[0]
127+
if x_forwarded_for
128+
else request.META.get("REMOTE_ADDR")
129+
)

src/backend/core/api/viewsets.py

Lines changed: 83 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
decorators,
2222
exceptions,
2323
filters,
24+
metadata,
2425
mixins,
2526
pagination,
2627
status,
@@ -30,7 +31,8 @@
3031
response as drf_response,
3132
)
3233

33-
from core import models
34+
from core import enums, models
35+
from core.services.ai_services import AIService
3436

3537
from . import permissions, serializers, utils
3638

@@ -302,6 +304,23 @@ def perform_update(self, serializer):
302304
serializer.save()
303305

304306

307+
class DocumentMetadata(metadata.SimpleMetadata):
308+
"""Custom metadata class to add information"""
309+
310+
def determine_metadata(self, request, view):
311+
"""Add language choices only for the list endpoint."""
312+
simple_metadata = super().determine_metadata(request, view)
313+
314+
if request.path.endswith("/documents/"):
315+
simple_metadata["actions"]["POST"]["language"] = {
316+
"choices": [
317+
{"value": code, "display_name": name}
318+
for code, name in enums.ALL_LANGUAGES.items()
319+
]
320+
}
321+
return simple_metadata
322+
323+
305324
class DocumentViewSet(
306325
ResourceViewsetMixin,
307326
mixins.CreateModelMixin,
@@ -319,6 +338,7 @@ class DocumentViewSet(
319338
resource_field_name = "document"
320339
queryset = models.Document.objects.all()
321340
ordering = ["-updated_at"]
341+
metadata_class = DocumentMetadata
322342

323343
def list(self, request, *args, **kwargs):
324344
"""Restrict resources returned by the list endpoint"""
@@ -455,10 +475,7 @@ def link_configuration(self, request, *args, **kwargs):
455475
serializer = serializers.LinkDocumentSerializer(
456476
document, data=request.data, partial=True
457477
)
458-
if not serializer.is_valid():
459-
return drf_response.Response(
460-
serializer.errors, status=status.HTTP_400_BAD_REQUEST
461-
)
478+
serializer.is_valid(raise_exception=True)
462479

463480
serializer.save()
464481
return drf_response.Response(serializer.data, status=status.HTTP_200_OK)
@@ -471,24 +488,21 @@ def attachment_upload(self, request, *args, **kwargs):
471488

472489
# Validate metadata in payload
473490
serializer = serializers.FileUploadSerializer(data=request.data)
474-
if not serializer.is_valid():
475-
return drf_response.Response(
476-
serializer.errors, status=status.HTTP_400_BAD_REQUEST
477-
)
491+
serializer.is_valid(raise_exception=True)
478492

479493
# Generate a generic yet unique filename to store the image in object storage
480494
file_id = uuid.uuid4()
481495
extension = serializer.validated_data["expected_extension"]
482496
key = f"{document.key_base}/{ATTACHMENTS_FOLDER:s}/{file_id!s}.{extension:s}"
483497

484498
# Prepare metadata for storage
485-
metadata = {"Metadata": {"owner": str(request.user.id)}}
499+
extra_args = {"Metadata": {"owner": str(request.user.id)}}
486500
if serializer.validated_data["is_unsafe"]:
487-
metadata["Metadata"]["is_unsafe"] = "true"
501+
extra_args["Metadata"]["is_unsafe"] = "true"
488502

489503
file = serializer.validated_data["file"]
490504
default_storage.connection.meta.client.upload_fileobj(
491-
file, default_storage.bucket_name, key, ExtraArgs=metadata
505+
file, default_storage.bucket_name, key, ExtraArgs=extra_args
492506
)
493507

494508
return drf_response.Response(
@@ -537,6 +551,63 @@ def retrieve_auth(self, request, *args, **kwargs):
537551
request = utils.generate_s3_authorization_headers(f"{pk:s}/{attachment_key:s}")
538552
return drf_response.Response("authorized", headers=request.headers, status=200)
539553

554+
@decorators.action(
555+
detail=True,
556+
methods=["post"],
557+
name="Apply a transformation action on a piece of text with AI",
558+
url_path="ai-transform",
559+
throttle_classes=[utils.AIDocumentRateThrottle, utils.AIUserRateThrottle],
560+
)
561+
def ai_transform(self, request, *args, **kwargs):
562+
"""
563+
POST /api/v1.0/documents/<resource_id>/ai-transform
564+
with expected data:
565+
- text: str
566+
- action: str [prompt, correct, rephrase, summarize]
567+
Return JSON response with the processed text.
568+
"""
569+
# Check permissions first
570+
self.get_object()
571+
572+
serializer = serializers.AITransformSerializer(data=request.data)
573+
serializer.is_valid(raise_exception=True)
574+
575+
text = serializer.validated_data["text"]
576+
action = serializer.validated_data["action"]
577+
578+
response = AIService().transform(text, action)
579+
580+
return drf_response.Response(response, status=status.HTTP_200_OK)
581+
582+
@decorators.action(
583+
detail=True,
584+
methods=["post"],
585+
name="Translate a piece of text with AI",
586+
serializer_class=serializers.AITranslateSerializer,
587+
url_path="ai-translate",
588+
throttle_classes=[utils.AIDocumentRateThrottle, utils.AIUserRateThrottle],
589+
)
590+
def ai_translate(self, request, *args, **kwargs):
591+
"""
592+
POST /api/v1.0/documents/<resource_id>/ai-translate
593+
with expected data:
594+
- text: str
595+
- language: str [settings.LANGUAGES]
596+
Return JSON response with the translated text.
597+
"""
598+
# Check permissions first
599+
self.get_object()
600+
601+
serializer = self.get_serializer(data=request.data)
602+
serializer.is_valid(raise_exception=True)
603+
604+
text = serializer.validated_data["text"]
605+
language = serializer.validated_data["language"]
606+
607+
response = AIService().translate(text, language)
608+
609+
return drf_response.Response(response, status=status.HTTP_200_OK)
610+
540611

541612
class DocumentAccessViewSet(
542613
ResourceAccessViewsetMixin,

src/backend/core/enums.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,11 @@
22
Core application enums declaration
33
"""
44

5-
from django.conf import global_settings, settings
5+
from django.conf import global_settings
66
from django.utils.translation import gettext_lazy as _
77

8-
# Django sets `LANGUAGES` by default with all supported languages. We can use it for
9-
# the choice of languages which should not be limited to the few languages active in
10-
# the app.
8+
# In Django's code base, `LANGUAGES` is set by default with all supported languages.
9+
# We can use it for the choice of languages which should not be limited to the few languages
10+
# active in the app.
1111
# pylint: disable=no-member
12-
ALL_LANGUAGES = getattr(
13-
settings,
14-
"ALL_LANGUAGES",
15-
[(language, _(name)) for language, name in global_settings.LANGUAGES],
16-
)
12+
ALL_LANGUAGES = {language: _(name) for language, name in global_settings.LANGUAGES}

src/backend/core/models.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -508,6 +508,8 @@ def get_abilities(self, user):
508508
can_get = bool(roles)
509509

510510
return {
511+
"ai_transform": is_owner_or_admin or is_editor,
512+
"ai_translate": is_owner_or_admin or is_editor,
511513
"attachment_upload": is_owner_or_admin or is_editor,
512514
"destroy": RoleChoices.OWNER in roles,
513515
"link_configuration": is_owner_or_admin,

src/backend/core/services/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)