22
33import asyncio
44import json
5- import os
65import re
76from collections .abc import Sequence
87from typing import Any , cast
98
9+ from pydantic import BaseModel
10+
1011from memu .app .settings import AppSettings
1112from memu .llm .http_client import HTTPLLMClient
1213from memu .memory .repo import InMemoryStore
@@ -27,9 +28,9 @@ def __init__(self, settings: AppSettings):
2728 self .store = InMemoryStore ()
2829 backend = (settings .llm_client_backend or "httpx" ).lower ()
2930 self .openai : Any
30- client_kwargs = {
31+ client_kwargs : dict [ str , Any ] = {
3132 "base_url" : settings .openai_base ,
32- "api_key" : os . getenv ( settings .openai_api_key_env , "" ) ,
33+ "api_key" : settings .openai_api_key ,
3334 "chat_model" : settings .chat_model ,
3435 "embed_model" : settings .embed_model ,
3536 }
@@ -39,9 +40,9 @@ def __init__(self, settings: AppSettings):
3940 self .openai = OpenAISDKClient (** client_kwargs )
4041 elif backend == "httpx" :
4142 self .openai = HTTPLLMClient (
42- ** client_kwargs ,
4343 provider = self .settings .llm_http_provider ,
4444 endpoint_overrides = self .settings .llm_http_endpoints ,
45+ ** client_kwargs ,
4546 )
4647 else :
4748 msg = f"Unknown llm_client_backend '{ settings .llm_client_backend } '"
@@ -89,9 +90,9 @@ async def memorize(self, *, resource_url: str, modality: str, summary_prompt: st
8990 await self ._update_category_summaries (category_memory_updates )
9091
9192 return {
92- "resource" : res . model_dump ( ),
93- "items" : [item . model_dump ( ) for item in items ],
94- "categories" : [self .store .categories [c ]. model_dump ( ) for c in cat_ids ],
93+ "resource" : self . _model_dump_without_embeddings ( res ),
94+ "items" : [self . _model_dump_without_embeddings ( item ) for item in items ],
95+ "categories" : [self ._model_dump_without_embeddings ( self . store .categories [c ]) for c in cat_ids ],
9596 "relations" : [r .model_dump () for r in rels ],
9697 }
9798
@@ -110,7 +111,7 @@ async def _create_resource_with_caption(
110111 caption_text = caption .strip ()
111112 if caption_text :
112113 res .caption = caption_text
113- res .caption_embedding = (await self .openai .embed ([caption_text ]))[0 ]
114+ res .embedding = (await self .openai .embed ([caption_text ]))[0 ]
114115 return res
115116
116117 def _resolve_memory_types (self ) -> list [MemoryType ]:
@@ -365,6 +366,11 @@ def _extract_json_blob(raw: str) -> str:
365366 def _escape_prompt_value (value : str ) -> str :
366367 return value .replace ("{" , "{{" ).replace ("}" , "}}" )
367368
369+ def _model_dump_without_embeddings (self , obj : BaseModel ) -> dict [str , Any ]:
370+ data = obj .model_dump ()
371+ data .pop ("embedding" , None )
372+ return data
373+
368374 async def retrieve (self , query : str , * , top_k : int = 5 ) -> dict [str , Any ]:
369375 qvec = (await self .openai .embed ([query ]))[0 ]
370376 response : dict [str , list [dict [str , Any ]]] = {"resources" : [], "items" : [], "categories" : []}
@@ -413,7 +419,7 @@ def _materialize_hits(self, hits: Sequence[tuple[str, float]], pool: dict[str, A
413419 obj = pool .get (_id )
414420 if not obj :
415421 continue
416- data = obj . model_dump ( )
422+ data = self . _model_dump_without_embeddings ( obj )
417423 data ["score" ] = float (score )
418424 out .append (data )
419425 return out
@@ -450,8 +456,8 @@ def _format_resource_content(self, hits: list[tuple[str, float]]) -> str:
450456 def _resource_caption_corpus (self ) -> list [tuple [str , list [float ]]]:
451457 corpus : list [tuple [str , list [float ]]] = []
452458 for rid , res in self .store .resources .items ():
453- if res .caption_embedding :
454- corpus .append ((rid , res .caption_embedding ))
459+ if res .embedding :
460+ corpus .append ((rid , res .embedding ))
455461 return corpus
456462
457463 async def _judge_retrieval_sufficient (self , query : str , content : str ) -> bool :
0 commit comments