Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 12 additions & 10 deletions src/celeste/modalities/text/protocols/openresponses/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,16 +87,18 @@ def _init_request(self, inputs: TextInput) -> dict[str, Any]:
else [inputs.document]
)
for doc in docs:
file_data = build_document_data_url(doc)
content.append(
{
"type": "input_file",
"filename": doc.path.rsplit("/", 1)[-1]
if doc.path
else "document",
"file_data": file_data,
}
)
if doc.url and not doc.data and not doc.path:
content.append({"type": "input_file", "file_url": doc.url})
else:
content.append(
{
"type": "input_file",
"filename": doc.path.rsplit("/", 1)[-1]
if doc.path
else "document",
"file_data": build_document_data_url(doc),
}
)

content.append({"type": "input_text", "text": inputs.prompt or ""})
return {"input": [{"role": "user", "content": content}]}
Expand Down
5 changes: 4 additions & 1 deletion src/celeste/providers/google/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,10 @@
def build_media_part(artifact: Artifact) -> dict[str, Any]:
"""Convert any media artifact to a Gemini inline_data/file_data part."""
if artifact.url:
return {"file_data": {"file_uri": artifact.url}}
part: dict[str, Any] = {"file_data": {"file_uri": artifact.url}}
if artifact.mime_type:
part["file_data"]["mime_type"] = artifact.mime_type.value
return part
media_bytes = artifact.get_bytes()
b64 = base64.b64encode(media_bytes).decode("utf-8")
mime = artifact.mime_type or detect_mime_type(media_bytes)
Expand Down
Binary file modified tests/integration_tests/text/assets/test_document.pdf
Binary file not shown.
54 changes: 54 additions & 0 deletions tests/unit_tests/test_text_modality_analyze_document.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,3 +131,57 @@ def test_mistral_init_request_includes_document_url_block() -> None:
assert content[0]["type"] == "document_url"
assert content[0]["document_url"].startswith("data:application/pdf;base64,")
assert content[-1] == {"type": "text", "text": "Summarize this document"}


def test_openai_init_request_uses_file_url_for_url_document() -> None:
model = Model(
id="gpt-4o",
provider=Provider.OPENAI,
display_name="GPT-4o",
operations={Modality.TEXT: {Operation.GENERATE, Operation.ANALYZE}},
)
client = OpenAITextClient(
model=model,
provider=Provider.OPENAI,
auth=AuthHeader(secret=SecretStr("test")),
)

request = client._init_request(
TextInput(
prompt="Summarize this document",
document=DocumentArtifact(url="https://example.com/doc.pdf"),
)
)

content = request["input"][0]["content"]
assert content[0]["type"] == "input_file"
assert content[0]["file_url"] == "https://example.com/doc.pdf"
assert "file_data" not in content[0]


def test_google_init_request_includes_mime_type_for_url_document() -> None:
model = Model(
id="gemini-2.5-pro",
provider=Provider.GOOGLE,
display_name="Gemini 2.5 Pro",
operations={Modality.TEXT: {Operation.GENERATE, Operation.ANALYZE}},
)
client = GoogleTextClient(
model=model,
provider=Provider.GOOGLE,
auth=AuthHeader(secret=SecretStr("test"), header="x-goog-api-key", prefix=""),
)

request = client._init_request(
TextInput(
prompt="Summarize this document",
document=DocumentArtifact(
url="https://example.com/doc.pdf", mime_type=DocumentMimeType.PDF
),
)
)

parts = request["contents"][0]["parts"]
assert "file_data" in parts[0]
assert parts[0]["file_data"]["file_uri"] == "https://example.com/doc.pdf"
assert parts[0]["file_data"]["mime_type"] == "application/pdf"
Loading