Skip to content

Commit ef6c299

Browse files
committed
feat: add memorize and retrieve methods
1 parent a9095fa commit ef6c299

File tree

24 files changed

+1288
-0
lines changed

24 files changed

+1288
-0
lines changed

src/memu/app/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
from .service import MemoryService
2+
from .settings import AppSettings
3+
4+
__all__ = ["AppSettings", "MemoryService"]

src/memu/app/service.py

Lines changed: 480 additions & 0 deletions
Large diffs are not rendered by default.

src/memu/app/settings.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
from pydantic import BaseModel, Field
2+
3+
from memu.prompts.memory_type import DEFAULT_MEMORY_TYPES
4+
from memu.prompts.memory_type import PROMPTS as DEFAULT_MEMORY_TYPE_PROMPTS
5+
6+
7+
def _default_memory_types() -> list[str]:
8+
return list(DEFAULT_MEMORY_TYPES)
9+
10+
11+
def _default_memory_type_prompts() -> dict[str, str]:
12+
return dict(DEFAULT_MEMORY_TYPE_PROMPTS)
13+
14+
15+
def _default_memory_categories() -> list[dict[str, str]]:
16+
return [
17+
{"name": "personal_info", "description": "Personal information about the user"},
18+
{"name": "preferences", "description": "User preferences, likes and dislikes"},
19+
{"name": "relationships", "description": "Information about relationships with others"},
20+
{"name": "activities", "description": "Activities, hobbies, and interests"},
21+
{"name": "goals", "description": "Goals, aspirations, and objectives"},
22+
{"name": "experiences", "description": "Past experiences and events"},
23+
{"name": "knowledge", "description": "Knowledge, facts, and learned information"},
24+
{"name": "opinions", "description": "Opinions, viewpoints, and perspectives"},
25+
{"name": "habits", "description": "Habits, routines, and patterns"},
26+
{"name": "work_life", "description": "Work-related information and professional life"},
27+
]
28+
29+
30+
class AppSettings(BaseModel):
31+
# where to store raw resources
32+
resources_dir: str = Field(default="./resources")
33+
# openai base
34+
openai_base: str = Field(default="https://api.openai.com/v1")
35+
openai_api_key_env: str = Field(default="OPENAI_API_KEY")
36+
# models
37+
chat_model: str = Field(default="gpt-4o-mini")
38+
embed_model: str = Field(default="text-embedding-3-small")
39+
llm_client_backend: str = Field(
40+
default="httpx",
41+
description="Which OpenAI client backend to use: 'httpx' (httpx) or 'sdk' (official OpenAI).",
42+
)
43+
llm_http_provider: str = Field(
44+
default="openai",
45+
description="Name of the HTTP LLM provider implementation (e.g. 'openai').",
46+
)
47+
llm_http_endpoints: dict[str, str] = Field(
48+
default_factory=dict,
49+
description="Optional overrides for HTTP endpoints (keys: 'chat'/'summary', 'embeddings'/'embed').",
50+
)
51+
# thresholds
52+
category_assign_threshold: float = Field(default=0.25)
53+
# summarization prompts
54+
default_summary_prompt: str = Field(default="Summarize the text in one short paragraph.")
55+
summary_prompts: dict[str, str] = Field(
56+
default_factory=dict,
57+
description="Optional mapping of modality -> summary system prompt.",
58+
)
59+
memory_categories: list[dict[str, str]] = Field(
60+
default_factory=_default_memory_categories,
61+
description="Global memory category definitions embedded at service startup.",
62+
)
63+
category_summary_target_length: int = Field(
64+
default=400,
65+
description="Target max length for auto-generated category summaries.",
66+
)
67+
memory_types: list[str] = Field(
68+
default_factory=_default_memory_types,
69+
description="Ordered list of memory types (profile/event/knowledge/behavior by default).",
70+
)
71+
memory_type_prompts: dict[str, str] = Field(
72+
default_factory=_default_memory_type_prompts,
73+
description="System prompt overrides for each memory type extraction.",
74+
)

