Skip to content

Commit 9c2aca3

Browse files
committed
test: add S3 filtering tests for all model providers
Add unit tests to verify S3 location content blocks are filtered with warning for non-Bedrock providers: - test_gemini.py: tests for Gemini S3 filtering - test_ollama.py: tests for Ollama S3 filtering - test_mistral.py: tests for Mistral S3 filtering - test_llamaapi.py: tests for LlamaAPI S3 filtering - test_llamacpp.py: tests for llama.cpp S3 filtering - test_writer.py: tests for Writer S3 filtering Each test verifies that: 1. Image/document with S3 source is removed from formatted request 2. Text content is preserved 3. Warning message is logged
1 parent e17e8bb commit 9c2aca3

File tree

6 files changed

+358
-0
lines changed

6 files changed

+358
-0
lines changed

tests/strands/models/test_gemini.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -878,3 +878,60 @@ def test_init_with_both_client_and_client_args_raises_error():
878878

879879
with pytest.raises(ValueError, match="Only one of 'client' or 'client_args' should be provided"):
880880
GeminiModel(client=mock_client, client_args={"api_key": "test"}, model_id="test-model")
881+
882+
883+
def test_format_request_filters_s3_source_image(model, caplog):
884+
"""Test that images with S3 sources are filtered out with warning."""
885+
caplog.set_level(logging.WARNING, logger="strands.models.gemini")
886+
887+
messages = [
888+
{
889+
"role": "user",
890+
"content": [
891+
{"text": "look at this image"},
892+
{
893+
"image": {
894+
"format": "png",
895+
"source": {"s3Location": {"uri": "s3://my-bucket/image.png"}},
896+
},
897+
},
898+
],
899+
},
900+
]
901+
902+
request = model.format_request(messages)
903+
904+
# Image with S3 source should be filtered, text should remain
905+
formatted_content = request["contents"][0]["parts"]
906+
assert len(formatted_content) == 1
907+
assert "text" in formatted_content[0]
908+
assert "S3 sources are not supported by Gemini" in caplog.text
909+
910+
911+
def test_format_request_filters_s3_source_document(model, caplog):
912+
"""Test that documents with S3 sources are filtered out with warning."""
913+
caplog.set_level(logging.WARNING, logger="strands.models.gemini")
914+
915+
messages = [
916+
{
917+
"role": "user",
918+
"content": [
919+
{"text": "analyze this document"},
920+
{
921+
"document": {
922+
"format": "pdf",
923+
"name": "report.pdf",
924+
"source": {"s3Location": {"uri": "s3://my-bucket/report.pdf"}},
925+
},
926+
},
927+
],
928+
},
929+
]
930+
931+
request = model.format_request(messages)
932+
933+
# Document with S3 source should be filtered, text should remain
934+
formatted_content = request["contents"][0]["parts"]
935+
assert len(formatted_content) == 1
936+
assert "text" in formatted_content[0]
937+
assert "S3 sources are not supported by Gemini" in caplog.text

tests/strands/models/test_llamaapi.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
# Copyright (c) Meta Platforms, Inc. and affiliates
2+
import logging
23
import unittest.mock
34

