1
1
"""Google's hosted Gemini API."""
2
2
3
+ import asyncio
3
4
import functools
4
5
import os
5
6
import typing
52
53
prepare_chat_params ,
53
54
handle_streaming_flexible_model ,
54
55
create_retry_decorator ,
56
+ delete_uploaded_files ,
55
57
)
56
58
57
59
import google .genai
@@ -135,6 +137,10 @@ class GoogleGenAI(FunctionCallingLLM):
135
137
default = None ,
136
138
description = "Google GenAI tool to use for the model to augment responses." ,
137
139
)
140
+ use_file_api : bool = Field (
141
+ default = True ,
142
+ description = "Whether or not to use the FileAPI for large files (>20MB)." ,
143
+ )
138
144
139
145
_max_tokens : int = PrivateAttr ()
140
146
_client : google .genai .Client = PrivateAttr ()
@@ -157,6 +163,7 @@ def __init__(
157
163
is_function_calling_model : bool = True ,
158
164
cached_content : Optional [str ] = None ,
159
165
built_in_tool : Optional [types .Tool ] = None ,
166
+ use_file_api : bool = True ,
160
167
** kwargs : Any ,
161
168
):
162
169
# API keys are optional. The API can be authorised via OAuth (detected
@@ -205,6 +212,7 @@ def __init__(
205
212
max_retries = max_retries ,
206
213
cached_content = cached_content ,
207
214
built_in_tool = built_in_tool ,
215
+ use_file_api = use_file_api ,
208
216
** kwargs ,
209
217
)
210
218
@@ -297,9 +305,21 @@ def _chat(self, messages: Sequence[ChatMessage], **kwargs: Any):
297
305
** kwargs .pop ("generation_config" , {}),
298
306
}
299
307
params = {** kwargs , "generation_config" : generation_config }
300
- next_msg , chat_kwargs = prepare_chat_params (self .model , messages , ** params )
308
+ next_msg , chat_kwargs = asyncio .run (
309
+ prepare_chat_params (
310
+ self .model , messages , self .use_file_api , self ._client , ** params
311
+ )
312
+ )
301
313
chat = self ._client .chats .create (** chat_kwargs )
302
- response = chat .send_message (next_msg .parts )
314
+ response = chat .send_message (
315
+ next_msg .parts if isinstance (next_msg , types .Content ) else next_msg
316
+ )
317
+
318
+ if self .use_file_api :
319
+ asyncio .run (
320
+ delete_uploaded_files ([* chat_kwargs ["history" ], next_msg ], self ._client )
321
+ )
322
+
303
323
return chat_from_gemini_response (response )
304
324
305
325
@llm_retry_decorator
@@ -309,9 +329,19 @@ async def _achat(self, messages: Sequence[ChatMessage], **kwargs: Any):
309
329
** kwargs .pop ("generation_config" , {}),
310
330
}
311
331
params = {** kwargs , "generation_config" : generation_config }
312
- next_msg , chat_kwargs = prepare_chat_params (self .model , messages , ** params )
332
+ next_msg , chat_kwargs = await prepare_chat_params (
333
+ self .model , messages , self .use_file_api , self ._client , ** params
334
+ )
313
335
chat = self ._client .aio .chats .create (** chat_kwargs )
314
- response = await chat .send_message (next_msg .parts )
336
+ response = await chat .send_message (
337
+ next_msg .parts if isinstance (next_msg , types .Content ) else next_msg
338
+ )
339
+
340
+ if self .use_file_api :
341
+ await delete_uploaded_files (
342
+ [* chat_kwargs ["history" ], next_msg ], self ._client
343
+ )
344
+
315
345
return chat_from_gemini_response (response )
316
346
317
347
@llm_chat_callback ()
@@ -332,9 +362,15 @@ def _stream_chat(
332
362
** kwargs .pop ("generation_config" , {}),
333
363
}
334
364
params = {** kwargs , "generation_config" : generation_config }
335
- next_msg , chat_kwargs = prepare_chat_params (self .model , messages , ** params )
365
+ next_msg , chat_kwargs = asyncio .run (
366
+ prepare_chat_params (
367
+ self .model , messages , self .use_file_api , self ._client , ** params
368
+ )
369
+ )
336
370
chat = self ._client .chats .create (** chat_kwargs )
337
- response = chat .send_message_stream (next_msg .parts )
371
+ response = chat .send_message_stream (
372
+ next_msg .parts if isinstance (next_msg , types .Content ) else next_msg
373
+ )
338
374
339
375
def gen () -> ChatResponseGen :
340
376
content = ""
@@ -361,6 +397,13 @@ def gen() -> ChatResponseGen:
361
397
llama_resp .message .additional_kwargs ["tool_calls" ] = existing_tool_calls
362
398
yield llama_resp
363
399
400
+ if self .use_file_api :
401
+ asyncio .run (
402
+ delete_uploaded_files (
403
+ [* chat_kwargs ["history" ], next_msg ], self ._client
404
+ )
405
+ )
406
+
364
407
return gen ()
365
408
366
409
@llm_chat_callback ()
@@ -377,14 +420,18 @@ async def _astream_chat(
377
420
** kwargs .pop ("generation_config" , {}),
378
421
}
379
422
params = {** kwargs , "generation_config" : generation_config }
380
- next_msg , chat_kwargs = prepare_chat_params (self .model , messages , ** params )
423
+ next_msg , chat_kwargs = await prepare_chat_params (
424
+ self .model , messages , self .use_file_api , self ._client , ** params
425
+ )
381
426
chat = self ._client .aio .chats .create (** chat_kwargs )
382
427
383
428
async def gen () -> ChatResponseAsyncGen :
384
429
content = ""
385
430
existing_tool_calls = []
386
431
thoughts = ""
387
- async for r in await chat .send_message_stream (next_msg .parts ):
432
+ async for r in await chat .send_message_stream (
433
+ next_msg .parts if isinstance (next_msg , types .Content ) else next_msg
434
+ ):
388
435
if candidates := r .candidates :
389
436
if not candidates :
390
437
continue
@@ -412,6 +459,11 @@ async def gen() -> ChatResponseAsyncGen:
412
459
)
413
460
yield llama_resp
414
461
462
+ if self .use_file_api :
463
+ await delete_uploaded_files (
464
+ [* chat_kwargs ["history" ], next_msg ], self ._client
465
+ )
466
+
415
467
return gen ()
416
468
417
469
@llm_chat_callback ()
@@ -529,9 +581,15 @@ def structured_predict_without_function_calling(
529
581
llm_kwargs = llm_kwargs or {}
530
582
531
583
messages = prompt .format_messages (** prompt_args )
584
+ contents = [
585
+ asyncio .run (
586
+ chat_message_to_gemini (message , self .use_file_api , self ._client )
587
+ )
588
+ for message in messages
589
+ ]
532
590
response = self ._client .models .generate_content (
533
591
model = self .model ,
534
- contents = list ( map ( chat_message_to_gemini , messages )) ,
592
+ contents = contents ,
535
593
** {
536
594
** llm_kwargs ,
537
595
** {
@@ -543,6 +601,9 @@ def structured_predict_without_function_calling(
543
601
},
544
602
)
545
603
604
+ if self .use_file_api :
605
+ asyncio .run (delete_uploaded_files (contents , self ._client ))
606
+
546
607
if isinstance (response .parsed , BaseModel ):
547
608
return response .parsed
548
609
else :
@@ -570,13 +631,21 @@ def structured_predict(
570
631
generation_config ["response_schema" ] = output_cls
571
632
572
633
messages = prompt .format_messages (** prompt_args )
573
- contents = list (map (chat_message_to_gemini , messages ))
634
+ contents = [
635
+ asyncio .run (
636
+ chat_message_to_gemini (message , self .use_file_api , self ._client )
637
+ )
638
+ for message in messages
639
+ ]
574
640
response = self ._client .models .generate_content (
575
641
model = self .model ,
576
642
contents = contents ,
577
643
config = generation_config ,
578
644
)
579
645
646
+ if self .use_file_api :
647
+ asyncio .run (delete_uploaded_files (contents , self ._client ))
648
+
580
649
if isinstance (response .parsed , BaseModel ):
581
650
return response .parsed
582
651
else :
@@ -609,13 +678,21 @@ async def astructured_predict(
609
678
generation_config ["response_schema" ] = output_cls
610
679
611
680
messages = prompt .format_messages (** prompt_args )
612
- contents = list (map (chat_message_to_gemini , messages ))
681
+ contents = await asyncio .gather (
682
+ * [
683
+ chat_message_to_gemini (message , self .use_file_api , self ._client )
684
+ for message in messages
685
+ ]
686
+ )
613
687
response = await self ._client .aio .models .generate_content (
614
688
model = self .model ,
615
689
contents = contents ,
616
690
config = generation_config ,
617
691
)
618
692
693
+ if self .use_file_api :
694
+ await delete_uploaded_files (contents , self ._client )
695
+
619
696
if isinstance (response .parsed , BaseModel ):
620
697
return response .parsed
621
698
else :
@@ -648,7 +725,12 @@ def stream_structured_predict(
648
725
generation_config ["response_schema" ] = output_cls
649
726
650
727
messages = prompt .format_messages (** prompt_args )
651
- contents = list (map (chat_message_to_gemini , messages ))
728
+ contents = [
729
+ asyncio .run (
730
+ chat_message_to_gemini (message , self .use_file_api , self ._client )
731
+ )
732
+ for message in messages
733
+ ]
652
734
653
735
def gen () -> Generator [Union [Model , FlexibleModel ], None , None ]:
654
736
flexible_model = create_flexible_model (output_cls )
@@ -672,6 +754,9 @@ def gen() -> Generator[Union[Model, FlexibleModel], None, None]:
672
754
if streaming_model :
673
755
yield streaming_model
674
756
757
+ if self .use_file_api :
758
+ asyncio .run (delete_uploaded_files (contents , self ._client ))
759
+
675
760
return gen ()
676
761
else :
677
762
return super ().stream_structured_predict (
@@ -700,7 +785,12 @@ async def astream_structured_predict(
700
785
generation_config ["response_schema" ] = output_cls
701
786
702
787
messages = prompt .format_messages (** prompt_args )
703
- contents = list (map (chat_message_to_gemini , messages ))
788
+ contents = await asyncio .gather (
789
+ * [
790
+ chat_message_to_gemini (message , self .use_file_api , self ._client )
791
+ for message in messages
792
+ ]
793
+ )
704
794
705
795
async def gen () -> AsyncGenerator [Union [Model , FlexibleModel ], None ]:
706
796
flexible_model = create_flexible_model (output_cls )
@@ -724,6 +814,9 @@ async def gen() -> AsyncGenerator[Union[Model, FlexibleModel], None]:
724
814
if streaming_model :
725
815
yield streaming_model
726
816
817
+ if self .use_file_api :
818
+ await delete_uploaded_files (contents , self ._client )
819
+
727
820
return gen ()
728
821
else :
729
822
return await super ().astream_structured_predict (
0 commit comments