Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion src/memu/app/memorize.py
Original file line number Diff line number Diff line change
Expand Up @@ -779,7 +779,9 @@ async def _preprocess_conversation(
preprocessed_text = format_conversation_for_preprocess(text)
prompt = template.format(conversation=self._escape_prompt_value(preprocessed_text))
client = llm_client or self._get_llm_client()
processed = await client.summarize(prompt, system_prompt=None)
processed = await client.summarize(
prompt, system_prompt=None, response_format="json_object"
)
_conv, segments = self._parse_conversation_preprocess_with_segments(processed, preprocessed_text)

# Important: always use the original JSON-derived, indexed conversation text for downstream
Expand Down
6 changes: 3 additions & 3 deletions src/memu/app/retrieve.py
Original file line number Diff line number Diff line change
Expand Up @@ -1195,7 +1195,7 @@ async def _llm_rank_categories(
)

client = llm_client or self._get_llm_client()
llm_response = await client.summarize(prompt, system_prompt=None)
llm_response = await client.summarize(prompt, system_prompt=None, response_format="json_object")
return self._parse_llm_category_response(llm_response, store, categories=category_pool)

async def _llm_rank_items(
Expand Down Expand Up @@ -1234,7 +1234,7 @@ async def _llm_rank_items(
)

client = llm_client or self._get_llm_client()
llm_response = await client.summarize(prompt, system_prompt=None)
llm_response = await client.summarize(prompt, system_prompt=None, response_format="json_object")
return self._parse_llm_item_response(llm_response, store, items=item_pool)

async def _llm_rank_resources(
Expand Down Expand Up @@ -1279,7 +1279,7 @@ async def _llm_rank_resources(
)

client = llm_client or self._get_llm_client()
llm_response = await client.summarize(prompt, system_prompt=None)
llm_response = await client.summarize(prompt, system_prompt=None, response_format="json_object")
return self._parse_llm_resource_response(llm_response, store, resources=resource_pool)

def _parse_llm_category_response(
Expand Down
1 change: 1 addition & 0 deletions src/memu/app/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ def _init_llm_client(self, config: LLMConfig | None = None) -> Any:
chat_model=cfg.chat_model,
embed_model=cfg.embed_model,
embed_batch_size=cfg.embed_batch_size,
reasoning_effort=cfg.reasoning_effort,
)
elif backend == "httpx":
return HTTPLLMClient(
Expand Down
4 changes: 4 additions & 0 deletions src/memu/app/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,10 @@ class LLMConfig(BaseModel):
default=1,
description="Maximum batch size for embedding API calls (used by SDK client backends).",
)
reasoning_effort: Literal["low", "medium", "high"] | None = Field(
default="high",
description="Reasoning effort level for chat completions (low/medium/high). Default is 'high'.",
)


class BlobConfig(BaseModel):
Expand Down
9 changes: 8 additions & 1 deletion src/memu/llm/http_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,11 +87,18 @@ def __init__(
self.embed_model = embed_model or chat_model

async def summarize(
self, text: str, max_tokens: int | None = None, system_prompt: str | None = None
self,
text: str,
max_tokens: int | None = None,
system_prompt: str | None = None,
response_format: str | None = None,
) -> tuple[str, dict[str, Any]]:
payload = self.backend.build_summary_payload(
text=text, system_prompt=system_prompt, chat_model=self.chat_model, max_tokens=max_tokens
)
# Add response_format if specified
if response_format == "json_object":
payload["response_format"] = {"type": "json_object"}
async with httpx.AsyncClient(base_url=self.base_url, timeout=self.timeout) as client:
resp = await client.post(self.summary_endpoint, json=payload, headers=self._headers())
resp.raise_for_status()
Expand Down
41 changes: 29 additions & 12 deletions src/memu/llm/openai_sdk.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,14 @@ def __init__(
chat_model: str,
embed_model: str,
embed_batch_size: int = 1,
reasoning_effort: Literal["low", "medium", "high"] | None = "high",
):
self.base_url = base_url.rstrip("/")
self.api_key = api_key or ""
self.chat_model = chat_model
self.embed_model = embed_model
self.embed_batch_size = embed_batch_size
self.reasoning_effort = reasoning_effort
self.client = AsyncOpenAI(api_key=self.api_key, base_url=self.base_url)

async def summarize(
Expand All @@ -42,19 +44,28 @@ async def summarize(
*,
max_tokens: int | None = None,
system_prompt: str | None = None,
response_format: Literal["text", "json_object"] | None = None,
) -> tuple[str, ChatCompletion]:
prompt = system_prompt or "Summarize the text in one short paragraph."

system_message: ChatCompletionSystemMessageParam = {"role": "system", "content": prompt}
user_message: ChatCompletionUserMessageParam = {"role": "user", "content": text}
messages: list[ChatCompletionMessageParam] = [system_message, user_message]

response = await self.client.chat.completions.create(
model=self.chat_model,
messages=messages,
temperature=1,
max_tokens=max_tokens,
)
# Build request kwargs with optional reasoning_effort and response_format
request_kwargs: dict[str, Any] = {
"model": self.chat_model,
"messages": messages,
"temperature": 1,
}
if max_tokens is not None:
request_kwargs["max_tokens"] = max_tokens
if self.reasoning_effort is not None:
request_kwargs["reasoning_effort"] = self.reasoning_effort
if response_format == "json_object":
request_kwargs["response_format"] = {"type": "json_object"}

response = await self.client.chat.completions.create(**request_kwargs)
content = response.choices[0].message.content
logger.debug("OpenAI summarize response: %s", response)
return content or "", response
Expand Down Expand Up @@ -115,12 +126,18 @@ async def vision(
}
messages.append(user_message)

response = await self.client.chat.completions.create(
model=self.chat_model,
messages=messages,
temperature=1,
max_tokens=max_tokens,
)
# Build request kwargs with optional reasoning_effort
request_kwargs: dict[str, Any] = {
"model": self.chat_model,
"messages": messages,
"temperature": 1,
}
if max_tokens is not None:
request_kwargs["max_tokens"] = max_tokens
if self.reasoning_effort is not None:
request_kwargs["reasoning_effort"] = self.reasoning_effort

response = await self.client.chat.completions.create(**request_kwargs)
content = response.choices[0].message.content
logger.debug("OpenAI vision response: %s", response)
return content or "", response
Expand Down
9 changes: 8 additions & 1 deletion src/memu/llm/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,18 +250,25 @@ async def summarize(
*,
max_tokens: int | None = None,
system_prompt: str | None = None,
response_format: str | None = None,
) -> Any:
request_view = _build_text_request_view(
"summarize",
text,
metadata={
"system_prompt_chars": len(system_prompt or ""),
"max_tokens": max_tokens,
"response_format": response_format,
},
)

async def _call() -> Any:
return await self._client.summarize(text, max_tokens=max_tokens, system_prompt=system_prompt)
return await self._client.summarize(
text,
max_tokens=max_tokens,
system_prompt=system_prompt,
response_format=response_format,
)

return await self._invoke(
kind="summarize",
Expand Down