Skip to content

Commit 3c39596

Browse files
Florian-BACHOFlorian Bacho
andauthored
Add GoogleGenAI FileAPI support for large files (#19853)
* Add GoogleGenAI FileAPI support for large files * Bump version * Update dependencies * Bump version to 0.4.0 * Bump version * Bump version to 0.5.0 --------- Co-authored-by: Florian Bacho <[email protected]>
1 parent 43545ca commit 3c39596

File tree

5 files changed

+247
-49
lines changed

5 files changed

+247
-49
lines changed

llama-index-integrations/llms/llama-index-llms-google-genai/llama_index/llms/google_genai/base.py

Lines changed: 106 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""Google's hosted Gemini API."""
22

3+
import asyncio
34
import functools
45
import os
56
import typing
@@ -52,6 +53,7 @@
5253
prepare_chat_params,
5354
handle_streaming_flexible_model,
5455
create_retry_decorator,
56+
delete_uploaded_files,
5557
)
5658

5759
import google.genai
@@ -135,6 +137,10 @@ class GoogleGenAI(FunctionCallingLLM):
135137
default=None,
136138
description="Google GenAI tool to use for the model to augment responses.",
137139
)
140+
use_file_api: bool = Field(
141+
default=True,
142+
description="Whether or not to use the FileAPI for large files (>20MB).",
143+
)
138144

139145
_max_tokens: int = PrivateAttr()
140146
_client: google.genai.Client = PrivateAttr()
@@ -157,6 +163,7 @@ def __init__(
157163
is_function_calling_model: bool = True,
158164
cached_content: Optional[str] = None,
159165
built_in_tool: Optional[types.Tool] = None,
166+
use_file_api: bool = True,
160167
**kwargs: Any,
161168
):
162169
# API keys are optional. The API can be authorised via OAuth (detected
@@ -205,6 +212,7 @@ def __init__(
205212
max_retries=max_retries,
206213
cached_content=cached_content,
207214
built_in_tool=built_in_tool,
215+
use_file_api=use_file_api,
208216
**kwargs,
209217
)
210218

@@ -297,9 +305,21 @@ def _chat(self, messages: Sequence[ChatMessage], **kwargs: Any):
297305
**kwargs.pop("generation_config", {}),
298306
}
299307
params = {**kwargs, "generation_config": generation_config}
300-
next_msg, chat_kwargs = prepare_chat_params(self.model, messages, **params)
308+
next_msg, chat_kwargs = asyncio.run(
309+
prepare_chat_params(
310+
self.model, messages, self.use_file_api, self._client, **params
311+
)
312+
)
301313
chat = self._client.chats.create(**chat_kwargs)
302-
response = chat.send_message(next_msg.parts)
314+
response = chat.send_message(
315+
next_msg.parts if isinstance(next_msg, types.Content) else next_msg
316+
)
317+
318+
if self.use_file_api:
319+
asyncio.run(
320+
delete_uploaded_files([*chat_kwargs["history"], next_msg], self._client)
321+
)
322+
303323
return chat_from_gemini_response(response)
304324

305325
@llm_retry_decorator
@@ -309,9 +329,19 @@ async def _achat(self, messages: Sequence[ChatMessage], **kwargs: Any):
309329
**kwargs.pop("generation_config", {}),
310330
}
311331
params = {**kwargs, "generation_config": generation_config}
312-
next_msg, chat_kwargs = prepare_chat_params(self.model, messages, **params)
332+
next_msg, chat_kwargs = await prepare_chat_params(
333+
self.model, messages, self.use_file_api, self._client, **params
334+
)
313335
chat = self._client.aio.chats.create(**chat_kwargs)
314-
response = await chat.send_message(next_msg.parts)
336+
response = await chat.send_message(
337+
next_msg.parts if isinstance(next_msg, types.Content) else next_msg
338+
)
339+
340+
if self.use_file_api:
341+
await delete_uploaded_files(
342+
[*chat_kwargs["history"], next_msg], self._client
343+
)
344+
315345
return chat_from_gemini_response(response)
316346

317347
@llm_chat_callback()
@@ -332,9 +362,15 @@ def _stream_chat(
332362
**kwargs.pop("generation_config", {}),
333363
}
334364
params = {**kwargs, "generation_config": generation_config}
335-
next_msg, chat_kwargs = prepare_chat_params(self.model, messages, **params)
365+
next_msg, chat_kwargs = asyncio.run(
366+
prepare_chat_params(
367+
self.model, messages, self.use_file_api, self._client, **params
368+
)
369+
)
336370
chat = self._client.chats.create(**chat_kwargs)
337-
response = chat.send_message_stream(next_msg.parts)
371+
response = chat.send_message_stream(
372+
next_msg.parts if isinstance(next_msg, types.Content) else next_msg
373+
)
338374

339375
def gen() -> ChatResponseGen:
340376
content = ""
@@ -361,6 +397,13 @@ def gen() -> ChatResponseGen:
361397
llama_resp.message.additional_kwargs["tool_calls"] = existing_tool_calls
362398
yield llama_resp
363399

400+
if self.use_file_api:
401+
asyncio.run(
402+
delete_uploaded_files(
403+
[*chat_kwargs["history"], next_msg], self._client
404+
)
405+
)
406+
364407
return gen()
365408

