Skip to content

Commit d52a4ea

Browse files
committed
Fix data class caches
1 parent 5194887 commit d52a4ea

File tree

19 files changed

+83
-129
lines changed

19 files changed

+83
-129
lines changed

llmstack/data/apis.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from rest_framework.response import Response as DRFResponse
1111

1212
from llmstack.base.models import VectorstoreEmbeddingEndpoint
13-
from llmstack.data.schemas import DataDocument
13+
from llmstack.data.sources.base import DataDocument
1414
from llmstack.data.yaml_loader import (
1515
get_data_pipeline_template_by_slug,
1616
get_data_pipelines_from_contrib,
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,19 @@
1+
from functools import cache
2+
13
from llmstack.data.destinations.stores.singlestore import SingleStore
4+
from llmstack.data.destinations.vector_stores.chromadb import ChromaDB
25
from llmstack.data.destinations.vector_stores.pinecone import Pinecone
6+
from llmstack.data.destinations.vector_stores.qdrant import Qdrant
7+
from llmstack.data.destinations.vector_stores.vector_store import PromptlyVectorStore
38
from llmstack.data.destinations.vector_stores.weaviate import Weaviate
49

10+
11+
@cache
12+
def get_destination_cls(slug, provider_slug):
13+
for cls in [ChromaDB, Weaviate, SingleStore, Pinecone, Qdrant, PromptlyVectorStore]:
14+
if cls.slug() == slug and cls.provider_slug() == provider_slug:
15+
return cls
16+
return None
17+
18+
519
__all__ = ["SingleStore", "Pinecone", "Weaviate"]

llmstack/data/destinations/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
CustomGenerateJsonSchema,
55
get_ui_schema_from_json_schema,
66
)
7-
from llmstack.data.schemas import DataDocument
7+
from llmstack.data.sources.base import DataDocument
88

99

1010
class BaseDestination(BaseModel):

llmstack/data/destinations/utils.py

Lines changed: 0 additions & 18 deletions
This file was deleted.

llmstack/data/destinations/vector_stores/weaviate.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from weaviate.connect.helpers import connect_to_custom, connect_to_wcs
2323

2424
from llmstack.data.destinations.base import BaseDestination
25-
from llmstack.data.schemas import DataDocument
25+
from llmstack.data.sources.base import DataDocument
2626
from llmstack.processors.providers.weaviate import (
2727
APIKey,
2828
WeaviateCloudInstance,

llmstack/data/models.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99

1010
from llmstack.assets.models import Assets
1111
from llmstack.base.models import Profile
12+
from llmstack.data.pipeline import DataIngestionPipeline, DataQueryPipeline
13+
from llmstack.data.schemas import PipelineBlock
1214
from llmstack.events.apis import EventsViewSet
1315

1416
logger = logging.getLogger(__name__)
@@ -122,8 +124,6 @@ def type_slug(self):
122124

123125
@property
124126
def pipeline_obj(self):
125-
from llmstack.data.schemas import PipelineBlock
126-
127127
if self.config.get("pipeline"):
128128
return PipelineBlock(**self.config.get("pipeline"))
129129

@@ -147,13 +147,9 @@ def pipeline(self):
147147
return self.config.get("pipeline", {})
148148

149149
def create_data_ingestion_pipeline(self):
150-
from llmstack.data.pipeline import DataIngestionPipeline
151-
152150
return DataIngestionPipeline(self)
153151

154152
def create_data_query_pipeline(self):
155-
from llmstack.data.pipeline import DataQueryPipeline
156-
157153
return DataQueryPipeline(self)
158154

159155

llmstack/data/pipeline.py

Lines changed: 4 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,7 @@
55
from llama_index.core.schema import Document as LlamaDocument
66

77
from llmstack.common.blocks.data.store.vectorstore import Document
8-
from llmstack.data.models import DataSource
9-
from llmstack.data.schemas import DataDocument
8+
from llmstack.data.sources.base import DataDocument
109

1110
logger = logging.getLogger(__name__)
1211

@@ -18,17 +17,15 @@ class LlamaDocumentShim(LlamaDocument):
1817

1918

2019
class DataIngestionPipeline:
21-
def __init__(self, datasource: DataSource):
20+
def __init__(self, datasource):
2221
self.datasource = datasource
2322
self._source_cls = self.datasource.pipeline_obj.source_cls
2423
self._destination_cls = self.datasource.pipeline_obj.destination_cls
25-
logger.debug("Initializing DataIngestionPipeline")
2624

2725
self._destination = None
2826
self._transformations = self.datasource.pipeline_obj.transformation_objs
2927
embedding_cls = self.datasource.pipeline_obj.embedding_cls
3028
if embedding_cls:
31-
logger.debug("Initializing DataIngestionPipeline Transformation")
3229
embedding_additional_kwargs = {
3330
**self.datasource.pipeline_obj.embedding.data.get("additional_kwargs", {}),
3431
**{"datasource": datasource},
@@ -41,29 +38,21 @@ def __init__(self, datasource: DataSource):
4138
}
4239
)
4340
)
44-
logger.debug("Finished Initializing DataIngestionPipeline Transformation")
4541

