Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion chatbotcore/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,10 @@ class QdrantDatabase:

def __post_init__(self):
"""Initialize database client"""
self.db_client = QdrantClient(host=self.host, port=self.port)
try:
self.db_client = QdrantClient(host=self.host, port=self.port)
except Exception:
logging.error("Database connection failed", exc_info=True)

def _collection_exists(self, collection_name: str) -> bool:
"""Check if the collection in db already exists"""
Expand Down
84 changes: 55 additions & 29 deletions content/mutations.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,19 @@
from asgiref.sync import sync_to_async
from strawberry.file_uploads import Upload

from content.models import Content
from content.serializers import (
ArchiveContentSerializer,
ContentSerializer,
RetriggerContentSerializer,
TagSerializer,
UpdateContentSerializer,
)
from content.types import ContentType, TagType
from main.graphql.context import Info
from utils.strawberry.mutations import (
ModelMutation,
MutationEmptyResponseType,
MutationResponseType,
convert_serializer_to_type,
mutation_is_not_valid,
process_input_data,
)
Expand All @@ -28,9 +28,10 @@ class FolderInput:


CreateContentMutation = ModelMutation("Content", ContentSerializer)
UpdateMutation = ModelMutation("UpdateContent", UpdateContentSerializer)
DeleteContent = ModelMutation("archive", ArchiveContentSerializer)
CreateTagMutation = ModelMutation("CreateTag", TagSerializer)
UpdateContentTitleInput = convert_serializer_to_type(UpdateContentSerializer, name="UpdateContentTitleInput")
RetriggerContentInput = convert_serializer_to_type(RetriggerContentSerializer, name="RetriggerContentInput")
ArchiveContentInput = convert_serializer_to_type(ArchiveContentSerializer, name="ArchiveContentInput")


@strawberry.type
Expand All @@ -48,43 +49,68 @@ def read_file(self, file: Upload) -> str:
return file.read().decode("utf-8")

