Skip to content

Commit 85e5629

Browse files
committed
add: multimodal process
1 parent 0e8404b commit 85e5629

File tree

16 files changed

+1036
-13992
lines changed

16 files changed

+1036
-13992
lines changed

scripts/evals/locomo/PROMPT/eval

Lines changed: 0 additions & 23 deletions
This file was deleted.

scripts/evals/locomo/result.json

Lines changed: 0 additions & 13917 deletions
This file was deleted.

src/memu/app/service.py

Lines changed: 347 additions & 14 deletions
Large diffs are not rendered by default.

src/memu/llm/backends/base.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,18 @@ def build_summary_payload(
1717

1818
def parse_summary_response(self, data: dict[str, Any]) -> str:
1919
raise NotImplementedError
20+
21+
def build_vision_payload(
22+
self,
23+
*,
24+
prompt: str,
25+
base64_image: str,
26+
mime_type: str,
27+
system_prompt: str | None,
28+
chat_model: str,
29+
max_tokens: int | None,
30+
) -> dict[str, Any]:
31+
raise NotImplementedError
2032

2133
def build_embedding_payload(self, *, inputs: list[str], embed_model: str) -> dict[str, Any]:
2234
raise NotImplementedError

src/memu/llm/backends/openai.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,41 @@ def build_summary_payload(
2626

2727
def parse_summary_response(self, data: dict[str, Any]) -> str:
2828
return cast(str, data["choices"][0]["message"]["content"])
29+
30+
def build_vision_payload(
31+
self,
32+
*,
33+
prompt: str,
34+
base64_image: str,
35+
mime_type: str,
36+
system_prompt: str | None,
37+
chat_model: str,
38+
max_tokens: int | None,
39+
) -> dict[str, Any]:
40+
"""Build payload for OpenAI Vision API."""
41+
messages: list[dict[str, Any]] = []
42+
if system_prompt:
43+
messages.append({"role": "system", "content": system_prompt})
44+
45+
messages.append({
46+
"role": "user",
47+
"content": [
48+
{"type": "text", "text": prompt},
49+
{
50+
"type": "image_url",
51+
"image_url": {
52+
"url": f"data:{mime_type};base64,{base64_image}",
53+
},
54+
},
55+
],
56+
})
57+
58+
return {
59+
"model": chat_model,
60+
"messages": messages,
61+
"temperature": 0.2,
62+
"max_tokens": max_tokens,
63+
}
2964

3065
def build_embedding_payload(self, *, inputs: list[str], embed_model: str) -> dict[str, Any]:
3166
return {"model": embed_model, "input": inputs}

src/memu/llm/http_client.py

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
from __future__ import annotations
22

3+
import base64
34
import logging
45
from collections.abc import Callable
6+
from pathlib import Path
57

68
import httpx
79

@@ -53,6 +55,56 @@ async def summarize(self, text: str, max_tokens: int | None = None, system_promp
5355
data = resp.json()
5456
logger.debug("HTTP LLM summarize response: %s", data)
5557
return self.backend.parse_summary_response(data)
58+
59+
async def vision(
60+
self,
61+
prompt: str,
62+
image_path: str,
63+
*,
64+
max_tokens: int | None = None,
65+
system_prompt: str | None = None,
66+
) -> str:
67+
"""
68+
Call Vision API with an image.
69+
70+
Args:
71+
prompt: Text prompt to send with the image
72+
image_path: Path to the image file
73+
max_tokens: Maximum tokens in response
74+
system_prompt: Optional system prompt
75+
76+
Returns:
77+
LLM response text
78+
"""
79+
# Read and encode image as base64
80+
image_data = Path(image_path).read_bytes()
81+
base64_image = base64.b64encode(image_data).decode("utf-8")
82+
83+
# Detect image format
84+
suffix = Path(image_path).suffix.lower()
85+
mime_type = {
86+
".jpg": "image/jpeg",
87+
".jpeg": "image/jpeg",
88+
".png": "image/png",
89+
".gif": "image/gif",
90+
".webp": "image/webp",
91+
}.get(suffix, "image/jpeg")
92+
93+
payload = self.backend.build_vision_payload(
94+
prompt=prompt,
95+
base64_image=base64_image,
96+
mime_type=mime_type,
97+
system_prompt=system_prompt,
98+
chat_model=self.chat_model,
99+
max_tokens=max_tokens,
100+
)
101+
102+
async with httpx.AsyncClient(base_url=self.base_url, timeout=self.timeout) as client:
103+
resp = await client.post(self.summary_endpoint, json=payload, headers=self._headers())
104+
resp.raise_for_status()
105+
data = resp.json()
106+
logger.debug("HTTP LLM vision response: %s", data)
107+
return self.backend.parse_summary_response(data)
56108

57109
async def embed(self, inputs: list[str]) -> list[list[float]]:
58110
payload = self.backend.build_embedding_payload(inputs=inputs, embed_model=self.embed_model)
@@ -62,6 +114,61 @@ async def embed(self, inputs: list[str]) -> list[list[float]]:
62114
data = resp.json()
63115
logger.debug("HTTP LLM embedding response: %s", data)
64116
return self.backend.parse_embedding_response(data)
117+
118+
async def transcribe(
119+
self,
120+
audio_path: str,
121+
*,
122+
prompt: str | None = None,
123+
language: str | None = None,
124+
response_format: str = "text",
125+
) -> str:
126+
"""
127+
Transcribe audio file using OpenAI Audio API.
128+
129+
Args:
130+
audio_path: Path to the audio file
131+
prompt: Optional prompt to guide the transcription
132+
language: Optional language code (e.g., 'en', 'zh')
133+
response_format: Response format ('text', 'json', 'verbose_json')
134+
135+
Returns:
136+
Transcribed text
137+
"""
138+
try:
139+
# Prepare multipart form data
140+
with open(audio_path, "rb") as audio_file:
141+
files = {"file": (Path(audio_path).name, audio_file, "application/octet-stream")}
142+
data = {
143+
"model": "gpt-4o-mini-transcribe",
144+
"response_format": response_format,
145+
}
146+
if prompt:
147+
data["prompt"] = prompt
148+
if language:
149+
data["language"] = language
150+
151+
async with httpx.AsyncClient(base_url=self.base_url, timeout=self.timeout * 3) as client:
152+
resp = await client.post(
153+
"/v1/audio/transcriptions",
154+
files=files,
155+
data=data,
156+
headers=self._headers(),
157+
)
158+
resp.raise_for_status()
159+
160+
if response_format == "text":
161+
result = resp.text
162+
else:
163+
result_data = resp.json()
164+
result = result_data.get("text", "")
165+
166+
logger.debug("HTTP audio transcribe response for %s: %s chars", audio_path, len(result))
167+
return result or ""
168+
169+
except Exception as e:
170+
logger.error("Audio transcription failed for %s: %s", audio_path, e)
171+
raise
65172

66173
def _headers(self) -> dict[str, str]:
67174
return {"Authorization": f"Bearer {self.api_key}"}

src/memu/llm/openai_sdk.py

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1+
import base64
12
import logging
3+
from pathlib import Path
24
from typing import cast
35

46
from openai import AsyncOpenAI
@@ -38,6 +40,112 @@ async def summarize(
3840
logger.debug("OpenAI summarize response: %s", response)
3941
return content or ""
4042

43+
async def vision(
44+
self,
45+
prompt: str,
46+
image_path: str,
47+
*,
48+
max_tokens: int | None = None,
49+
system_prompt: str | None = None,
50+
) -> str:
51+
"""
52+
Call OpenAI Vision API with an image.
53+
54+
Args:
55+
prompt: Text prompt to send with the image
56+
image_path: Path to the image file
57+
max_tokens: Maximum tokens in response
58+
system_prompt: Optional system prompt
59+
60+
Returns:
61+
LLM response text
62+
"""
63+
# Read and encode image as base64
64+
image_data = Path(image_path).read_bytes()
65+
base64_image = base64.b64encode(image_data).decode("utf-8")
66+
67+
# Detect image format
68+
suffix = Path(image_path).suffix.lower()
69+
mime_type = {
70+
".jpg": "image/jpeg",
71+
".jpeg": "image/jpeg",
72+
".png": "image/png",
73+
".gif": "image/gif",
74+
".webp": "image/webp",
75+
}.get(suffix, "image/jpeg")
76+
77+
# Build messages with image
78+
messages: list[dict] = []
79+
if system_prompt:
80+
messages.append({"role": "system", "content": system_prompt})
81+
82+
messages.append({
83+
"role": "user",
84+
"content": [
85+
{"type": "text", "text": prompt},
86+
{
87+
"type": "image_url",
88+
"image_url": {
89+
"url": f"data:{mime_type};base64,{base64_image}",
90+
},
91+
},
92+
],
93+
})
94+
95+
response = await self.client.chat.completions.create(
96+
model=self.chat_model,
97+
messages=messages,
98+
temperature=1,
99+
max_completion_tokens=max_tokens,
100+
)
101+
content = response.choices[0].message.content
102+
logger.debug("OpenAI vision response: %s", response)
103+
return content or ""
104+
41105
async def embed(self, inputs: list[str]) -> list[list[float]]:
42106
response = await self.client.embeddings.create(model=self.embed_model, input=inputs)
43107
return [cast(list[float], d.embedding) for d in response.data]
108+
109+
async def transcribe(
110+
self,
111+
audio_path: str,
112+
*,
113+
prompt: str | None = None,
114+
language: str | None = None,
115+
response_format: str = "text",
116+
) -> str:
117+
"""
118+
Transcribe audio file using OpenAI Audio API.
119+
120+
Args:
121+
audio_path: Path to the audio file
122+
prompt: Optional prompt to guide the transcription
123+
language: Optional language code (e.g., 'en', 'zh')
124+
response_format: Response format ('text', 'json', 'verbose_json')
125+
126+
Returns:
127+
Transcribed text
128+
"""
129+
try:
130+
with open(audio_path, "rb") as audio_file:
131+
# Use gpt-4o-mini-transcribe for better performance and cost
132+
transcription = await self.client.audio.transcriptions.create(
133+
model="gpt-4o-mini-transcribe",
134+
file=audio_file,
135+
response_format=response_format,
136+
prompt=prompt,
137+
language=language,
138+
)
139+
140+
# Handle different response formats
141+
if response_format == "text":
142+
result = transcription if isinstance(transcription, str) else transcription.text
143+
else:
144+
result = transcription.text if hasattr(transcription, "text") else str(transcription)
145+
146+
logger.debug("OpenAI transcribe response for %s: %s chars", audio_path, len(result))
147+
return result or ""
148+
149+
except Exception as e:
150+
logger.error("Audio transcription failed for %s: %s", audio_path, e)
151+
raise
Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,11 @@
1-
from memu.prompts.preprocess import conversation
1+
from memu.prompts.preprocess import audio, conversation, document, image, video
22

33
PROMPTS: dict[str, str] = {
44
"conversation": conversation.PROMPT.strip(),
5+
"video": video.PROMPT.strip(),
6+
"image": image.PROMPT.strip(),
7+
"document": document.PROMPT.strip(),
8+
"audio": audio.PROMPT.strip(),
59
}
610

711
__all__ = ["PROMPTS"]
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
PROMPT = """
2+
Analyze the following audio transcription and provide two outputs:
3+
4+
## Transcription:
5+
<transcription>
6+
{transcription}
7+
</transcription>
8+
9+
## Task:
10+
1. **Processed Content**: Provide a clean, well-formatted version of the transcription with proper punctuation and paragraph breaks if needed
11+
2. **Caption**: Provide a one-sentence summary describing what the audio is about
12+
13+
## Output Format:
14+
<processed_content>
15+
[Provide the cleaned and formatted transcription here]
16+
</processed_content>
17+
18+
<caption>
19+
[Provide a one-sentence summary of what the audio is about]
20+
</caption>
21+
"""

0 commit comments

Comments
 (0)