Skip to content
Open
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
Type,
Union,
Callable,
Literal,
)


Expand Down Expand Up @@ -141,9 +142,9 @@ class GoogleGenAI(FunctionCallingLLM):
default=None,
description="Google GenAI tool to use for the model to augment responses.",
)
use_file_api: bool = Field(
default=True,
description="Whether or not to use the FileAPI for large files (>20MB).",
file_mode: Literal["inline", "fileapi", "hybrid"] = Field(
default="hybrid",
description="Whether to use inline-only, FileAPI-only or both for handling files.",
)

_max_tokens: int = PrivateAttr()
Expand All @@ -167,7 +168,7 @@ def __init__(
is_function_calling_model: bool = True,
cached_content: Optional[str] = None,
built_in_tool: Optional[types.Tool] = None,
use_file_api: bool = True,
file_mode: Literal["inline", "fileapi", "hybrid"] = "hybrid",
**kwargs: Any,
):
# API keys are optional. The API can be authorised via OAuth (detected
Expand Down Expand Up @@ -216,7 +217,7 @@ def __init__(
max_retries=max_retries,
cached_content=cached_content,
built_in_tool=built_in_tool,
use_file_api=use_file_api,
file_mode=file_mode,
**kwargs,
)

Expand Down Expand Up @@ -309,20 +310,17 @@ def _chat(self, messages: Sequence[ChatMessage], **kwargs: Any):
**kwargs.pop("generation_config", {}),
}
params = {**kwargs, "generation_config": generation_config}
next_msg, chat_kwargs = asyncio.run(
next_msg, chat_kwargs, file_api_names = asyncio.run(
prepare_chat_params(
self.model, messages, self.use_file_api, self._client, **params
self.model, messages, self.file_mode, self._client, **params
)
)
chat = self._client.chats.create(**chat_kwargs)
response = chat.send_message(
next_msg.parts if isinstance(next_msg, types.Content) else next_msg
)

if self.use_file_api:
asyncio.run(
delete_uploaded_files([*chat_kwargs["history"], next_msg], self._client)
)
asyncio.run(delete_uploaded_files(file_api_names, self._client))

return chat_from_gemini_response(response)

Expand All @@ -333,18 +331,15 @@ async def _achat(self, messages: Sequence[ChatMessage], **kwargs: Any):
**kwargs.pop("generation_config", {}),
}
params = {**kwargs, "generation_config": generation_config}
next_msg, chat_kwargs = await prepare_chat_params(
self.model, messages, self.use_file_api, self._client, **params
next_msg, chat_kwargs, file_api_names = await prepare_chat_params(
self.model, messages, self.file_mode, self._client, **params
)
chat = self._client.aio.chats.create(**chat_kwargs)
response = await chat.send_message(
next_msg.parts if isinstance(next_msg, types.Content) else next_msg
)

if self.use_file_api:
await delete_uploaded_files(
[*chat_kwargs["history"], next_msg], self._client
)
await delete_uploaded_files(file_api_names, self._client)

return chat_from_gemini_response(response)

Expand All @@ -366,9 +361,9 @@ def _stream_chat(
**kwargs.pop("generation_config", {}),
}
params = {**kwargs, "generation_config": generation_config}
next_msg, chat_kwargs = asyncio.run(
next_msg, chat_kwargs, file_api_names = asyncio.run(
prepare_chat_params(
self.model, messages, self.use_file_api, self._client, **params
self.model, messages, self.file_mode, self._client, **params
)
)
chat = self._client.chats.create(**chat_kwargs)
Expand Down Expand Up @@ -402,12 +397,8 @@ def gen() -> ChatResponseGen:
llama_resp.message.blocks = [ThinkingBlock(content=thoughts)]
yield llama_resp

if self.use_file_api:
asyncio.run(
delete_uploaded_files(
[*chat_kwargs["history"], next_msg], self._client
)
)
if self.file_mode in ("fileapi", "hybrid"):
asyncio.run(delete_uploaded_files(file_api_names, self._client))

return gen()

Expand All @@ -425,8 +416,8 @@ async def _astream_chat(
**kwargs.pop("generation_config", {}),
}
params = {**kwargs, "generation_config": generation_config}
next_msg, chat_kwargs = await prepare_chat_params(
self.model, messages, self.use_file_api, self._client, **params
next_msg, chat_kwargs, file_api_names = await prepare_chat_params(
self.model, messages, self.file_mode, self._client, **params
)
chat = self._client.aio.chats.create(**chat_kwargs)

Expand Down Expand Up @@ -463,10 +454,7 @@ async def gen() -> ChatResponseAsyncGen:
]
yield llama_resp

if self.use_file_api:
await delete_uploaded_files(
[*chat_kwargs["history"], next_msg], self._client
)
await delete_uploaded_files(file_api_names, self._client)

return gen()

Expand Down Expand Up @@ -589,12 +577,13 @@ def structured_predict_without_function_calling(
llm_kwargs = llm_kwargs or {}

messages = prompt.format_messages(**prompt_args)
contents = [
asyncio.run(
chat_message_to_gemini(message, self.use_file_api, self._client)
)
contents_and_names = [
asyncio.run(chat_message_to_gemini(message, self.file_mode, self._client))
for message in messages
]
contents = [it[0] for it in contents_and_names]
file_api_names = [name for it in contents_and_names for name in it[1]]

response = self._client.models.generate_content(
model=self.model,
contents=contents,
Expand All @@ -609,8 +598,7 @@ def structured_predict_without_function_calling(
},
)

if self.use_file_api:
asyncio.run(delete_uploaded_files(contents, self._client))
asyncio.run(delete_uploaded_files(file_api_names, self._client))

if isinstance(response.parsed, BaseModel):
return response.parsed
Expand Down Expand Up @@ -639,20 +627,22 @@ def structured_predict(
generation_config["response_schema"] = output_cls

messages = prompt.format_messages(**prompt_args)
contents = [
contents_and_names = [
asyncio.run(
chat_message_to_gemini(message, self.use_file_api, self._client)
chat_message_to_gemini(message, self.file_mode, self._client)
)
for message in messages
]
contents = [it[0] for it in contents_and_names]
file_api_names = [name for it in contents_and_names for name in it[1]]

response = self._client.models.generate_content(
model=self.model,
contents=contents,
config=generation_config,
)

if self.use_file_api:
asyncio.run(delete_uploaded_files(contents, self._client))
asyncio.run(delete_uploaded_files(file_api_names, self._client))

if isinstance(response.parsed, BaseModel):
return response.parsed
Expand Down Expand Up @@ -686,20 +676,22 @@ async def astructured_predict(
generation_config["response_schema"] = output_cls

messages = prompt.format_messages(**prompt_args)
contents = await asyncio.gather(
contents_and_names = await asyncio.gather(
*[
chat_message_to_gemini(message, self.use_file_api, self._client)
chat_message_to_gemini(message, self.file_mode, self._client)
for message in messages
]
)
contents = [it[0] for it in contents_and_names]
file_api_names = [name for it in contents_and_names for name in it[1]]

response = await self._client.aio.models.generate_content(
model=self.model,
contents=contents,
config=generation_config,
)

if self.use_file_api:
await delete_uploaded_files(contents, self._client)
await delete_uploaded_files(file_api_names, self._client)

if isinstance(response.parsed, BaseModel):
return response.parsed
Expand Down Expand Up @@ -733,12 +725,14 @@ def stream_structured_predict(
generation_config["response_schema"] = output_cls

messages = prompt.format_messages(**prompt_args)
contents = [
contents_and_names = [
asyncio.run(
chat_message_to_gemini(message, self.use_file_api, self._client)
chat_message_to_gemini(message, self.file_mode, self._client)
)
for message in messages
]
contents = [it[0] for it in contents_and_names]
file_api_names = [name for it in contents_and_names for name in it[1]]

def gen() -> Generator[Union[Model, FlexibleModel], None, None]:
flexible_model = create_flexible_model(output_cls)
Expand All @@ -762,8 +756,7 @@ def gen() -> Generator[Union[Model, FlexibleModel], None, None]:
if streaming_model:
yield streaming_model

if self.use_file_api:
asyncio.run(delete_uploaded_files(contents, self._client))
asyncio.run(delete_uploaded_files(file_api_names, self._client))

return gen()
else:
Expand Down Expand Up @@ -793,12 +786,14 @@ async def astream_structured_predict(
generation_config["response_schema"] = output_cls

messages = prompt.format_messages(**prompt_args)
contents = await asyncio.gather(
contents_and_names = await asyncio.gather(
*[
chat_message_to_gemini(message, self.use_file_api, self._client)
chat_message_to_gemini(message, self.file_mode, self._client)
for message in messages
]
)
contents = [it[0] for it in contents_and_names]
file_api_names = [name for it in contents_and_names for name in it[1]]

async def gen() -> AsyncGenerator[Union[Model, FlexibleModel], None]:
flexible_model = create_flexible_model(output_cls)
Expand All @@ -822,8 +817,7 @@ async def gen() -> AsyncGenerator[Union[Model, FlexibleModel], None]:
if streaming_model:
yield streaming_model

if self.use_file_api:
await delete_uploaded_files(contents, self._client)
await delete_uploaded_files(file_api_names, self._client)

return gen()
else:
Expand Down
Loading
Loading