4642
if self._destination_cls:
47-
logger.debug("Initializing DataIngestionPipeline Destination")
4843
self._destination = self._destination_cls(**self.datasource.pipeline_obj.destination_data)
4944
self._destination.initialize_client(datasource=self.datasource, create_collection=True)
50-
logger.debug("Finished Initializing DataIngestionPipeline Destination")
5145

5246
def process(self, document: DataDocument) -> DataDocument:
53-
logger.debug(f"Processing document: {document.name}")
5447
document = self._source_cls.process_document(document)
55-
logger.debug(f"Creating IngestionPipeline for document: {document.name}")
5648
ingestion_pipeline = IngestionPipeline(transformations=self._transformations)
5749
ldoc = LlamaDocumentShim(**document.model_dump())
5850
ldoc.metadata = {**ldoc.metadata, **document.metadata}
59-
logger.debug(f"Running IngestionPipeline for document: {document.name}")
6051
document.nodes = ingestion_pipeline.run(documents=[ldoc])
61-
logger.debug(f"Finished running IngestionPipeline for document: {document.name}")
6252
document.node_ids = list(map(lambda x: x.id_, document.nodes))
53+
6354
if self._destination:
64-
logger.debug(f"Adding document: {document.name} to destination")
6555
self._destination.add(document=document)
66-
logger.debug(f"Finished adding document: {document.name} to destination")
6756

6857
return document
6958

@@ -80,55 +69,44 @@ def delete_all_entries(self) -> None:
8069

8170

8271
class DataQueryPipeline:
83-
def __init__(self, datasource: DataSource):
72+
def __init__(self, datasource):
8473
self.datasource = datasource
8574
self._destination_cls = self.datasource.pipeline_obj.destination_cls
8675
self._destination = None
8776
self._embedding_generator = None
88-
logger.debug("Initializing DataQueryPipeline")
8977

9078
if self._destination_cls:
91-
logger.debug("Initializing DataQueryPipeline Destination")
9279
self._destination = self._destination_cls(**self.datasource.pipeline_obj.destination_data)
9380
self._destination.initialize_client(datasource=self.datasource, create_collection=False)
94-
logger.debug("Finished Initializing DataQueryPipeline Destination")
9581

9682
if self.datasource.pipeline_obj.embedding:
97-
logger.debug("Initializing DataQueryPipeline Embedding")
9883
embedding_data = self.datasource.pipeline_obj.embedding.data
9984
embedding_data["additional_kwargs"] = {
10085
**embedding_data.get("additional_kwargs", {}),
10186
**{"datasource": self.datasource},
10287
}
10388
self._embedding_generator = self.datasource.pipeline_obj.embedding_cls(**embedding_data)
104-
logger.debug("Finished Initializing DataQueryPipeline Embedding")
10589

10690
def search(self, query: str, use_hybrid_search=True, **kwargs) -> List[dict]:
10791
content_key = self.datasource.destination_text_content_key
10892
query_embedding = None
10993

110-
logger.debug(f"Initializing Search for query: {query}")
111-
11294
if kwargs.get("search_filters", None):
11395
raise NotImplementedError("Search filters are not supported for this data source.")
11496

11597
documents = []
11698

11799
if self._embedding_generator:
118-
logger.debug("Generating embedding for query")
119100
query_embedding = self._embedding_generator.get_embedding(query)
120-
logger.debug("Finished generating embedding for query")
121101