@strawberry.mutation
async def update_content_title(
@sync_to_async
def update_content_title(
self,
id: strawberry.ID,
data: UpdateMutation.PartialInputType, # type: ignore[reportInvalidTypeForm]
data: UpdateContentTitleInput, # type: ignore[reportInvalidTypeForm]
info: Info,
) -> MutationResponseType[ContentType]:
try:
instance = await Content.objects.aget(id=id)
except Content.DoesNotExist:
return MutationResponseType(ok=False, errors=["Content not found"])
serializer = UpdateContentSerializer(
instance, data=process_input_data(data), context={"request": info.context.request}, partial=True
instance=info.context.request.user,
data=process_input_data(data),
context={"request": info.context.request},
)
if errors := mutation_is_not_valid(serializer):
return MutationResponseType(ok=False, errors=errors)
await sync_to_async(serializer.save)()
return MutationResponseType()
return MutationResponseType(
ok=False,
errors=errors,
)
content = serializer.save() # type: ignore[reportReturnType]
return MutationResponseType(
result=content, # type: ignore[reportReturnType]
)

@strawberry.mutation
async def archive_content(
self, id: strawberry.ID, info: Info, data: DeleteContent.PartialInputType # type: ignore[reportInvalidTypeForm]
) -> MutationEmptyResponseType:
try:
instance = await Content.objects.aget(id=id)
except Content.DoesNotExist:
return MutationEmptyResponseType(ok=False, errors=["Content not found"])
@sync_to_async
def retrigger_content(
self,
data: RetriggerContentInput, # type: ignore[reportInvalidTypeForm]
info: Info,
) -> MutationResponseType[ContentType]:
serializer = RetriggerContentSerializer(
instance=info.context.request.user,
data=process_input_data(data),
context={"request": info.context.request},
)
if errors := mutation_is_not_valid(serializer):
return MutationResponseType(
ok=False,
errors=errors,
)
content = serializer.save() # type: ignore[reportReturnType]
return MutationResponseType(
result=content, # type: ignore[reportReturnType]
)

@strawberry.mutation
@sync_to_async
def archive_content(
self, data: ArchiveContentInput, info: Info # type: ignore[reportInvalidTypeForm]
) -> MutationResponseType[ContentType]:
serializer = ArchiveContentSerializer(
instance, data=process_input_data(data), context={"request": info.context.request}, partial=True
instance=info.context.request.user,
data=process_input_data(data),
context={"request": info.context.request},
)

if errors := mutation_is_not_valid(serializer):
return MutationEmptyResponseType(ok=False, errors=errors)

await sync_to_async(serializer.save)()

return MutationEmptyResponseType(ok=True)
return MutationResponseType(
ok=False,
errors=errors,
)
content = serializer.save() # type: ignore[reportReturnType]
return MutationResponseType(
result=content, # type: ignore[reportReturnType]
)

@strawberry.mutation
async def create_tag(
Expand Down
83 changes: 60 additions & 23 deletions content/serializers.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,13 @@
from django.db import transaction
from django.utils import timezone
from rest_framework import serializers

from chat.models import UserChatSession
from content.models import Content, Tag
from content.tasks import (
create_embedding_for_content_task,
delete_content_from_qdrant_task,
)
from utils.file_check import validate_document_size, validate_file_type


Expand All @@ -19,7 +25,7 @@ class Meta:

class ContentSerializer(serializers.ModelSerializer):
tag = serializers.PrimaryKeyRelatedField(queryset=Tag.objects.all(), many=True, required=False)
document_file = serializers.FileField(required=False)
document_file = serializers.FileField(required=True)

class Meta:
model = Content
Expand All @@ -41,33 +47,64 @@ def create(self, validated_data):


class UpdateContentSerializer(serializers.ModelSerializer):
# NOTE: Update only the content title for now
content = serializers.PrimaryKeyRelatedField(queryset=Content.objects.all(), required=True)

class Meta:
model = Content
fields = ["title"]
read_only_fields = ["modified_by"]
fields = ["title", "content"]

def save(self, **_):
assert isinstance(self.validated_data, dict)
content = self.validated_data["content"]
content.title = self.validated_data["title"]
content.modified_by = self.context["request"].user
content.save(update_fields=["title", "modified_by"])
return content

def update(self, instance, validated_data):
validated_data["modified_by"] = self.context["request"].user

return super().update(instance, validated_data)
class ArchiveContentSerializer(serializers.ModelSerializer):
"""NOTE: Update the document status to DELETED_FROM_VECTOR in content model and delete the content from qdrant db"""

content = serializers.PrimaryKeyRelatedField(queryset=Content.objects.all(), required=True)

def validate(self, attrs):
content = attrs["content"]
if content.document_status == Content.DocumentStatus.DELETED_FROM_VECTOR:
raise serializers.ValidationError("Content is already deleted from vector.")
return attrs

class ArchiveContentSerializer(serializers.ModelSerializer):
class Meta:
model = Content
fields = [
"is_deleted",
]
read_only_fields = ["deleted_at", "deleted_by"]

def validate(self, data):
instance = self.instance
if instance.is_deleted:
raise serializers.ValidationError("Content is already deleted.")
return data

def update(self, instance, validated_data):
instance.is_deleted = True
instance.deleted_by = self.context["request"].user
instance.save(update_fields=("is_deleted", "deleted_by"))
return instance
fields = ["content"]

def save(self, **_):
assert isinstance(self.validated_data, dict)
content = self.validated_data["content"]
content.document_status = Content.DocumentStatus.DELETED_FROM_VECTOR
content.deleted_by = self.context["request"].user
content.deleted_at = timezone.now()
content.is_deleted = True
content.save(update_fields=["document_status", "deleted_by", "deleted_at", "is_deleted"])
transaction.on_commit(lambda: delete_content_from_qdrant_task.delay(content.content_id))
return content


class RetriggerContentSerializer(serializers.ModelSerializer):
content = serializers.PrimaryKeyRelatedField(queryset=Content.objects.all(), required=True)

def validate(self, attrs):
content = attrs["content"]
if content.document_status == Content.DocumentStatus.ADDED_TO_VECTOR:
raise serializers.ValidationError("Content has already been added to vector. No need to trigger it again.")
return attrs

class Meta:
model = Content
fields = ["content"]

def save(self, **_):
assert isinstance(self.validated_data, dict)
content = self.validated_data["content"]
transaction.on_commit(lambda: create_embedding_for_content_task.delay(content.id))
return content
24 changes: 20 additions & 4 deletions content/tasks.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
import logging

import requests
from celery import shared_task
from django.conf import settings

from chatbotcore.database import QdrantDatabase
from chatbotcore.doc_loaders import LoaderFromText

logger = logging.getLogger(__name__)


@shared_task(blind=True)
def create_embedding_for_content_task(content_id):
@shared_task(bind=True)
def create_embedding_for_content_task(self, content_id):
from content.models import Content

content = Content.objects.get(id=content_id)
Expand All @@ -27,13 +31,25 @@ def create_embedding_for_content_task(content_id):
{"source": "plain-text", "page_content": split_docs[i].page_content, "uuid": content.content_id}
for i in range(len(split_docs))
]
if response.status_code == 200:
try:
db = QdrantDatabase(
host=settings.QDRANT_DB_HOST, port=settings.QDRANT_DB_PORT, collection_name=settings.QDRANT_DB_COLLECTION_NAME
)
db.set_collection()
db.store_data(zip(response.json(), metadata))
content.document_status = Content.DocumentStatus.ADDED_TO_VECTOR
else:

# NOTE: All exceptions have been handled with except
except Exception:
logger.error("An error occurred while creating embeddings", exc_info=True)
content.document_status = Content.DocumentStatus.FAILURE
content.save()


@shared_task(bind=True)
def delete_content_from_qdrant_task(self, content_id):
db = QdrantDatabase(
host=settings.QDRANT_DB_HOST, port=settings.QDRANT_DB_PORT, collection_name=settings.QDRANT_DB_COLLECTION_NAME
)
db.delete_data_by_src_uuid(key="uuid", value=str(content_id))
return logger.info(f"Deleted content {content_id}")
22 changes: 14 additions & 8 deletions schema.graphql
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@ type AppEnumCollectionUserDepartment {
label: String!
}

input ArchiveContentInput {
id: String!
}

input BoolBaseFilterLookup {
"""Exact match. Filter will be skipped on `null` value"""
exact: Boolean
Expand Down Expand Up @@ -174,8 +178,9 @@ type PrivateMutation {
updateMe(data: UserMeInput!): UserMeTypeMutationResponseType!
createContent(data: ContentCreateInput!): ContentTypeMutationResponseType!
readFile(file: Upload!): String!
updateContentTitle(id: ID!, data: UpdateContentUpdateInput!): ContentTypeMutationResponseType!
archiveContent(id: ID!, data: archiveUpdateInput!): MutationEmptyResponseType!
updateContentTitle(data: UpdateContentTitleInput!): ContentTypeMutationResponseType!
retriggerContent(data: RetriggerContentInput!): ContentTypeMutationResponseType!
archiveContent(data: ArchiveContentInput!): ContentTypeMutationResponseType!
createTag(data: CreateTagCreateInput!): TagTypeMutationResponseType!
addOrganization(data: AddOrganizationInputType!): OrganizationTypeMutationResponseType!
updateOrganization(data: UpdateOrganizationInputType!): OrganizationTypeMutationResponseType!
Expand Down Expand Up @@ -242,6 +247,10 @@ input ResetUserPassword {
email: String!
}

input RetriggerContentInput {
id: String!
}

input StrFilterLookup {
"""Exact match. Filter will be skipped on `null` value"""
exact: String
Expand Down Expand Up @@ -322,8 +331,9 @@ type TagTypeMutationResponseType {
result: TagType
}

input UpdateContentUpdateInput {
title: String
input UpdateContentTitleInput {
title: String!
content: ID!
}

input UpdateOrganizationInputType {
Expand Down Expand Up @@ -409,8 +419,4 @@ type UserTypeMutationResponseType {
ok: Boolean!
errors: CustomErrorType
result: UserType
}

input archiveUpdateInput {
isDeleted: Boolean
}
Loading