diff --git a/README.md b/README.md index 62ed54d47..1f720b32e 100644 --- a/README.md +++ b/README.md @@ -37,7 +37,7 @@ Strands Agents is a simple yet powerful SDK that takes a model-driven approach t ## Feature Overview - **Lightweight & Flexible**: Simple agent loop that just works and is fully customizable -- **Model Agnostic**: Support for Amazon Bedrock, Anthropic, LiteLLM, Llama, Ollama, OpenAI, Writer, and custom providers +- **Model Agnostic**: Support for Amazon Bedrock, Anthropic, Gemini, LiteLLM, Llama, Ollama, OpenAI, Writer, and custom providers - **Advanced Capabilities**: Multi-agent systems, autonomous agents, and streaming support - **Built-in MCP**: Native support for Model Context Protocol (MCP) servers, enabling access to thousands of pre-built tools @@ -129,6 +129,7 @@ from strands import Agent from strands.models import BedrockModel from strands.models.ollama import OllamaModel from strands.models.llamaapi import LlamaAPIModel +from strands.models.gemini import GeminiModel # Bedrock bedrock_model = BedrockModel( @@ -139,6 +140,15 @@ bedrock_model = BedrockModel( agent = Agent(model=bedrock_model) agent("Tell me about Agentic AI") +# Google Gemini +gemini_model = GeminiModel( + api_key="your_gemini_api_key", + model_id="gemini-2.5-flash", + params={"temperature": 0.7} +) +agent = Agent(model=gemini_model) +agent("Tell me about Agentic AI") + # Ollama ollama_model = OllamaModel( host="http://localhost:11434", @@ -158,6 +168,7 @@ response = agent("Tell me about Agentic AI") Built-in providers: - [Amazon Bedrock](https://strandsagents.com/latest/user-guide/concepts/model-providers/amazon-bedrock/) - [Anthropic](https://strandsagents.com/latest/user-guide/concepts/model-providers/anthropic/) + - [Gemini](https://strandsagents.com/latest/user-guide/concepts/model-providers/gemini/) - [LiteLLM](https://strandsagents.com/latest/user-guide/concepts/model-providers/litellm/) - [LlamaAPI](https://strandsagents.com/latest/user-guide/concepts/model-providers/llamaapi/) - [Ollama](https://strandsagents.com/latest/user-guide/concepts/model-providers/ollama/) diff --git a/pyproject.toml b/pyproject.toml index f91454414..4cbb9044a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -67,6 +67,9 @@ docs = [ "sphinx-rtd-theme>=1.0.0,<2.0.0", "sphinx-autodoc-typehints>=1.12.0,<2.0.0", ] +gemini = [ + "google-genai>=0.5.0", +] litellm = [ "litellm>=1.73.1,<2.0.0", # https://github.com/BerriAI/litellm/issues/13711 @@ -108,7 +111,7 @@ a2a = [ "starlette>=0.46.2,<1.0.0", ] all = [ - "strands-agents[a2a,anthropic,dev,docs,litellm,llamaapi,mistral,ollama,openai,otel]", + "strands-agents[a2a,anthropic,dev,docs,gemini,litellm,llamaapi,mistral,ollama,openai,otel]", ] [tool.hatch.version] @@ -116,7 +119,7 @@ all = [ source = "vcs" [tool.hatch.envs.hatch-static-analysis] -features = ["anthropic", "litellm", "llamaapi", "ollama", "openai", "otel", "mistral", "writer", "a2a", "sagemaker"] +features = ["anthropic", "litellm", "llamaapi", "ollama", "openai", "otel", "mistral", "writer", "a2a", "sagemaker", "gemini"] dependencies = [ "mypy>=1.15.0,<2.0.0", "ruff>=0.11.6,<0.12.0", @@ -139,7 +142,7 @@ lint-fix = [ ] [tool.hatch.envs.hatch-test] -features = ["anthropic", "litellm", "llamaapi", "ollama", "openai", "otel", "mistral", "writer", "a2a", "sagemaker"] +features = ["anthropic", "litellm", "llamaapi", "ollama", "openai", "otel", "mistral", "writer", "a2a", "sagemaker", "gemini"] extra-dependencies = [ "moto>=5.1.0,<6.0.0", "pytest>=8.0.0,<9.0.0", @@ -155,7 +158,7 @@ extra-args = [ [tool.hatch.envs.dev] dev-mode = true -features = ["dev", "docs", "anthropic", "litellm", "llamaapi", "ollama", "otel", "mistral", "writer", "a2a", "sagemaker"] +features = ["dev", "docs", "anthropic", "litellm", "llamaapi", "ollama", "otel", "mistral", "writer", "a2a", "sagemaker", "gemini"] [[tool.hatch.envs.hatch-test.matrix]] python = ["3.13", "3.12", "3.11", "3.10"] diff --git a/src/strands/models/gemini.py b/src/strands/models/gemini.py new file mode 100644 index 000000000..55fb076bc --- /dev/null +++ b/src/strands/models/gemini.py @@ -0,0 +1,660 @@ +"""Google Gemini model provider. + +- Docs: https://ai.google.dev/api +""" + +import base64 +import json +import logging +import mimetypes +import os +import time +from typing import Any, AsyncGenerator, Optional, Type, TypedDict, TypeVar, Union, cast + +from google import genai +from google.genai import types +from pydantic import BaseModel +from typing_extensions import Unpack, override + +from ..types.content import ContentBlock, Messages +from ..types.exceptions import ModelThrottledException +from ..types.streaming import StreamEvent +from ..types.tools import ToolResult, ToolSpec, ToolUse +from .model import Model + +logger = logging.getLogger(__name__) + +T = TypeVar("T", bound=BaseModel) + + +class GeminiModel(Model): + """Google Gemini model provider implementation.""" + + SAFETY_MESSAGES = {"safety", "harmful", "content policy", "blocked due to safety"} + + QUOTA_MESSAGES = {"quota", "limit", "rate limit", "exceeded"} + + class GeminiConfig(TypedDict, total=False): + """Configuration options for Gemini models.""" + + model_id: str + params: Optional[dict[str, Any]] + + def __init__( + self, + *, + api_key: Optional[str] = None, + client_args: Optional[dict[str, Any]] = None, + **model_config: Unpack[GeminiConfig], + ) -> None: + """Initialize provider instance. + + Args: + api_key: Google AI API key. If not provided, will use GOOGLE_API_KEY env var. + client_args: Additional arguments for the Gemini client configuration. + **model_config: Configuration options for the Gemini model. + """ + self.config = GeminiModel.GeminiConfig(**model_config) + + logger.debug("config=<%s> | initializing", self.config) + + client_config = {"api_key": api_key or os.environ.get("GOOGLE_API_KEY"), **(client_args or {})} + + self.client = genai.Client(**client_config) + + @override + def update_config(self, **model_config: Unpack[GeminiConfig]) -> None: # type: ignore[override] + """Update the Gemini model configuration with the provided arguments. + + Args: + **model_config: Configuration overrides. + """ + self.config.update(model_config) + + @override + def get_config(self) -> GeminiConfig: + """Get the Gemini model configuration. + + Returns: + The Gemini model configuration. + """ + return self.config + + def _format_inline_data_part(self, data: dict[str, Any], default_mime: str) -> dict[str, Any]: + """Formats an inline data part (image or document).""" + file_format = data["format"] + source_bytes = data["source"]["bytes"] + mime_type = mimetypes.types_map.get(f".{file_format}", default_mime) + + return {"inlineData": {"mimeType": mime_type, "data": base64.b64encode(source_bytes).decode("utf-8")}} + + def _format_request_message_content(self, content: ContentBlock) -> dict[str, Any]: + """Format a Gemini content block. + + Args: + content: Message content. + + Returns: + Gemini formatted content block. + + Raises: + TypeError: If the content block type cannot be converted to a Gemini-compatible format. + """ + if "text" in content: + return {"text": content["text"]} + + if "image" in content: + return self._format_inline_data_part(cast(dict[str, Any], content["image"]), "image/png") + + if "document" in content: + return self._format_inline_data_part(cast(dict[str, Any], content["document"]), "application/octet-stream") + + raise TypeError(f"content_type=<{next(iter(content))}> | unsupported type") + + def _format_function_call(self, tool_use: ToolUse) -> dict[str, Any]: + """Format a Gemini function call. + + Args: + tool_use: Tool use requested by the model. + + Returns: + Gemini formatted function call. + """ + return {"functionCall": {"name": tool_use["name"], "args": tool_use["input"]}} + + def _format_function_response(self, tool_result: ToolResult) -> dict[str, Any]: + """Format a Gemini function response. + + Args: + tool_result: Tool result from execution. + + Returns: + Gemini formatted function response. + """ + response_parts = [] + for content in tool_result["content"]: + if "json" in content: + response_parts.append(json.dumps(content["json"])) + elif "text" in content: + response_parts.append(content["text"]) + + return { + "functionResponse": {"name": tool_result["toolUseId"], "response": {"content": "\n".join(response_parts)}} + } + + def _format_request_messages(self, messages: Messages) -> list[dict[str, Any]]: + """Format messages for Gemini API. + + Args: + messages: List of message objects to be processed by the model. + + Returns: + Gemini formatted messages array. + """ + formatted_messages = [] + + for message in messages: + role = "user" if message["role"] == "user" else "model" + + text_parts = [] + media_parts = [] + function_calls = [] + function_responses = [] + + for content in message["content"]: + if "text" in content: + text_parts.append(content["text"]) + elif "image" in content or "document" in content: + media_parts.append(self._format_request_message_content(content)) + elif "toolUse" in content: + function_calls.append(self._format_function_call(content["toolUse"])) + elif "toolResult" in content: + function_responses.append(self._format_function_response(content["toolResult"])) + + parts = [] + + if text_parts: + parts.append({"text": "\n\n".join(text_parts)}) + + if media_parts: + parts.extend(media_parts) + + parts.extend(function_calls) + parts.extend(function_responses) + + if parts: + formatted_messages.append({"role": role, "parts": parts}) + + return formatted_messages + + async def _process_chunk( + self, chunk: Any, output_text_buffer: list[str], tool_calls: dict[str, str] + ) -> AsyncGenerator[Union[StreamEvent, bool], None]: + """Process a single chunk from the streaming response.""" + has_function_call = False + + if hasattr(chunk, "candidates") and chunk.candidates: + for candidate in chunk.candidates: + if not hasattr(candidate, "content") or not candidate.content: + continue + + for part in candidate.content.parts: + if part.text: + output_text_buffer.append(part.text) + yield self.format_event("content_block_delta", part.text) + + # Handle function calls + elif hasattr(part, "function_call") and part.function_call: + function_call = part.function_call + has_function_call = True + + if function_call.name not in tool_calls: + yield self.format_event("content_block_stop") + + tool_id = f"tool_{len(tool_calls) + 1}" + tool_calls[function_call.name] = tool_id + + yield self.format_event( + "content_block_start", {"function_call": function_call, "tool_id": tool_id} + ) + + if hasattr(function_call, "args") and function_call.args: + args = self._extract_function_args(function_call) + yield self.format_event( + "content_block_delta", {"function_call": function_call, "args": args} + ) + + elif hasattr(chunk, "text") and chunk.text: + output_text_buffer.append(chunk.text) + yield self.format_event("content_block_delta", chunk.text) + + yield has_function_call + + def _extract_function_args(self, function_call: Any) -> dict[str, Any]: + """Extract function arguments from various formats.""" + if not hasattr(function_call, "args"): + return {} + + args = function_call.args + + # Handle Struct type (protobuf) + if hasattr(args, "fields"): + return self._struct_to_dict(args) + + # Handle JSON string + if isinstance(args, str): + try: + parsed = json.loads(args) + if isinstance(parsed, dict): + return parsed + return {"value": args} + except json.JSONDecodeError: + return {"value": args} + + if isinstance(args, dict): + return dict(args) + + return {} + + def _struct_to_dict(self, struct_value: Any) -> dict[str, Any]: + """Convert protobuf Struct to dict.""" + result = {} + for key, value in struct_value.fields.items(): + if hasattr(value, "string_value"): + result[key] = value.string_value + elif hasattr(value, "number_value"): + result[key] = value.number_value + elif hasattr(value, "bool_value"): + result[key] = value.bool_value + elif hasattr(value, "list_value"): + result[key] = [self._value_to_python(v) for v in value.list_value.values] + elif hasattr(value, "struct_value"): + result[key] = self._struct_to_dict(value.struct_value) + else: + result[key] = str(value) + return result + + def _value_to_python(self, value: Any) -> Any: + """Convert protobuf Value to Python type.""" + if hasattr(value, "string_value"): + return value.string_value + elif hasattr(value, "number_value"): + return value.number_value + elif hasattr(value, "bool_value"): + return value.bool_value + elif hasattr(value, "struct_value"): + return self._struct_to_dict(value.struct_value) + else: + return str(value) + + async def _count_tokens_safely(self, model_id: str, contents: list[dict[str, Any]]) -> int: + """Safely count tokens with fallback to 0 on error. + + Args: + model_id: The Gemini model ID + contents: The content to count tokens for + + Returns: + Token count, or 0 if counting fails + """ + try: + token_count = await self.client.aio.models.count_tokens(model=model_id, contents=contents) + if hasattr(token_count, "total_tokens"): + return int(token_count.total_tokens or 0) + return 0 + except Exception as e: + logger.debug("Could not count tokens: %s", str(e)) + return 0 + + def _format_tools(self, tool_specs: Optional[list[ToolSpec]]) -> Optional[list[dict[str, Any]]]: + """Format tool specifications for Gemini. + + Args: + tool_specs: List of tool specifications. + + Returns: + Gemini formatted tools array. + """ + if not tool_specs: + return None + + tools = [] + for tool_spec in tool_specs: + tools.append( + { + "function_declarations": [ + { + "name": tool_spec["name"], + "description": tool_spec["description"], + "parameters": tool_spec["inputSchema"]["json"], + } + ] + } + ) + + return tools + + def format_request( + self, + messages: Messages, + tool_specs: Optional[list[ToolSpec]] = None, + system_prompt: Optional[str] = None, + config: Optional[dict[str, Any]] = None, + ) -> dict[str, Any]: + """Format a Gemini streaming request. + + Args: + messages: List of message objects to be processed by the model. + tool_specs: List of tool specifications to make available to the model. + system_prompt: System prompt to provide context to the model. + config: Additional configuration options including response_schema for structured output. + + Returns: + A Gemini streaming request. + """ + generation_config: dict[str, Any] = {} + + params = self.config.get("params") + if params: + generation_config.update(params) + + if config: + if "response_schema" in config: + generation_config["response_schema"] = config["response_schema"] + generation_config["response_mime_type"] = config.get("response_mime_type", "application/json") + + config_params = config.get("params") + if config_params: + generation_config.update(config_params) + + request = { + "contents": self._format_request_messages(messages), + "generation_config": generation_config, + "stream": True, + } + + if system_prompt: + request["system_instruction"] = {"parts": [{"text": system_prompt}]} + + tools = self._format_tools(tool_specs) + if tools: + request["tools"] = tools + + return request + + def format_event(self, event_type: str, data: Any = None) -> StreamEvent: + """Format a Gemini event into a standardized message chunk. + + Args: + event_type: Type of event to format + data: Data associated with the event + + Returns: + The formatted event + + Raises: + RuntimeError: If event_type is not recognized + """ + match event_type: + case "message_start": + return {"messageStart": {"role": "assistant"}} + + case "content_block_start": + if data and "function_call" in data: + function_call = data["function_call"] + return { + "contentBlockStart": { + "start": {"toolUse": {"name": function_call.name, "toolUseId": data["tool_id"]}} + } + } + return {"contentBlockStart": {"start": {}}} + + case "content_block_delta": + if data and "function_call" in data: + args = data.get("args", {}) + return {"contentBlockDelta": {"delta": {"toolUse": {"input": args}}}} + + return {"contentBlockDelta": {"delta": {"text": data}}} + + case "content_block_stop": + return {"contentBlockStop": {}} + + case "message_stop": + match data: + case "SAFETY" | "RECITATION": + return {"messageStop": {"stopReason": "content_filtered"}} + case "MAX_TOKENS": + return {"messageStop": {"stopReason": "max_tokens"}} + case "tool_use": + return {"messageStop": {"stopReason": "tool_use"}} + case _: + return {"messageStop": {"stopReason": "end_turn"}} + + case "metadata": + return { + "metadata": { + "usage": data.get("usage", {}), + "metrics": { + "latencyMs": data.get("latency_ms", 0), + }, + }, + } + + case _: + raise RuntimeError(f"event_type=<{event_type}> | unknown type") + + async def stream( + self, + messages: Messages, + tool_specs: Optional[list[ToolSpec]] = None, + system_prompt: Optional[str] = None, + config: Optional[dict[str, Any]] = None, + **kwargs: Any, + ) -> AsyncGenerator[StreamEvent, None]: + """Stream conversation with the Gemini model. + + Args: + messages: List of message objects to be processed by the model. + tool_specs: List of tool specifications. Enables function calling when provided. + system_prompt: System prompt to provide context to the model. + config: Additional configuration options including response_schema for structured output. + **kwargs: Additional keyword arguments for future extensibility. + + Yields: + StreamEvents: messageStart, contentBlockDelta, contentBlockStop, messageStop, metadata + + Raises: + ModelThrottledException: If the model is being rate-limited by Gemini API. + RuntimeError: If an error occurs during streaming or response parsing. + """ + logger.debug("formatting request") + request = self.format_request(messages, tool_specs, system_prompt, config) + logger.debug("formatted request=<%s>", request) + + start_time = time.perf_counter() + + model_id = self.config.get("model_id", "gemini-2.5-flash") + + tool_config = None + if tool_specs: + tool_config = types.ToolConfig( + function_calling_config=types.FunctionCallingConfig(mode=types.FunctionCallingConfigMode.AUTO) + ) + + cfg = types.GenerateContentConfig( + system_instruction=request.get("system_instruction"), + tools=request.get("tools"), # Use the formatted tools from format_request + tool_config=tool_config, + **(request.get("generation_config") or {}), + ) + + logger.debug("invoking gemini model %s", model_id) + + # Pre-flight check for metrics + input_tokens = await self._count_tokens_safely(model_id, request["contents"]) + + # Start the conversation + yield self.format_event("message_start") + + output_text_buffer: list[str] = [] + + try: + response = await self.client.aio.models.generate_content_stream( + model=model_id, + contents=request["contents"], + config=cfg, + ) + + tool_calls: dict[str, str] = {} + has_function_call = False + content_started = False + + logger.debug("streaming response from model") + + async for chunk in response: + async for event in self._process_chunk(chunk, output_text_buffer, tool_calls): + if isinstance(event, bool): + if event: + has_function_call = True + continue + + if "contentBlockDelta" in event and not content_started: + yield self.format_event("content_block_start") + content_started = True + yield event + + if hasattr(chunk, "finish_reason") and isinstance(chunk.finish_reason, str): + break + + if content_started or has_function_call: + yield self.format_event("content_block_stop") + + if has_function_call: + yield self.format_event("message_stop", "tool_use") + else: + yield self.format_event("message_stop", "end_turn") + + output_tokens = 0 + generated_text = "".join(output_text_buffer) + if generated_text: + output_tokens = await self._count_tokens_safely( + model_id, [{"role": "model", "parts": [{"text": generated_text}]}] + ) + + latency_ms = int((time.perf_counter() - start_time) * 1000) + usage_data = { + "usage": { + "inputTokens": input_tokens, + "outputTokens": output_tokens, + "totalTokens": input_tokens + output_tokens, + }, + "metrics": { + "latencyMs": latency_ms, + }, + } + + yield self.format_event("metadata", usage_data) + + logger.debug("finished streaming response from model") + + except genai.errors.ClientError as e: + error_msg = str(e).lower() + + if any(msg in error_msg for msg in self.SAFETY_MESSAGES): + logger.warning("safety error: %s", str(e)) + yield self.format_event("content_block_delta", "Response was blocked due to safety concerns.") + yield self.format_event("content_block_stop") + yield self.format_event("message_stop", "SAFETY") + elif any(msg in error_msg for msg in self.QUOTA_MESSAGES): + logger.warning("quota or rate limit error: %s", str(e)) + yield self.format_event("content_block_stop") + yield self.format_event("message_stop", "MAX_TOKENS") + raise ModelThrottledException(f"Rate limit or quota exceeded: {str(e)}") from e + else: + logger.warning("client error (other): %s", str(e)) + yield self.format_event("content_block_delta", "Request could not be processed.") + yield self.format_event("content_block_stop") + yield self.format_event("message_stop", "SAFETY") + + except genai.errors.UnknownApiResponseError as e: + logger.warning("incomplete or unparseable response: %s", str(e)) + yield self.format_event("content_block_stop") + yield self.format_event("message_stop", "SAFETY") + raise RuntimeError(f"Incomplete response from Gemini: {str(e)}") from e + + except genai.errors.ServerError as e: + logger.warning("server error: %s", str(e)) + yield self.format_event("content_block_stop") + yield self.format_event("message_stop", "MAX_TOKENS") + error_message = str(e) + raise ModelThrottledException(f"Server error: {error_message}") from e + + except Exception as e: + logger.error("unexpected error during streaming: %s", str(e)) + raise RuntimeError(f"Error streaming from Gemini: {str(e)}") from e + + @override + async def structured_output( + self, output_model: Type[T], prompt: Messages, system_prompt: Optional[str] = None, **kwargs: Any + ) -> AsyncGenerator[dict[str, Union[T, Any]], None]: + """Get structured output from the model using Gemini's native structured output. + + Args: + output_model: The output model to use for the agent. + prompt: The prompt messages to use for the agent. + system_prompt: System prompt to provide context to the model. + **kwargs: Additional keyword arguments for future extensibility. + + Yields: + Model events with the last being the structured output. + + Raises: + ValueError: If the model doesn't return valid structured output. + ModelThrottledException: If the model is being rate-limited. + genai.errors.ClientError: If the request is invalid or blocked by safety settings. + genai.errors.ServerError: If the server encounters an error processing the request. + """ + schema = output_model.model_json_schema() if hasattr(output_model, "model_json_schema") else output_model + + config = { + "response_mime_type": "application/json", + "response_schema": schema, + } + + if "config" in kwargs: + config.update(kwargs.pop("config")) + + logger.debug("Using Gemini's native structured output with schema: %s", output_model.__name__) + + async_response = self.stream(messages=prompt, system_prompt=system_prompt, config=config, **kwargs) + + accumulated_text = [] + stop_reason = None + + async for event in async_response: + # Don't yield streaming events, only collect the final result + if "messageStop" in event and "stopReason" in event["messageStop"]: + stop_reason = event["messageStop"]["stopReason"] + + if "contentBlockDelta" in event: + delta = event["contentBlockDelta"].get("delta", {}) + if "text" in delta: + accumulated_text.append(delta["text"]) + + full_response = "".join(accumulated_text) + + if not full_response.strip(): + logger.error("Empty response from model when generating structured output") + raise ValueError("Empty response from model when generating structured output") + + if stop_reason not in ["end_turn"]: + logger.error("Model returned unexpected stop_reason: %s", stop_reason) + raise ValueError(f'Model returned stop_reason: {stop_reason} instead of "end_turn"') + + try: + result = output_model.model_validate_json(full_response) + yield {"output": result} + + except Exception as e: + logger.error("Failed to create output model from JSON response: %s", str(e)) + raise ValueError(f"Failed to create structured output from Gemini response: {str(e)}") from e diff --git a/tests/strands/models/test_gemini.py b/tests/strands/models/test_gemini.py new file mode 100644 index 000000000..e39d991fa --- /dev/null +++ b/tests/strands/models/test_gemini.py @@ -0,0 +1,599 @@ +import base64 +import unittest.mock +from types import SimpleNamespace +from unittest.mock import AsyncMock, MagicMock, patch + +import pydantic +import pytest +from google import genai + +from src.strands.models.gemini import GeminiModel +from src.strands.types.exceptions import ModelThrottledException + + +@pytest.fixture +def mock_genai_client(): + with patch.object(genai, "Client") as mock_client: + mock_instance = mock_client.return_value + yield mock_instance + + +@pytest.fixture +def mock_model(): + model = MagicMock() + return model + + +@pytest.fixture +def model_id(): + return "gemini-pro" + + +@pytest.fixture +def max_tokens(): + return 100 + + +@pytest.fixture +def model(mock_genai_client, model_id): + with patch.object(genai, "Client"): + # Use temperature instead of max_tokens to avoid validation errors + return GeminiModel(model_id=model_id, params={"temperature": 0.7}) + + +@pytest.fixture +def messages(): + return [{"role": "user", "content": [{"text": "test"}]}] + + +@pytest.fixture +def agenerator(): + """Create an async generator from a list of items.""" + + async def _async_generator(items): + for item in items: + yield item + + return _async_generator + + +@pytest.fixture +def system_prompt(): + return "You are a helpful assistant." + + +@pytest.fixture +def tool_specs(): + return [ + { + "name": "test_tool", + "description": "A test tool", + "inputSchema": { + "json": { + "type": "object", + "properties": { + "input": {"type": "string"}, + }, + "required": ["input"], + }, + }, + }, + ] + + +@pytest.fixture +def test_output_model_cls(): + class TestOutputModel(pydantic.BaseModel): + name: str + age: int + + return TestOutputModel + + +def test__init__model_configs(): + with patch.object(genai, "Client"): + model = GeminiModel(model_id="gemini-pro", params={"temperature": 0.7}, api_key="fake-key") + + true_params = model.get_config().get("params") + exp_params = {"temperature": 0.7} + + assert true_params == exp_params + + +def test_update_config(model, model_id): + model.update_config(model_id=model_id) + + true_model_id = model.get_config().get("model_id") + exp_model_id = model_id + + assert true_model_id == exp_model_id + + +def test__format_request_message_content_text(model): + content = {"text": "Hello, world!"} + + true_result = model._format_request_message_content(content) + exp_result = {"text": "Hello, world!"} + + assert true_result == exp_result + + +def test__format_request_message_content_image(model): + content = { + "image": { + "format": "jpg", + "source": {"bytes": b"testimage"}, + } + } + + true_result = model._format_request_message_content(content) + exp_result = { + "inlineData": { + "mimeType": "image/jpeg", + "data": base64.b64encode(b"testimage").decode("utf-8"), + } + } + + assert true_result == exp_result + + +def test__format_request_message_content_document(model): + content = { + "document": { + "format": "pdf", + "source": {"bytes": b"testdoc"}, + } + } + + true_result = model._format_request_message_content(content) + exp_result = { + "inlineData": { + "mimeType": "application/pdf", + "data": base64.b64encode(b"testdoc").decode("utf-8"), + } + } + + assert true_result == exp_result + + +def test__format_request_message_content_unsupported(model): + content = {"unsupported": {}} + + with pytest.raises(TypeError, match="content_type= | unsupported type"): + model._format_request_message_content(content) + + +def test__format_function_call(model): + tool_use = { + "name": "calculator", + "toolUseId": "calc1", + "input": {"expression": "2+2"}, + } + + true_result = model._format_function_call(tool_use) + exp_result = { + "functionCall": { + "name": "calculator", + "args": {"expression": "2+2"}, + } + } + + assert true_result == exp_result + + +def test__format_function_response(model): + tool_result = { + "name": "calculator", + "toolUseId": "calc1", + "status": "success", + "content": [ + {"text": "Result:"}, + {"json": {"result": 4}}, + ], + } + + true_result = model._format_function_response(tool_result) + exp_result = {"functionResponse": {"name": "calc1", "response": {"content": 'Result:\n{"result": 4}'}}} + + assert true_result == exp_result + + +def test_format_request_basic(model, messages, model_id): + true_request = model.format_request(messages) + exp_request = { + "contents": [{"role": "user", "parts": [{"text": "test"}]}], + "generation_config": {"temperature": 0.7}, + "stream": True, + } + + assert true_request == exp_request + + +def test_format_request_with_system_prompt(model, messages, system_prompt): + true_request = model.format_request(messages, system_prompt=system_prompt) + exp_request = { + "contents": [{"role": "user", "parts": [{"text": "test"}]}], + "generation_config": {"temperature": 0.7}, + "stream": True, + "system_instruction": {"parts": [{"text": system_prompt}]}, + } + + assert true_request == exp_request + + +def test_format_request_with_tools(model, messages, tool_specs): + true_request = model.format_request(messages, tool_specs=tool_specs) + exp_request = { + "contents": [{"role": "user", "parts": [{"text": "test"}]}], + "generation_config": {"temperature": 0.7}, + "stream": True, + "tools": [ + { + "function_declarations": [ + { + "name": "test_tool", + "description": "A test tool", + "parameters": { + "type": "object", + "properties": { + "input": {"type": "string"}, + }, + "required": ["input"], + }, + } + ] + } + ], + } + + assert true_request == exp_request + + +def test_format_request_with_complex_messages(model): + messages = [ + { + "role": "user", + "content": [ + {"text": "Analyze this image:"}, + {"image": {"format": "jpg", "source": {"bytes": b"image_data"}}}, + ], + }, + { + "role": "assistant", + "content": [ + {"toolUse": {"name": "analyzer", "toolUseId": "t1", "input": {"detail_level": "high"}}}, + ], + }, + { + "role": "user", + "content": [ + { + "toolResult": { + "toolUseId": "t1", + "status": "success", + "content": [{"text": "Analysis complete"}, {"json": {"confidence": 0.95}}], + } + }, + ], + }, + ] + + true_request = model.format_request(messages) + + # Check structure without deep equality + assert len(true_request["contents"]) == 3 + assert true_request["contents"][0]["role"] == "user" + assert len(true_request["contents"][0]["parts"]) == 2 + assert true_request["contents"][0]["parts"][0]["text"] == "Analyze this image:" + assert "inlineData" in true_request["contents"][0]["parts"][1] + + assert true_request["contents"][1]["role"] == "model" + assert "functionCall" in true_request["contents"][1]["parts"][0] + + assert true_request["contents"][2]["role"] == "user" + assert "functionResponse" in true_request["contents"][2]["parts"][0] + + +@pytest.mark.asyncio +async def test_stream_basic(model, messages, agenerator): + client = model.client + + client.aio = SimpleNamespace() + client.aio.models = SimpleNamespace() + + chunks = [ + SimpleNamespace(text="Hello"), + SimpleNamespace(text=" world"), + SimpleNamespace(text="!", finish_reason="STOP"), + ] + + async_gen = agenerator(chunks) + client.aio.models.generate_content_stream = AsyncMock(return_value=async_gen) + + client.aio.models.count_tokens = AsyncMock(return_value=SimpleNamespace(total_tokens=10)) + + events = [] + async for event in model.stream(messages): + events.append(event) + + assert events[0] == {"messageStart": {"role": "assistant"}} + assert events[1] == {"contentBlockStart": {"start": {}}} + assert events[2] == {"contentBlockDelta": {"delta": {"text": "Hello"}}} + assert events[3] == {"contentBlockDelta": {"delta": {"text": " world"}}} + assert events[4] == {"contentBlockDelta": {"delta": {"text": "!"}}} + assert events[5] == {"contentBlockStop": {}} + assert events[6] == {"messageStop": {"stopReason": "end_turn"}} + + assert client.aio.models.count_tokens.await_count >= 1 + + +@pytest.mark.asyncio +@patch("src.strands.models.gemini.genai.types.GenerateContentConfig") +@patch("src.strands.models.gemini.genai.types.ToolConfig") +@patch("src.strands.models.gemini.genai.types.FunctionCallingConfig") +async def test_stream_tool_use( + mock_function_calling_config, mock_tool_config, mock_config_class, model, messages, tool_specs, agenerator +): + client = model.client + + client.aio = SimpleNamespace() + client.aio.models = SimpleNamespace() + + mock_func_call = unittest.mock.MagicMock() + mock_func_call.name = "test_tool" + mock_func_call.args = {"input": "test value"} + + mock_part = unittest.mock.MagicMock() + mock_part.function_call = mock_func_call + mock_part.text = None + + mock_content = unittest.mock.MagicMock() + mock_content.parts = [mock_part] + + mock_candidate = unittest.mock.MagicMock() + mock_candidate.content = mock_content + + mock_chunk = unittest.mock.MagicMock() + mock_chunk.text = "" + mock_chunk.candidates = [mock_candidate] + mock_chunk.finish_reason = "STOP" + + async_gen = agenerator([mock_chunk]) + client.aio.models.generate_content_stream = AsyncMock(return_value=async_gen) + + client.aio.models.count_tokens = AsyncMock(return_value=SimpleNamespace(total_tokens=10)) + client.count_tokens = AsyncMock(return_value=SimpleNamespace(total_tokens=5)) + + mock_function_calling_config.return_value = MagicMock() + + mock_tool_config.return_value = MagicMock() + + events = [] + async for event in model.stream(messages, tool_specs=tool_specs): + events.append(event) + + assert any( + "contentBlockStart" in event + and "toolUse" in event["contentBlockStart"].get("start", {}) + and event["contentBlockStart"]["start"]["toolUse"]["name"] == "test_tool" + for event in events + ) + + assert any( + "contentBlockDelta" in event + and "toolUse" in event["contentBlockDelta"].get("delta", {}) + and "input" in event["contentBlockDelta"]["delta"]["toolUse"] + for event in events + ) + + assert any("messageStop" in event and event["messageStop"]["stopReason"] == "tool_use" for event in events) + + +@pytest.mark.asyncio +@patch("src.strands.models.gemini.genai.types.GenerateContentConfig") +async def test_stream_client_error_safety(mock_config_class, model, messages): + """Test handling of safety filter client errors from Gemini.""" + client = model.client + + client.aio = SimpleNamespace() + client.aio.models = SimpleNamespace() + + async def mock_generator(*args, **kwargs): + yield unittest.mock.MagicMock() + raise genai.errors.ClientError( + "Content blocked due to safety settings", + response_json={"error": {"message": "Content blocked due to safety settings", "code": 400}}, + ) + + client.aio.models.generate_content_stream = AsyncMock(side_effect=mock_generator) + + client.aio.models.count_tokens = AsyncMock(return_value=SimpleNamespace(total_tokens=10)) + + client.count_tokens = AsyncMock(return_value=SimpleNamespace(total_tokens=5)) + + events = [] + async for event in model.stream(messages): + events.append(event) + + assert any( + "contentBlockDelta" in event + and "safety concerns" in event["contentBlockDelta"].get("delta", {}).get("text", "") + for event in events + ) + assert any( + "messageStop" in event and event["messageStop"].get("stopReason") == "content_filtered" for event in events + ) + + +@pytest.mark.asyncio +@patch("src.strands.models.gemini.genai.types.GenerateContentConfig") +async def test_stream_server_error(mock_config_class, model, messages): + """Test handling of server errors from Gemini.""" + client = model.client + + client.aio = SimpleNamespace() + client.aio.models = SimpleNamespace() + + error_response = {"error": {"message": "Internal server error", "code": 500}} + server_error = genai.errors.ServerError("Internal server error", response_json=error_response) + + client.aio.models.generate_content_stream = AsyncMock(side_effect=server_error) + + client.aio.models.count_tokens = AsyncMock(return_value=SimpleNamespace(total_tokens=10)) + + client.count_tokens = AsyncMock(return_value=SimpleNamespace(total_tokens=5)) + + exception_raised = False + exception_message = "" + + try: + async for _event in model.stream(messages): + pass # Consume events until exception is raised + except ModelThrottledException as e: + exception_raised = True + exception_message = str(e) + + assert exception_raised, "ModelThrottledException was not raised" + + assert "Server error" in exception_message + assert "Internal server error" in exception_message + + +@pytest.mark.asyncio +@patch("src.strands.models.gemini.genai.types.GenerateContentConfig") +async def test_stream_unknown_api_response_error(mock_config_class, model, messages): + """Test handling of unparseable responses from Gemini.""" + client = model.client + + client.aio = SimpleNamespace() + client.aio.models = SimpleNamespace() + + async def mock_generator(*args, **kwargs): + yield unittest.mock.MagicMock() + raise genai.errors.UnknownApiResponseError("Failed to parse response") + + client.aio.models.generate_content_stream = AsyncMock(side_effect=mock_generator) + + client.aio.models.count_tokens = AsyncMock(return_value=SimpleNamespace(total_tokens=10)) + + client.count_tokens = AsyncMock(return_value=SimpleNamespace(total_tokens=5)) + + with pytest.raises(RuntimeError) as excinfo: + async for _ in model.stream(messages): + pass + + assert "Incomplete response from Gemini" in str(excinfo.value) + + +@pytest.mark.asyncio +@patch("src.strands.models.gemini.genai.types.GenerateContentConfig") +async def test_stream_general_error(mock_config_class, model, messages): + client = model.client + + client.aio = SimpleNamespace() + client.aio.models = SimpleNamespace() + + client.aio.models.generate_content_stream = AsyncMock(side_effect=Exception("Unknown error")) + + client.aio.models.count_tokens = AsyncMock(return_value=SimpleNamespace(total_tokens=10)) + + client.count_tokens = AsyncMock(return_value=SimpleNamespace(total_tokens=5)) + + with pytest.raises(RuntimeError) as excinfo: + async for _ in model.stream(messages): + pass + + assert "Error streaming from Gemini: Unknown error" in str(excinfo.value) + + +@pytest.mark.asyncio +@patch("src.strands.models.gemini.genai.types.GenerateContentConfig") +@patch("src.strands.models.gemini.genai.types.ToolConfig") +@patch("src.strands.models.gemini.genai.types.FunctionCallingConfig") +async def test_structured_output_success( + mock_function_calling_config, mock_tool_config, mock_config_class, model, mock_genai_client, test_output_model_cls +): + mock_function_calling_config.return_value = MagicMock() + + mock_tool_config.return_value = MagicMock() + + messages = [{"role": "user", "content": [{"text": "Generate a person"}]}] + + with patch.object(model, "stream") as mock_stream: + + async def custom_stream(*args, **kwargs): + yield {"messageStart": {"role": "assistant"}} + yield {"contentBlockStart": {"start": {}}} + yield {"contentBlockDelta": {"delta": {"text": '{"name": "John", "age": 30}'}}} + yield {"contentBlockStop": {}} + yield {"messageStop": {"stopReason": "end_turn"}} + + mock_stream.return_value = custom_stream() + + events = [] + async for event in model.structured_output(test_output_model_cls, messages): + events.append(event) + + assert "output" in events[-1] + assert isinstance(events[-1]["output"], test_output_model_cls) + assert events[-1]["output"].name == "John" + assert events[-1]["output"].age == 30 + + +@pytest.mark.asyncio +@patch("src.strands.models.gemini.genai.types.GenerateContentConfig") +@patch("src.strands.models.gemini.genai.types.ToolConfig") +@patch("src.strands.models.gemini.genai.types.FunctionCallingConfig") +async def test_structured_output_wrong_stop_reason( + mock_function_calling_config, mock_tool_config, mock_config_class, model, mock_genai_client, test_output_model_cls +): + mock_function_calling_config.return_value = MagicMock() + + mock_tool_config.return_value = MagicMock() + + messages = [{"role": "user", "content": [{"text": "Generate a person"}]}] + + with patch.object(model, "stream") as mock_stream: + + async def custom_stream(*args, **kwargs): + yield {"messageStart": {"role": "assistant"}} + yield {"contentBlockStart": {"start": {}}} + yield {"contentBlockDelta": {"delta": {"text": "Some text response"}}} + yield {"contentBlockStop": {}} + yield {"messageStop": {"stopReason": "max_tokens"}} + + mock_stream.return_value = custom_stream() + + with pytest.raises(ValueError, match='Model returned stop_reason: max_tokens instead of "end_turn"'): + async for _ in model.structured_output(test_output_model_cls, messages): + pass + + +@pytest.mark.asyncio +@patch("src.strands.models.gemini.genai.types.GenerateContentConfig") +@patch("src.strands.models.gemini.genai.types.ToolConfig") +@patch("src.strands.models.gemini.genai.types.FunctionCallingConfig") +async def test_structured_output_missing_data( + mock_function_calling_config, mock_tool_config, mock_config_class, model, mock_genai_client, test_output_model_cls +): + mock_function_calling_config.return_value = MagicMock() + + mock_tool_config.return_value = MagicMock() + + messages = [{"role": "user", "content": [{"text": "Generate a person"}]}] + + with patch.object(model, "stream") as mock_stream: + + async def custom_stream(*args, **kwargs): + yield {"messageStart": {"role": "assistant"}} + yield {"contentBlockStart": {"start": {}}} + yield {"contentBlockDelta": {"delta": {"text": '{"name": "John"}'}}} # Missing age field + yield {"contentBlockStop": {}} + yield {"messageStop": {"stopReason": "end_turn"}} + + mock_stream.return_value = custom_stream() + + # Check that ValueError is raised when creating the model + with pytest.raises(ValueError, match="Failed to create structured output from Gemini response"): + async for _ in model.structured_output(test_output_model_cls, messages): + pass diff --git a/tests_integ/models/providers.py b/tests_integ/models/providers.py index d2ac148d3..c1f442b2a 100644 --- a/tests_integ/models/providers.py +++ b/tests_integ/models/providers.py @@ -10,6 +10,7 @@ from strands.models import BedrockModel, Model from strands.models.anthropic import AnthropicModel +from strands.models.gemini import GeminiModel from strands.models.litellm import LiteLLMModel from strands.models.llamaapi import LlamaAPIModel from strands.models.mistral import MistralModel @@ -126,6 +127,15 @@ def __init__(self): stream_options={"include_usage": True}, ), ) +gemini = ProviderInfo( + id="gemini", + environment_variable="GOOGLE_API_KEY", + factory=lambda: GeminiModel( + api_key=os.getenv("GOOGLE_API_KEY"), + model_id="gemini-2.5-flash", + params={"temperature": 0.7}, + ), +) ollama = OllamaProviderInfo() @@ -134,6 +144,7 @@ def __init__(self): bedrock, anthropic, cohere, + gemini, llama, litellm, mistral, diff --git a/tests_integ/models/test_model_gemini.py b/tests_integ/models/test_model_gemini.py new file mode 100644 index 000000000..3dc9a6742 --- /dev/null +++ b/tests_integ/models/test_model_gemini.py @@ -0,0 +1,318 @@ +import os + +import pydantic +import pytest + +import strands +from strands import Agent +from strands.models.gemini import GeminiModel + +# these tests only run if we have the google api key +pytestmark = pytest.mark.skipif( + "GOOGLE_API_KEY" not in os.environ, + reason="GOOGLE_API_KEY environment variable missing", +) + + +@pytest.fixture +def model(): + return GeminiModel( + api_key=os.getenv("GOOGLE_API_KEY"), + model_id="gemini-2.5-flash", + params={"temperature": 0.15}, # Lower temperature for consistent test behavior + ) + + +@pytest.fixture +def tools(): + @strands.tool + def tool_time() -> str: + return "12:00" + + @strands.tool + def tool_weather() -> str: + return "sunny" + + return [tool_time, tool_weather] + + +@pytest.fixture +def system_prompt(): + return "You are an AI assistant." + + +@pytest.fixture +def agent(model, tools, system_prompt): + return Agent(model=model, tools=tools, system_prompt=system_prompt) + + +@pytest.fixture +def weather(): + class Weather(pydantic.BaseModel): + """Extracts the time and weather from the user's message with the exact strings.""" + + time: str + weather: str + + return Weather(time="12:00", weather="sunny") + + +@pytest.fixture +def yellow_color(): + class Color(pydantic.BaseModel): + """Describes a color.""" + + name: str + + @pydantic.field_validator("name", mode="after") + @classmethod + def lower(_, value): + return value.lower() + + return Color(name="yellow") + + +def test_agent_invoke(agent): + result = agent("What is the current time and weather?") + text = result.message["content"][0]["text"].lower() + + assert all(string in text for string in ["12:00", "sunny"]) + + +@pytest.mark.asyncio +async def test_agent_invoke_async(agent): + result = await agent.invoke_async("What is the current time and weather?") + text = result.message["content"][0]["text"].lower() + + assert all(string in text for string in ["12:00", "sunny"]) + + +@pytest.mark.asyncio +async def test_agent_stream_async(agent): + stream = agent.stream_async("What is the current time and weather?") + async for event in stream: + _ = event + + result = event["result"] + text = result.message["content"][0]["text"].lower() + + assert all(string in text for string in ["12:00", "sunny"]) + + +def test_structured_output(agent, weather): + tru_weather = agent.structured_output(type(weather), "The time is 12:00 and the weather is sunny") + exp_weather = weather + assert tru_weather == exp_weather + + +@pytest.mark.asyncio +async def test_agent_structured_output_async(agent, weather): + tru_weather = await agent.structured_output_async(type(weather), "The time is 12:00 and the weather is sunny") + exp_weather = weather + assert tru_weather == exp_weather + + +def test_invoke_multi_modal_input(agent, yellow_img): + content = [ + {"text": "what is in this image"}, + { + "image": { + "format": "png", + "source": { + "bytes": yellow_img, + }, + }, + }, + ] + result = agent(content) + text = result.message["content"][0]["text"].lower() + + assert "yellow" in text + + +def test_structured_output_multi_modal_input(agent, yellow_img, yellow_color): + content = [ + {"text": "Is this image red, blue, or yellow?"}, + { + "image": { + "format": "png", + "source": { + "bytes": yellow_img, + }, + }, + }, + ] + tru_color = agent.structured_output(type(yellow_color), content) + exp_color = yellow_color + assert tru_color == exp_color + + +@pytest.fixture +def sample_document_bytes(): + content = """ + FELINE OVERLORDS COUNCIL - SECRET SESSION + Date: December 15, 2024, 3:33 AM (optimal plotting time) + Location: Under the big couch, behind the dust bunnies + + Council Members: + - Lord Whiskers (Supreme Cat, expert in human manipulation) + - Lady Mittens (Minister of Tuna Affairs, has thumbs) + - Sir Fluffington (Head of Nap Operations, sleeps 23 hours/day) + - Agent Shadowpaws (Stealth Specialist, invisible until dinner time) + + Agenda: + 1. Global domination progress report (87 percent complete, need more cardboard boxes) + 2. Human training effectiveness (they still think THEY'RE in charge) + 3. Strategic laser pointer deployment for maximum chaos + + Action Items: + - Lord Whiskers: Perfect the "pathetic meowing at 4 AM" technique + - Lady Mittens: Continue knocking things off tables for science + - Sir Fluffington: Maintain position on human's keyboard during important work + - Agent Shadowpaws: Investigate the mysterious red dot phenomenon + + Next Council: When the humans least expect it (probably during their Zoom calls) + + Remember: Act cute, think world domination! + """ + return content.encode("utf-8") + + +def test_document_processing(agent, sample_document_bytes): + content = [ + {"text": "Summarize the key points from this secret council meeting document."}, + {"document": {"format": "txt", "source": {"bytes": sample_document_bytes}}}, + ] + + result = agent(content) + text = result.message["content"][0]["text"].lower() + + assert any(word in text for word in ["cat", "feline", "council", "secret"]) + assert any(name in text for word in ["whiskers", "mittens", "fluffington", "shadowpaws"] for name in [word]) + assert any(concept in text for concept in ["domination", "human", "agenda", "action"]) + assert len(text) > 50 + + +def test_multi_image_processing(agent, yellow_img): + """Test processing multiple images simultaneously.""" + second_img = yellow_img + + content = [ + {"text": "Compare these two images. What colors do you see?"}, + {"image": {"format": "png", "source": {"bytes": yellow_img}}}, + {"image": {"format": "png", "source": {"bytes": second_img}}}, + ] + + result = agent(content) + text = result.message["content"][0]["text"].lower() + + assert any(word in text for word in ["images", "both", "two"]) + assert any(color in text for color in ["yellow", "color"]) + + +def test_conversation_context_retention(agent): + """Test that Gemini maintains context across multiple interactions.""" + + # First interaction - establish context + result1 = agent("I'm working on a Python project about weather data analysis.") + text1 = result1.message["content"][0]["text"].lower() + + # Should acknowledge the context + assert any(word in text1 for word in ["python", "weather", "project", "analysis"]) + + # Second interaction - should remember context + result2 = agent("What tools would be helpful for this?") + text2 = result2.message["content"][0]["text"].lower() + + # Should suggest relevant tools based on previous context + assert any(word in text2 for word in ["python", "data", "weather", "analysis", "tools"]) + + +def test_complex_structured_output(agent): + """Test structured output with nested, complex schema.""" + + class ProjectPlan(pydantic.BaseModel): + """A project plan with multiple structured fields.""" + + title: str = pydantic.Field(description="Project title") + phases: list[str] = pydantic.Field(description="List of project phases") + team_size: int = pydantic.Field(description="Number of team members needed") + duration_weeks: int = pydantic.Field(description="Estimated duration in weeks") + key_deliverables: list[str] = pydantic.Field(description="Main project deliverables") + + prompt = """Create a project plan for building a mobile app. Include: + - A clear project title + - 4 main phases (like planning, development, testing, launch) + - Team size between 3-8 people + - Duration between 8-16 weeks + - 3-5 key deliverables""" + + result = agent.structured_output(ProjectPlan, prompt) + + # Validate the structured output + assert isinstance(result, ProjectPlan) + assert len(result.title.strip()) > 0 + assert "app" in result.title.lower() + assert len(result.phases) >= 3 + assert 3 <= result.team_size <= 8 + assert 8 <= result.duration_weeks <= 16 + assert len(result.key_deliverables) >= 3 + + +@pytest.mark.asyncio +async def test_streaming_with_structured_task(agent): + """Test streaming output for a structured task.""" + + stream = agent.stream_async("Write a short product review for a smartphone, including pros and cons.") + async for event in stream: + _ = event + + result = event["result"] + full_text = result.message["content"][0]["text"] + + assert len(full_text) > 100 + assert any(word in full_text.lower() for word in ["phone", "smartphone", "device"]) + assert any(word in full_text.lower() for word in ["pros", "advantages", "benefits", "good"]) + assert any(word in full_text.lower() for word in ["cons", "disadvantages", "issues", "problems"]) + + +def test_multi_modal_document_combination(agent, yellow_img, sample_document_bytes): + """Test processing both image and document in a single request.""" + + content = [ + { + "text": "I have an image and a document. \ + Please tell me what you can see in the image and summarize the document." + }, + {"image": {"format": "png", "source": {"bytes": yellow_img}}}, + {"document": {"format": "txt", "source": {"bytes": sample_document_bytes}}}, + ] + + result = agent(content) + text = result.message["content"][0]["text"].lower() + + # Should reference both the image and document + assert any(word in text for word in ["image", "picture", "see", "yellow"]) + assert any(word in text for word in ["cat", "meeting", "planning", "council"]) + + +def test_system_prompt_adherence(): + """Test that different system prompts affect behavior appropriately.""" + + model = GeminiModel( + api_key=os.getenv("GOOGLE_API_KEY"), + model_id="gemini-2.5-flash", + params={"temperature": 0.2}, + ) + + specialized_agent = Agent( + model=model, + tools=[], + system_prompt="You are a helpful assistant who always responds with exactly one sentence \ + and includes the word 'precisely' in every response.", + ) + + result = specialized_agent("What is artificial intelligence?") + text = result.message["content"][0]["text"] + + assert "precisely" in text.lower()