45
import pytest
@@ -414,3 +415,62 @@ async def test_tool_choice_none_no_warning(model, messages, captured_warnings, a
414415
await alist(response)
415416

416417
assert len(captured_warnings) == 0
418+
419+
420+
def test_format_request_filters_s3_source_image(model, caplog):
421+
"""Test that images with S3 sources are filtered out with warning."""
422+
caplog.set_level(logging.WARNING, logger="strands.models.llamaapi")
423+
424+
messages = [
425+
{
426+
"role": "user",
427+
"content": [
428+
{"text": "look at this image"},
429+
{
430+
"image": {
431+
"format": "png",
432+
"source": {"s3Location": {"uri": "s3://my-bucket/image.png"}},
433+
},
434+
},
435+
],
436+
},
437+
]
438+
439+
request = model.format_request(messages)
440+
441+
# Image with S3 source should be filtered, text should remain
442+
formatted_messages = request["messages"]
443+
user_content = formatted_messages[0]["content"]
444+
assert len(user_content) == 1
445+
assert user_content[0]["type"] == "text"
446+
assert "S3 sources are not supported by LlamaAPI" in caplog.text
447+
448+
449+
def test_format_request_filters_s3_source_document(model, caplog):
450+
"""Test that documents with S3 sources are filtered out with warning."""
451+
caplog.set_level(logging.WARNING, logger="strands.models.llamaapi")
452+
453+
messages = [
454+
{
455+
"role": "user",
456+
"content": [
457+
{"text": "analyze this document"},
458+
{
459+
"document": {
460+
"format": "pdf",
461+
"name": "report.pdf",
462+
"source": {"s3Location": {"uri": "s3://my-bucket/report.pdf"}},
463+
},
464+
},
465+
],
466+
},
467+
]
468+
469+
request = model.format_request(messages)
470+
471+
# Document with S3 source should be filtered, text should remain
472+
formatted_messages = request["messages"]
473+
user_content = formatted_messages[0]["content"]
474+
assert len(user_content) == 1
475+
assert user_content[0]["type"] == "text"
476+
assert "S3 sources are not supported by LlamaAPI" in caplog.text

tests/strands/models/test_llamacpp.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import base64
44
import json
5+
import logging
56
from unittest.mock import AsyncMock, patch
67

78
import httpx
@@ -637,3 +638,64 @@ def test_format_messages_with_mixed_content() -> None:
637638
assert result[0]["content"][2]["type"] == "image_url"
638639
assert "image_url" in result[0]["content"][2]
639640
assert result[0]["content"][2]["image_url"]["url"].startswith("data:image/jpeg;base64,")
641+
642+
643+
def test_format_request_filters_s3_source_image(caplog) -> None:
644+
"""Test that images with S3 sources are filtered out with warning."""
645+
model = LlamaCppModel()
646+
caplog.set_level(logging.WARNING, logger="strands.models.llamacpp")
647+
648+
messages = [
649+
{
650+
"role": "user",
651+
"content": [
652+
{"text": "look at this image"},
653+
{
654+
"image": {
655+
"format": "png",
656+
"source": {"s3Location": {"uri": "s3://my-bucket/image.png"}},
657+
},
658+
},
659+
],
660+
},
661+
]
662+
663+
request = model._format_request(messages)
664+
665+
# Image with S3 source should be filtered, text should remain
666+
formatted_messages = request["messages"]
667+
user_content = formatted_messages[0]["content"]
668+
assert len(user_content) == 1
669+
assert user_content[0]["type"] == "text"
670+
assert "S3 sources are not supported by llama.cpp" in caplog.text
671+
672+
673+
def test_format_request_filters_s3_source_document(caplog) -> None:
674+
"""Test that documents with S3 sources are filtered out with warning."""
675+
model = LlamaCppModel()
676+
caplog.set_level(logging.WARNING, logger="strands.models.llamacpp")
677+
678+
messages = [
679+
{
680+
"role": "user",
681+
"content": [
682+
{"text": "analyze this document"},
683+
{
684+
"document": {
685+
"format": "pdf",
686+
"name": "report.pdf",
687+
"source": {"s3Location": {"uri": "s3://my-bucket/report.pdf"}},
688+
},
689+
},
690+
],
691+
},
692+
]
693+
694+
request = model._format_request(messages)
695+
696+
# Document with S3 source should be filtered, text should remain
697+
formatted_messages = request["messages"]
698+
user_content = formatted_messages[0]["content"]
699+
assert len(user_content) == 1
700+
assert user_content[0]["type"] == "text"
701+
assert "S3 sources are not supported by llama.cpp" in caplog.text

tests/strands/models/test_mistral.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import logging
12
import unittest.mock
23

34
import pydantic
@@ -592,3 +593,62 @@ def test_update_config_validation_warns_on_unknown_keys(model, captured_warnings
592593
assert len(captured_warnings) == 1
593594
assert "Invalid configuration parameters" in str(captured_warnings[0].message)
594595
assert "wrong_param" in str(captured_warnings[0].message)
596+
597+
598+
def test_format_request_filters_s3_source_image(model, caplog):
599+
"""Test that images with S3 sources are filtered out with warning."""
600+
caplog.set_level(logging.WARNING, logger="strands.models.mistral")
601+
602+
messages = [
603+
{
604+
"role": "user",
605+
"content": [
606+
{"text": "look at this image"},
607+
{
608+
"image": {
609+
"format": "png",
610+
"source": {"s3Location": {"uri": "s3://my-bucket/image.png"}},
611+
},
612+
},
613+
],
614+
},
615+
]
616+
617+
request = model.format_request(messages)
618+
619+
# Image with S3 source should be filtered, text should remain
620+
formatted_messages = request["messages"]
621+
user_content = formatted_messages[0]["content"]
622+
assert len(user_content) == 1
623+
assert user_content[0]["type"] == "text"
624+
assert "S3 sources are not supported by Mistral" in caplog.text
625+
626+
627+
def test_format_request_filters_s3_source_document(model, caplog):
628+
"""Test that documents with S3 sources are filtered out with warning."""
629+
caplog.set_level(logging.WARNING, logger="strands.models.mistral")
630+
631+
messages = [
632+
{
633+
"role": "user",
634+
"content": [
635+
{"text": "analyze this document"},
636+
{
637+
"document": {
638+
"format": "pdf",
639+
"name": "report.pdf",
640+
"source": {"s3Location": {"uri": "s3://my-bucket/report.pdf"}},
641+
},
642+
},
643+
],
644+
},
645+
]
646+
647+
request = model.format_request(messages)
648+
649+
# Document with S3 source should be filtered, text should remain
650+
formatted_messages = request["messages"]
651+
user_content = formatted_messages[0]["content"]
652+
assert len(user_content) == 1
653+
assert user_content[0]["type"] == "text"
654+
assert "S3 sources are not supported by Mistral" in caplog.text

tests/strands/models/test_ollama.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import json
2+
import logging
23
import unittest.mock
34

45
import pydantic
@@ -559,3 +560,61 @@ def test_update_config_validation_warns_on_unknown_keys(model, captured_warnings
559560
assert len(captured_warnings) == 1
560561
assert "Invalid configuration parameters" in str(captured_warnings[0].message)
561562
assert "wrong_param" in str(captured_warnings[0].message)
563+
564+
565+
def test_format_request_filters_s3_source_image(model, caplog):
566+
"""Test that images with S3 sources are filtered out with warning."""
567+
caplog.set_level(logging.WARNING, logger="strands.models.ollama")
568+
569+
messages = [
570+
{
571+
"role": "user",
572+
"content": [
573+
{"text": "look at this image"},
574+
{
575+
"image": {
576+
"format": "png",
577+
"source": {"s3Location": {"uri": "s3://my-bucket/image.png"}},
578+
},
579+
},
580+
],
581+
},
582+
]
583+
584+
request = model.format_request(messages)
585+
586+
# Image with S3 source should be filtered, text should remain
587+
formatted_messages = request["messages"]
588+
user_message = formatted_messages[0]
589+
assert user_message["content"] == "look at this image"
590+
assert "images" not in user_message or user_message.get("images") == []
591+
assert "S3 sources are not supported by Ollama" in caplog.text
592+
593+
594+
def test_format_request_filters_s3_source_document(model, caplog):
595+
"""Test that documents with S3 sources are filtered out with warning."""
596+
caplog.set_level(logging.WARNING, logger="strands.models.ollama")
597+
598+
messages = [
599+
{
600+
"role": "user",
601+
"content": [
602+
{"text": "analyze this document"},
603+
{
604+
"document": {
605+
"format": "pdf",
606+
"name": "report.pdf",
607+
"source": {"s3Location": {"uri": "s3://my-bucket/report.pdf"}},
608+
},
609+
},
610+
],
611+
},
612+
]
613+
614+
request = model.format_request(messages)
615+
616+
# Document with S3 source should be filtered, text should remain
617+
formatted_messages = request["messages"]
618+
user_message = formatted_messages[0]
619+
assert user_message["content"] == "analyze this document"
620+
assert "S3 sources are not supported by Ollama" in caplog.text

0 commit comments

Comments
 (0)