Skip to content

Commit b154e02

Browse files
sairin1202ankaisen
andauthored
feat: add multimodal memory (#79)
Co-authored-by: ankaisen <51148505+ankaisen@users.noreply.github.com>
1 parent 0e8404b commit b154e02

File tree

16 files changed

+1169
-14027
lines changed

16 files changed

+1169
-14027
lines changed

scripts/evals/locomo/PROMPT/eval

Lines changed: 0 additions & 23 deletions
This file was deleted.

scripts/evals/locomo/result.json

Lines changed: 0 additions & 13917 deletions
This file was deleted.

src/memu/app/service.py

Lines changed: 427 additions & 44 deletions
Large diffs are not rendered by default.

src/memu/llm/backends/base.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,18 @@ def build_summary_payload(
1818
def parse_summary_response(self, data: dict[str, Any]) -> str:
1919
raise NotImplementedError
2020

21+
def build_vision_payload(
22+
self,
23+
*,
24+
prompt: str,
25+
base64_image: str,
26+
mime_type: str,
27+
system_prompt: str | None,
28+
chat_model: str,
29+
max_tokens: int | None,
30+
) -> dict[str, Any]:
31+
raise NotImplementedError
32+
2133
def build_embedding_payload(self, *, inputs: list[str], embed_model: str) -> dict[str, Any]:
2234
raise NotImplementedError
2335

src/memu/llm/backends/openai.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,41 @@ def build_summary_payload(
2727
def parse_summary_response(self, data: dict[str, Any]) -> str:
2828
return cast(str, data["choices"][0]["message"]["content"])
2929

30+
def build_vision_payload(
31+
self,
32+
*,
33+
prompt: str,
34+
base64_image: str,
35+
mime_type: str,
36+
system_prompt: str | None,
37+
chat_model: str,
38+
max_tokens: int | None,
39+
) -> dict[str, Any]:
40+
"""Build payload for OpenAI Vision API."""
41+
messages: list[dict[str, Any]] = []
42+
if system_prompt:
43+
messages.append({"role": "system", "content": system_prompt})
44+
45+
messages.append({
46+
"role": "user",
47+
"content": [
48+
{"type": "text", "text": prompt},
49+
{
50+
"type": "image_url",
51+
"image_url": {
52+
"url": f"data:{mime_type};base64,{base64_image}",
53+
},
54+
},
55+
],
56+
})
57+
58+
return {
59+
"model": chat_model,
60+
"messages": messages,
61+
"temperature": 0.2,
62+
"max_tokens": max_tokens,
63+
}
64+
3065
def build_embedding_payload(self, *, inputs: list[str], embed_model: str) -> dict[str, Any]:
3166
return {"model": embed_model, "input": inputs}
3267

src/memu/llm/http_client.py

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
from __future__ import annotations
22

3+
import base64
34
import logging
45
from collections.abc import Callable
6+
from pathlib import Path
57

68
import httpx
79

@@ -54,6 +56,56 @@ async def summarize(self, text: str, max_tokens: int | None = None, system_promp
5456
logger.debug("HTTP LLM summarize response: %s", data)
5557
return self.backend.parse_summary_response(data)
5658

59+
async def vision(
60+
self,
61+
prompt: str,
62+
image_path: str,
63+
*,
64+
max_tokens: int | None = None,
65+
system_prompt: str | None = None,
66+
) -> str:
67+
"""
68+
Call Vision API with an image.
69+
70+
Args:
71+
prompt: Text prompt to send with the image
72+
image_path: Path to the image file
73+
max_tokens: Maximum tokens in response
74+
system_prompt: Optional system prompt
75+
76+
Returns:
77+
LLM response text
78+
"""
79+
# Read and encode image as base64
80+
image_data = Path(image_path).read_bytes()
81+
base64_image = base64.b64encode(image_data).decode("utf-8")
82+
83+
# Detect image format
84+
suffix = Path(image_path).suffix.lower()
85+
mime_type = {
86+
".jpg": "image/jpeg",
87+
".jpeg": "image/jpeg",
88+
".png": "image/png",
89+
".gif": "image/gif",
90+
".webp": "image/webp",
91+
}.get(suffix, "image/jpeg")
92+
93+
payload = self.backend.build_vision_payload(
94+
prompt=prompt,
95+
base64_image=base64_image,
96+
mime_type=mime_type,
97+
system_prompt=system_prompt,
98+
chat_model=self.chat_model,
99+
max_tokens=max_tokens,
100+
)
101+
102+
async with httpx.AsyncClient(base_url=self.base_url, timeout=self.timeout) as client:
103+
resp = await client.post(self.summary_endpoint, json=payload, headers=self._headers())
104+
resp.raise_for_status()
105+
data = resp.json()
106+
logger.debug("HTTP LLM vision response: %s", data)
107+
return self.backend.parse_summary_response(data)
108+
57109
async def embed(self, inputs: list[str]) -> list[list[float]]:
58110
payload = self.backend.build_embedding_payload(inputs=inputs, embed_model=self.embed_model)
59111
async with httpx.AsyncClient(base_url=self.base_url, timeout=self.timeout) as client:
@@ -63,6 +115,61 @@ async def embed(self, inputs: list[str]) -> list[list[float]]:
63115
logger.debug("HTTP LLM embedding response: %s", data)
64116
return self.backend.parse_embedding_response(data)
65117

118+
async def transcribe(
119+
self,
120+
audio_path: str,
121+
*,
122+
prompt: str | None = None,
123+
language: str | None = None,
124+
response_format: str = "text",
125+
) -> str:
126+
"""
127+
Transcribe audio file using OpenAI Audio API.
128+
129+
Args:
130+
audio_path: Path to the audio file
131+
prompt: Optional prompt to guide the transcription
132+
language: Optional language code (e.g., 'en', 'zh')
133+
response_format: Response format ('text', 'json', 'verbose_json')
134+
135+
Returns:
136+
Transcribed text
137+
"""
138+
try:
139+
# Prepare multipart form data
140+
with open(audio_path, "rb") as audio_file:
141+
files = {"file": (Path(audio_path).name, audio_file, "application/octet-stream")}
142+
data = {
143+
"model": "gpt-4o-mini-transcribe",
144+
"response_format": response_format,
145+
}
146+
if prompt:
147+
data["prompt"] = prompt
148+
if language:
149+
data["language"] = language
150+
151+
async with httpx.AsyncClient(base_url=self.base_url, timeout=self.timeout * 3) as client:
152+
resp = await client.post(
153+
"/v1/audio/transcriptions",
154+
files=files,
155+
data=data,
156+
headers=self._headers(),
157+
)
158+
resp.raise_for_status()
159+
160+
if response_format == "text":
161+
result = resp.text
162+
else:
163+
result_data = resp.json()
164+
result = result_data.get("text", "")
165+
166+
logger.debug("HTTP audio transcribe response for %s: %s chars", audio_path, len(result))
167+
except Exception:
168+
logger.exception("Audio transcription failed for %s", audio_path)
169+
raise
170+
else:
171+
return result or ""
172+
66173
def _headers(self) -> dict[str, str]:
67174
return {"Authorization": f"Bearer {self.api_key}"}
68175

src/memu/llm/openai_sdk.py

Lines changed: 130 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,16 @@
1+
import base64
12
import logging
2-
from typing import cast
3+
from pathlib import Path
4+
from typing import Any, Literal, cast
35

46
from openai import AsyncOpenAI
7+
from openai.types.chat import (
8+
ChatCompletionContentPartImageParam,
9+
ChatCompletionContentPartTextParam,
10+
ChatCompletionMessageParam,
11+
ChatCompletionSystemMessageParam,
12+
ChatCompletionUserMessageParam,
13+
)
514

615
logger = logging.getLogger(__name__)
716

@@ -25,19 +34,134 @@ async def summarize(
2534
) -> str:
2635
prompt = system_prompt or "Summarize the text in one short paragraph."
2736

37+
system_message: ChatCompletionSystemMessageParam = {"role": "system", "content": prompt}
38+
user_message: ChatCompletionUserMessageParam = {"role": "user", "content": text}
39+
messages: list[ChatCompletionMessageParam] = [system_message, user_message]
40+
2841
response = await self.client.chat.completions.create(
2942
model=self.chat_model,
30-
messages=[
31-
{"role": "system", "content": prompt},
32-
{"role": "user", "content": text},
33-
],
43+
messages=messages,
3444
temperature=1,
35-
max_completion_tokens=max_tokens,
45+
max_tokens=max_tokens,
3646
)
3747
content = response.choices[0].message.content
3848
logger.debug("OpenAI summarize response: %s", response)
3949
return content or ""
4050

51+
async def vision(
52+
self,
53+
prompt: str,
54+
image_path: str,
55+
*,
56+
max_tokens: int | None = None,
57+
system_prompt: str | None = None,
58+
) -> str:
59+
"""
60+
Call OpenAI Vision API with an image.
61+
62+
Args:
63+
prompt: Text prompt to send with the image
64+
image_path: Path to the image file
65+
max_tokens: Maximum tokens in response
66+
system_prompt: Optional system prompt
67+
68+
Returns:
69+
LLM response text
70+
"""
71+
# Read and encode image as base64
72+
image_data = Path(image_path).read_bytes()
73+
base64_image = base64.b64encode(image_data).decode("utf-8")
74+
75+
# Detect image format
76+
suffix = Path(image_path).suffix.lower()
77+
mime_type = {
78+
".jpg": "image/jpeg",
79+
".jpeg": "image/jpeg",
80+
".png": "image/png",
81+
".gif": "image/gif",
82+
".webp": "image/webp",
83+
}.get(suffix, "image/jpeg")
84+
85+
# Build messages with image
86+
messages: list[ChatCompletionMessageParam] = []
87+
if system_prompt:
88+
system_message: ChatCompletionSystemMessageParam = {
89+
"role": "system",
90+
"content": system_prompt,
91+
}
92+
messages.append(system_message)
93+
94+
text_part: ChatCompletionContentPartTextParam = {"type": "text", "text": prompt}
95+
image_part: ChatCompletionContentPartImageParam = {
96+
"type": "image_url",
97+
"image_url": {
98+
"url": f"data:{mime_type};base64,{base64_image}",
99+
},
100+
}
101+
user_message: ChatCompletionUserMessageParam = {
102+
"role": "user",
103+
"content": [text_part, image_part],
104+
}
105+
messages.append(user_message)
106+
107+
response = await self.client.chat.completions.create(
108+
model=self.chat_model,
109+
messages=messages,
110+
temperature=1,
111+
max_tokens=max_tokens,
112+
)
113+
content = response.choices[0].message.content
114+
logger.debug("OpenAI vision response: %s", response)
115+
return content or ""
116+
41117
async def embed(self, inputs: list[str]) -> list[list[float]]:
42118
response = await self.client.embeddings.create(model=self.embed_model, input=inputs)
43119
return [cast(list[float], d.embedding) for d in response.data]
120+
121+
async def transcribe(
122+
self,
123+
audio_path: str,
124+
*,
125+
prompt: str | None = None,
126+
language: str | None = None,
127+
response_format: Literal["text", "json", "verbose_json"] = "text",
128+
) -> str:
129+
"""
130+
Transcribe audio file using OpenAI Audio API.
131+
132+
Args:
133+
audio_path: Path to the audio file
134+
prompt: Optional prompt to guide the transcription
135+
language: Optional language code (e.g., 'en', 'zh')
136+
response_format: Response format ('text', 'json', 'verbose_json')
137+
138+
Returns:
139+
Transcribed text
140+
"""
141+
try:
142+
# Use gpt-4o-mini-transcribe for better performance and cost
143+
kwargs: dict[str, Any] = {}
144+
if prompt is not None:
145+
kwargs["prompt"] = prompt
146+
if language is not None:
147+
kwargs["language"] = language
148+
with open(audio_path, "rb") as audio_stream:
149+
transcription = await self.client.audio.transcriptions.create(
150+
file=audio_stream,
151+
model="gpt-4o-mini-transcribe",
152+
response_format=response_format,
153+
**kwargs,
154+
)
155+
156+
# Handle different response formats
157+
if response_format == "text":
158+
result = transcription if isinstance(transcription, str) else transcription.text
159+
else:
160+
result = transcription.text if hasattr(transcription, "text") else str(transcription)
161+
162+
logger.debug("OpenAI transcribe response for %s: %s chars", audio_path, len(result))
163+
except Exception:
164+
logger.exception("Audio transcription failed for %s", audio_path)
165+
raise
166+
else:
167+
return result or ""
Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,11 @@
1-
from memu.prompts.preprocess import conversation
1+
from memu.prompts.preprocess import audio, conversation, document, image, video
22

33
PROMPTS: dict[str, str] = {
44
"conversation": conversation.PROMPT.strip(),
5+
"video": video.PROMPT.strip(),
6+
"image": image.PROMPT.strip(),
7+
"document": document.PROMPT.strip(),
8+
"audio": audio.PROMPT.strip(),
59
}
610

711
__all__ = ["PROMPTS"]

0 commit comments

Comments
 (0)