366409
@llm_chat_callback()
@@ -377,14 +420,18 @@ async def _astream_chat(
377420
**kwargs.pop("generation_config", {}),
378421
}
379422
params = {**kwargs, "generation_config": generation_config}
380-
next_msg, chat_kwargs = prepare_chat_params(self.model, messages, **params)
423+
next_msg, chat_kwargs = await prepare_chat_params(
424+
self.model, messages, self.use_file_api, self._client, **params
425+
)
381426
chat = self._client.aio.chats.create(**chat_kwargs)
382427

383428
async def gen() -> ChatResponseAsyncGen:
384429
content = ""
385430
existing_tool_calls = []
386431
thoughts = ""
387-
async for r in await chat.send_message_stream(next_msg.parts):
432+
async for r in await chat.send_message_stream(
433+
next_msg.parts if isinstance(next_msg, types.Content) else next_msg
434+
):
388435
if candidates := r.candidates:
389436
if not candidates:
390437
continue
@@ -412,6 +459,11 @@ async def gen() -> ChatResponseAsyncGen:
412459
)
413460
yield llama_resp
414461

462+
if self.use_file_api:
463+
await delete_uploaded_files(
464+
[*chat_kwargs["history"], next_msg], self._client
465+
)
466+
415467
return gen()
416468

417469
@llm_chat_callback()
@@ -529,9 +581,15 @@ def structured_predict_without_function_calling(
529581
llm_kwargs = llm_kwargs or {}
530582

531583
messages = prompt.format_messages(**prompt_args)
584+
contents = [
585+
asyncio.run(
586+
chat_message_to_gemini(message, self.use_file_api, self._client)
587+
)
588+
for message in messages
589+
]
532590
response = self._client.models.generate_content(
533591
model=self.model,
534-
contents=list(map(chat_message_to_gemini, messages)),
592+
contents=contents,
535593
**{
536594
**llm_kwargs,
537595
**{
@@ -543,6 +601,9 @@ def structured_predict_without_function_calling(
543601
},
544602
)
545603

604+
if self.use_file_api:
605+
asyncio.run(delete_uploaded_files(contents, self._client))
606+
546607
if isinstance(response.parsed, BaseModel):
547608
return response.parsed
548609
else:
@@ -570,13 +631,21 @@ def structured_predict(
570631
generation_config["response_schema"] = output_cls
571632

572633
messages = prompt.format_messages(**prompt_args)
573-
contents = list(map(chat_message_to_gemini, messages))
634+
contents = [
635+
asyncio.run(
636+
chat_message_to_gemini(message, self.use_file_api, self._client)
637+
)
638+
for message in messages
639+
]
574640
response = self._client.models.generate_content(
575641
model=self.model,
576642
contents=contents,
577643
config=generation_config,
578644
)
579645

646+
if self.use_file_api:
647+
asyncio.run(delete_uploaded_files(contents, self._client))
648+
580649
if isinstance(response.parsed, BaseModel):
581650
return response.parsed
582651
else:
@@ -609,13 +678,21 @@ async def astructured_predict(
609678
generation_config["response_schema"] = output_cls
610679

611680
messages = prompt.format_messages(**prompt_args)
612-
contents = list(map(chat_message_to_gemini, messages))
681+
contents = await asyncio.gather(
682+
*[
683+
chat_message_to_gemini(message, self.use_file_api, self._client)
684+
for message in messages
685+
]
686+
)
613687
response = await self._client.aio.models.generate_content(
614688
model=self.model,
615689
contents=contents,
616690
config=generation_config,
617691
)
618692

693+
if self.use_file_api:
694+
await delete_uploaded_files(contents, self._client)
695+
619696
if isinstance(response.parsed, BaseModel):
620697
return response.parsed
621698
else:
@@ -648,7 +725,12 @@ def stream_structured_predict(
648725
generation_config["response_schema"] = output_cls
649726

650727
messages = prompt.format_messages(**prompt_args)
651-
contents = list(map(chat_message_to_gemini, messages))
728+
contents = [
729+
asyncio.run(
730+
chat_message_to_gemini(message, self.use_file_api, self._client)
731+
)
732+
for message in messages
733+
]
652734

653735
def gen() -> Generator[Union[Model, FlexibleModel], None, None]:
654736
flexible_model = create_flexible_model(output_cls)
@@ -672,6 +754,9 @@ def gen() -> Generator[Union[Model, FlexibleModel], None, None]:
672754
if streaming_model:
673755
yield streaming_model
674756

757+
if self.use_file_api:
758+
asyncio.run(delete_uploaded_files(contents, self._client))
759+
675760
return gen()
676761
else:
677762
return super().stream_structured_predict(
@@ -700,7 +785,12 @@ async def astream_structured_predict(
700785
generation_config["response_schema"] = output_cls
701786

702787
messages = prompt.format_messages(**prompt_args)
703-
contents = list(map(chat_message_to_gemini, messages))
788+
contents = await asyncio.gather(
789+
*[
790+
chat_message_to_gemini(message, self.use_file_api, self._client)
791+
for message in messages
792+
]
793+
)
704794

705795
async def gen() -> AsyncGenerator[Union[Model, FlexibleModel], None]:
706796
flexible_model = create_flexible_model(output_cls)
@@ -724,6 +814,9 @@ async def gen() -> AsyncGenerator[Union[Model, FlexibleModel], None]:
724814
if streaming_model:
725815
yield streaming_model
726816

817+
if self.use_file_api:
818+
await delete_uploaded_files(contents, self._client)
819+
727820
return gen()
728821
else:
729822
return await super().astream_structured_predict(

0 commit comments

Comments
 (0)