122102
if self._destination:
123-
logger.debug(f"Searching for query: {query} in destination")
124103
query_result = self._destination.search(
125104
query=query,
126105
use_hybrid_search=use_hybrid_search,
127106
query_embedding=query_embedding,
128107
datasource_uuid=str(self.datasource.uuid),
129108
**kwargs,
130109
)
131-
logger.debug(f"Received results for query: {query} from destination")
132110
documents = list(
133111
map(
134112
lambda x: Document(page_content_key=content_key, page_content=x.text, metadata=x.metadata),

llmstack/data/schemas.py

Lines changed: 6 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
1-
import uuid
2-
from typing import Any, List, Optional
1+
from typing import List, Optional
32

4-
from pydantic import BaseModel, Field, PrivateAttr
3+
from pydantic import BaseModel, PrivateAttr
4+
5+
from llmstack.data.destinations import get_destination_cls
6+
from llmstack.data.sources import get_source_cls
7+
from llmstack.data.transformations import get_transformer_cls
58

69

710
class BaseProcessorBlock(BaseModel):
@@ -36,26 +39,20 @@ def default_dict(self):
3639

3740
class PipelineSource(BaseProcessorBlock):
3841
def __init__(self, **data):
39-
from llmstack.data.sources.utils import get_source_cls
40-
4142
super().__init__(**data)
4243

4344
self._processor_cls = get_source_cls(slug=self.slug, provider_slug=self.provider_slug)
4445

4546

4647
class PipelineDestination(BaseProcessorBlock):
4748
def __init__(self, **data):
48-
from llmstack.data.destinations.utils import get_destination_cls
49-
5049
super().__init__(**data)
5150

5251
self._processor_cls = get_destination_cls(slug=self.slug, provider_slug=self.provider_slug)
5352

5453

5554
class PipelineTransformation(BaseProcessorBlock):
5655
def __init__(self, **data):
57-
from llmstack.data.transformations.utils import get_transformer_cls
58-
5956
super().__init__(**data)
6057

6158
self._processor_cls = get_transformer_cls(slug=self.slug, provider_slug=self.provider_slug)
@@ -66,8 +63,6 @@ def get_default_data(self, **kwargs):
6663

6764
class PipelineEmbedding(BaseProcessorBlock):
6865
def __init__(self, **data):
69-
from llmstack.data.transformations.utils import get_transformer_cls
70-
7166
super().__init__(**data)
7267

7368
self._processor_cls = get_transformer_cls(slug=self.slug, provider_slug=self.provider_slug)
@@ -128,20 +123,3 @@ def default_dict(self):
128123
"destination": self.pipeline.destination.default_dict() if self.pipeline.destination else None,
129124
},
130125
}
131-
132-
133-
class DataDocument(BaseModel):
134-
id_: str = Field(default_factory=lambda: str(uuid.uuid4()), description="Unique ID of the document.")
135-
name: Optional[str] = None
136-
text: Optional[str] = None
137-
text_objref: Optional[str] = None
138-
content: Optional[str] = None
139-
mimetype: str = Field(default="text/plain", description="MIME type of the content.")
140-
metadata: Optional[dict] = None
141-
extra_info: Optional[dict] = {}
142-
nodes: Optional[List[Any]] = None
143-
embeddings: Optional[List[float]] = None
144-
processing_errors: Optional[List[str]] = None
145-
datasource_uuid: Optional[str] = None
146-
request_data: Optional[dict] = {}
147-
node_ids: Optional[List[str]] = []

llmstack/data/sources/__init__.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,18 @@
1+
from functools import cache
2+
3+
from llmstack.data.sources.files.csv import CSVFileSchema
14
from llmstack.data.sources.files.file import FileSchema
5+
from llmstack.data.sources.files.pdf import PdfSchema
26
from llmstack.data.sources.text.text_data import TextSchema
37
from llmstack.data.sources.website.url import URLSchema
48

59
__all__ = ["FileSchema", "TextSchema", "URLSchema"]
10+
11+
12+
@cache
13+
def get_source_cls(slug, provider_slug):
14+
for cls in [CSVFileSchema, FileSchema, PdfSchema, URLSchema, TextSchema]:
15+
if cls.slug() == slug and cls.provider_slug() == provider_slug:
16+
return cls
17+
18+
return None

llmstack/data/sources/base.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,29 @@
1-
from typing import List
1+
import uuid
2+
from typing import Any, List, Optional
23

3-
from pydantic import BaseModel
4+
from pydantic import BaseModel, Field
45

56
from llmstack.common.blocks.base.schema import (
67
CustomGenerateJsonSchema,
78
get_ui_schema_from_json_schema,
89
)
9-
from llmstack.data.schemas import DataDocument
10+
11+
12+
class DataDocument(BaseModel):
13+
id_: str = Field(default_factory=lambda: str(uuid.uuid4()), description="Unique ID of the document.")
14+
name: Optional[str] = None
15+
text: Optional[str] = None
16+
text_objref: Optional[str] = None
17+
content: Optional[str] = None
18+
mimetype: str = Field(default="text/plain", description="MIME type of the content.")
19+
metadata: Optional[dict] = None
20+
extra_info: Optional[dict] = {}
21+
nodes: Optional[List[Any]] = None
22+
embeddings: Optional[List[float]] = None
23+
processing_errors: Optional[List[str]] = None
24+
datasource_uuid: Optional[str] = None
25+
request_data: Optional[dict] = {}
26+
node_ids: Optional[List[str]] = []
1027

1128

1229
class BaseSource(BaseModel):

0 commit comments

Comments
 (0)