src/memu/llm/backends/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
# Re-export common backends for convenience.
2+
from .base import HTTPBackend
3+
from .openai import OpenAIHTTPBackend
4+
5+
__all__ = ["HTTPBackend", "OpenAIHTTPBackend"]

src/memu/llm/backends/base.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
from __future__ import annotations
2+
3+
from typing import Any
4+
5+
6+
class HTTPBackend:
7+
"""Defines how to talk to a specific HTTP LLM provider."""
8+
9+
name: str = "base"
10+
summary_endpoint: str = "/chat/completions"
11+
embedding_endpoint: str = "/embeddings"
12+
13+
def build_summary_payload(
14+
self, *, text: str, system_prompt: str | None, chat_model: str, max_tokens: int
15+
) -> dict[str, Any]:
16+
raise NotImplementedError
17+
18+
def parse_summary_response(self, data: dict[str, Any]) -> str:
19+
raise NotImplementedError
20+
21+
def build_embedding_payload(self, *, inputs: list[str], embed_model: str) -> dict[str, Any]:
22+
raise NotImplementedError
23+
24+
def parse_embedding_response(self, data: dict[str, Any]) -> list[list[float]]:
25+
raise NotImplementedError

src/memu/llm/backends/openai.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
from __future__ import annotations
2+
3+
from typing import Any, cast
4+
5+
from .base import HTTPBackend
6+
7+
8+
class OpenAIHTTPBackend(HTTPBackend):
9+
name = "openai"
10+
summary_endpoint = "/chat/completions"
11+
embedding_endpoint = "/embeddings"
12+
13+
def build_summary_payload(
14+
self, *, text: str, system_prompt: str | None, chat_model: str, max_tokens: int
15+
) -> dict[str, Any]:
16+
prompt = system_prompt or "Summarize the text in one short paragraph."
17+
return {
18+
"model": chat_model,
19+
"messages": [
20+
{"role": "system", "content": prompt},
21+
{"role": "user", "content": text},
22+
],
23+
"temperature": 0.2,
24+
"max_tokens": max_tokens,
25+
}
26+
27+
def parse_summary_response(self, data: dict[str, Any]) -> str:
28+
return cast(str, data["choices"][0]["message"]["content"])
29+
30+
def build_embedding_payload(self, *, inputs: list[str], embed_model: str) -> dict[str, Any]:
31+
return {"model": embed_model, "input": inputs}
32+
33+
def parse_embedding_response(self, data: dict[str, Any]) -> list[list[float]]:
34+
return [cast(list[float], d["embedding"]) for d in data["data"]]

