5
5
PipelineFileCreate ,
6
6
ProjectCreate ,
7
7
CompositeRetrievalMode ,
8
+ LlamaParseParameters ,
8
9
)
9
10
from llama_index .indices .managed .llama_cloud import (
10
11
LlamaCloudIndex ,
11
12
LlamaCloudCompositeRetriever ,
12
13
)
13
14
from llama_index .embeddings .openai import OpenAIEmbedding
14
- from llama_index .core .schema import Document
15
+ from llama_index .core .schema import Document , ImageNode
15
16
import os
16
17
import pytest
17
18
from uuid import uuid4
@@ -35,6 +36,7 @@ def remote_file() -> Tuple[str, str]:
35
36
36
37
def _setup_empty_index (
37
38
client : LlamaCloud ,
39
+ multi_modal_index : bool = False ,
38
40
) -> LlamaCloudIndex :
39
41
# create project if it doesn't exist
40
42
project_create = ProjectCreate (name = project_name )
@@ -47,6 +49,7 @@ def _setup_empty_index(
47
49
name = "test_empty_index_" + str (uuid4 ()),
48
50
embedding_config = {"type" : "OPENAI_EMBEDDING" , "component" : OpenAIEmbedding ()},
49
51
transform_config = AutoTransformConfig (),
52
+ llama_parse_parameters = LlamaParseParameters (take_screenshot = multi_modal_index ),
50
53
)
51
54
return client .pipelines .upsert_pipeline (
52
55
project_id = project .id , request = pipeline_create
@@ -279,6 +282,36 @@ def test_index_from_documents():
279
282
assert "3" not in docs
280
283
281
284
285
+ @pytest .mark .skipif (
286
+ not base_url or not api_key , reason = "No platform base url or api key set"
287
+ )
288
+ @pytest .mark .skipif (not openai_api_key , reason = "No openai api key set" )
289
+ @pytest .mark .integration ()
290
+ def test_image_retrieval () -> None :
291
+ pipeline = _setup_empty_index (
292
+ LlamaCloud (token = api_key , base_url = base_url ), multi_modal_index = True
293
+ )
294
+
295
+ index = LlamaCloudIndex (
296
+ name = pipeline .name ,
297
+ project_name = project_name ,
298
+ api_key = api_key ,
299
+ base_url = base_url ,
300
+ )
301
+
302
+ file_path = "llama-index-integrations/indices/llama-index-indices-managed-llama-cloud/tests/data/Simple PDF Slides.pdf"
303
+ file_id = index .upload_file (file_path , wait_for_ingestion = True )
304
+
305
+ retriever = index .as_retriever (retrieve_image_nodes = True )
306
+ nodes = retriever .retrieve ("1" )
307
+ assert len (nodes ) > 0
308
+
309
+ image_nodes = [n .node for n in nodes if isinstance (n .node , ImageNode )]
310
+ assert len (image_nodes ) > 0
311
+ assert all (n .metadata ["file_id" ] == file_id for n in image_nodes )
312
+ assert all (n .metadata ["page_index" ] >= 0 for n in image_nodes )
313
+
314
+
282
315
@pytest .mark .skipif (
283
316
not base_url or not api_key , reason = "No platform base url or api key set"
284
317
)
0 commit comments