src/memu/llm/http_client.py

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
from __future__ import annotations
2+
3+
import os
4+
from collections.abc import Callable
5+
from typing import cast
6+
7+
import httpx
8+
import numpy as np
9+
10+
from memu.llm.backends.base import HTTPBackend
11+
from memu.llm.backends.openai import OpenAIHTTPBackend
12+
13+
HTTP_BACKENDS: dict[str, Callable[[], HTTPBackend]] = {
14+
OpenAIHTTPBackend.name: OpenAIHTTPBackend,
15+
}
16+
17+
18+
class HTTPLLMClient:
19+
def __init__(
20+
self,
21+
*,
22+
base_url: str,
23+
api_key: str,
24+
chat_model: str,
25+
embed_model: str,
26+
provider: str = "openai",
27+
endpoint_overrides: dict[str, str] | None = None,
28+
timeout: int = 60,
29+
):
30+
self.base_url = base_url.rstrip("/")
31+
self.api_key = api_key or ""
32+
self.chat_model = chat_model
33+
self.embed_model = embed_model
34+
self.provider = provider.lower()
35+
self.backend = self._load_backend(self.provider)
36+
overrides = endpoint_overrides or {}
37+
self.summary_endpoint = overrides.get("chat") or overrides.get("summary") or self.backend.summary_endpoint
38+
self.embedding_endpoint = (
39+
overrides.get("embeddings")
40+
or overrides.get("embedding")
41+
or overrides.get("embed")
42+
or self.backend.embedding_endpoint
43+
)
44+
self.fake = bool(os.getenv("MEMUFLOW_FAKE_OPENAI")) or not bool(self.api_key)
45+
self.timeout = timeout
46+
47+
async def summarize(self, text: str, max_tokens: int = 160, system_prompt: str | None = None) -> str:
48+
if self.fake:
49+
s = " ".join(text.strip().split())
50+
return s[:200] + ("..." if len(s) > 200 else "")
51+
52+
payload = self.backend.build_summary_payload(
53+
text=text, system_prompt=system_prompt, chat_model=self.chat_model, max_tokens=max_tokens
54+
)
55+
async with httpx.AsyncClient(base_url=self.base_url, timeout=self.timeout) as client:
56+
resp = await client.post(self.summary_endpoint, json=payload, headers=self._headers())
57+
resp.raise_for_status()
58+
data = resp.json()
59+
return self.backend.parse_summary_response(data)
60+
61+
async def embed(self, inputs: list[str]) -> list[list[float]]:
62+
if self.fake:
63+
return [self._fake_vec(x) for x in inputs]
64+
payload = self.backend.build_embedding_payload(inputs=inputs, embed_model=self.embed_model)
65+
async with httpx.AsyncClient(base_url=self.base_url, timeout=self.timeout) as client:
66+
resp = await client.post(self.embedding_endpoint, json=payload, headers=self._headers())
67+
resp.raise_for_status()
68+
data = resp.json()
69+
return self.backend.parse_embedding_response(data)
70+
71+
def _headers(self) -> dict[str, str]:
72+
return {"Authorization": f"Bearer {self.api_key}"}
73+
74+
def _fake_vec(self, s: str, dim: int = 256) -> list[float]:
75+
import hashlib
76+
77+
h = hashlib.sha256(s.encode("utf-8")).digest()
78+
b = (h * (dim // len(h) + 1))[:dim]
79+
arr = np.frombuffer(b, dtype=np.uint8).astype(np.float32)
80+
arr = (arr - arr.mean()) / (arr.std() + 1e-6)
81+
arr = arr / (np.linalg.norm(arr) + 1e-9)
82+
return cast(list[float], arr.tolist())
83+
84+
def _load_backend(self, provider: str) -> HTTPBackend:
85+
factory = HTTP_BACKENDS.get(provider)
86+
if not factory:
87+
msg = f"Unsupported HTTP LLM provider '{provider}'. Available: {', '.join(HTTP_BACKENDS.keys())}"
88+
raise ValueError(msg)
89+
return factory()

src/memu/llm/openai_sdk.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
import os
2+
from typing import TYPE_CHECKING, cast
3+
4+
import numpy as np
5+
6+
if TYPE_CHECKING:
7+
from openai import AsyncOpenAI # 只给类型检查用
8+
9+
try:
10+
import openai
11+
except ImportError:
12+
openai = None # 运行时用来判断有没有这个库
13+
14+
15+
class OpenAISDKClient:
16+
"""OpenAI client that relies on the official Python SDK."""
17+
18+
def __init__(self, *, base_url: str, api_key: str, chat_model: str, embed_model: str):
19+
self.base_url = base_url.rstrip("/")
20+
self.api_key = api_key or ""
21+
self.chat_model = chat_model
22+
self.embed_model = embed_model
23+
self.fake = bool(os.getenv("MEMUFLOW_FAKE_OPENAI")) or not bool(self.api_key)
24+
self.client: AsyncOpenAI | None = None
25+
if self.fake:
26+
self.client = None
27+
else:
28+
if openai is None:
29+
msg = "The 'openai' Python package is required for the SDK client. Install it via `pip install openai` or switch to the httpx backend."
30+
raise RuntimeError(msg)
31+
self.client = openai.AsyncOpenAI(api_key=self.api_key, base_url=self.base_url)
32+
33+
async def summarize(
34+
self,
35+
text: str,
36+
*,
37+
max_tokens: int = 160,
38+
system_prompt: str | None = None,
39+
) -> str:
40+
prompt = system_prompt or "Summarize the text in one short paragraph."
41+
if self.fake:
42+
s = " ".join(text.strip().split())
43+
return s[:200] + ("..." if len(s) > 200 else "")
44+
if self.client is None:
45+
msg = "The 'openai' Python package is required for the SDK client. Install it via `pip install openai` or switch to the httpx backend."
46+
raise RuntimeError(msg)
47+
response = await self.client.chat.completions.create(
48+
model=self.chat_model,
49+
messages=[
50+
{"role": "system", "content": prompt},
51+
{"role": "user", "content": text},
52+
],
53+
temperature=0.2,
54+
max_tokens=max_tokens,
55+
)
56+
content = response.choices[0].message.content
57+
return content or ""
58+
59+
async def embed(self, inputs: list[str]) -> list[list[float]]:
60+
if self.fake:
61+
return [self._fake_vec(x) for x in inputs]
62+
if self.client is None:
63+
msg = "The 'openai' Python package is required for the SDK client. Install it via `pip install openai` or switch to the httpx backend."
64+
raise RuntimeError(msg)
65+
response = await self.client.embeddings.create(model=self.embed_model, input=inputs)
66+
return [cast(list[float], d.embedding) for d in response.data]
67+
68+
def _fake_vec(self, s: str, dim: int = 256) -> list[float]:
69+
# deterministic pseudo-embedding for offline demo
70+
import hashlib
71+
72+
h = hashlib.sha256(s.encode("utf-8")).digest()
73+
b = (h * (dim // len(h) + 1))[:dim]
74+
arr = np.frombuffer(b, dtype=np.uint8).astype(np.float32)
75+
arr = (arr - arr.mean()) / (arr.std() + 1e-6)
76+
arr = arr / (np.linalg.norm(arr) + 1e-9)
77+
return arr.tolist()

src/memu/memory/repo.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
from __future__ import annotations
2+
3+
import uuid
4+
5+
from memu.models import CategoryItem, MemoryCategory, MemoryItem, MemoryType, Resource
6+
7+
8+
class InMemoryStore:
9+
def __init__(self) -> None:
10+
self.resources: dict[str, Resource] = {}
11+
self.items: dict[str, MemoryItem] = {}
12+
self.categories: dict[str, MemoryCategory] = {}
13+
self.relations: list[CategoryItem] = []
14+
15+
def create_resource(self, *, url: str, modality: str, local_path: str) -> Resource:
16+
rid = str(uuid.uuid7())
17+
res = Resource(id=rid, url=url, modality=modality, local_path=local_path)
18+
self.resources[rid] = res
19+
return res
20+
21+
def get_or_create_category(self, *, name: str, description: str, embedding: list[float]) -> MemoryCategory:
22+
for c in self.categories.values():
23+
if c.name == name:
24+
if not c.embedding:
25+
c.embedding = embedding
26+
if not c.description:
27+
c.description = description
28+
return c
29+
cid = str(uuid.uuid7())
30+
cat = MemoryCategory(id=cid, name=name, description=description, embedding=embedding)
31+
self.categories[cid] = cat
32+
return cat
33+
34+
def create_item(
35+
self, *, resource_id: str, memory_type: MemoryType, summary: str, embedding: list[float]
36+
) -> MemoryItem:
37+
mid = str(uuid.uuid7())
38+
it = MemoryItem(
39+
id=mid,
40+
resource_id=resource_id,
41+
memory_type=memory_type,
42+
summary=summary,
43+
embedding=embedding,
44+
category_ids=[],
45+
)
46+
self.items[mid] = it
47+
return it
48+
49+
def link_item_category(self, item_id: str, cat_id: str) -> CategoryItem:
50+
it = self.items[item_id]
51+
if cat_id not in it.category_ids:
52+
it.category_ids.append(cat_id)
53+
rel = CategoryItem(item_id=item_id, category_id=cat_id)
54+
self.relations.append(rel)
55+
return rel

0 commit comments

Comments
 (0)