From e2f3e2d5bcf0ccccc8304a7686121817198df0dc Mon Sep 17 00:00:00 2001 From: "Michal.Kozal" Date: Sun, 3 May 2026 11:07:06 +0200 Subject: [PATCH] feat(custom-tier): centralize tier configs and add ability to have your own preset --- tests/test_custom_tier.py | 40 ++++++++++++++++ truememory/config.py | 92 +++++++++++++++++++++++++++++++++++++ truememory/mcp_server.py | 15 +++--- truememory/reranker.py | 24 ++++------ truememory/vector_search.py | 44 +++++++++--------- 5 files changed, 171 insertions(+), 44 deletions(-) create mode 100644 tests/test_custom_tier.py create mode 100644 truememory/config.py diff --git a/tests/test_custom_tier.py b/tests/test_custom_tier.py new file mode 100644 index 0000000..871febb --- /dev/null +++ b/tests/test_custom_tier.py @@ -0,0 +1,40 @@ +import os +import pytest +from truememory import config +from truememory.vector_search import _resolve_model_name +from truememory.reranker import get_reranker_name_for_tier + +def test_custom_tier_env_vars(monkeypatch): + monkeypatch.setenv("TRUEMEMORY_CUSTOM_EMBED_MODEL", "custom-embed") + monkeypatch.setenv("TRUEMEMORY_CUSTOM_RERANKER", "custom-reranker") + monkeypatch.setenv("TRUEMEMORY_CUSTOM_EMBED_DIM", "512") + + conf = config.get_tier_config("custom") + assert conf["embed_model"] == "custom-embed" + assert conf["reranker"] == "custom-reranker" + assert conf["embed_dim"] == 512 + +def test_custom_tier_resolution(monkeypatch): + monkeypatch.setenv("TRUEMEMORY_EMBED_MODEL", "custom") + monkeypatch.setenv("TRUEMEMORY_CUSTOM_EMBED_MODEL", "my-model") + monkeypatch.setenv("TRUEMEMORY_CUSTOM_RERANKER", "my-reranker") + + assert _resolve_model_name("custom") == "my-model" + assert get_reranker_name_for_tier("custom") == "my-reranker" + +def test_standard_tiers(): + assert config.get_tier_config("edge")["embed_model"] == "model2vec" + assert config.get_tier_config("base")["embed_model"] == "qwen3_256" + assert config.get_tier_config("pro")["embed_model"] == "qwen3_256" + +def test_fallback_logic(): + # Unknown tier falls back to edge + assert config.get_tier_config("unknown")["embed_model"] == "model2vec" + +def test_embedding_dim_resolution(): + assert config.get_embedding_dim("model2vec") == 256 + assert config.get_embedding_dim("minilm") == 384 + + # Tier name resolution + assert config.get_embedding_dim("edge") == 256 + assert config.get_embedding_dim("base") == 256 diff --git a/truememory/config.py b/truememory/config.py new file mode 100644 index 0000000..190bebb --- /dev/null +++ b/truememory/config.py @@ -0,0 +1,92 @@ +""" +TrueMemory Configuration +======================== + +Centralized configuration for memory tiers, embedding models, and rerankers. +Allows for easy customization of the memory pipeline via standard tiers or +a 'custom' tier that reads from environment variables. +""" + +import os +import logging + +log = logging.getLogger(__name__) + +# Default tier mappings (v0.4.0 paper-aligned) +DEFAULT_TIERS = { + "edge": { + "embed_model": "model2vec", + "reranker": "cross-encoder/ms-marco-MiniLM-L-6-v2", + "embed_dim": 256, + }, + "base": { + "embed_model": "qwen3_256", + "reranker": "Alibaba-NLP/gte-reranker-modernbert-base", + "embed_dim": 256, + }, + "pro": { + "embed_model": "qwen3_256", + "reranker": "Alibaba-NLP/gte-reranker-modernbert-base", + "embed_dim": 256, + }, +} + +# Legacy model dimensions (for backward compatibility or explicit model selection) +MODEL_DIMS = { + "model2vec": 256, + "minilm": 384, + "bge-small": 384, + "qwen3_256": 256, +} + + +def get_tier_config(tier_name: str | None = None) -> dict: + """ + Resolve the configuration for a given tier name. + + If tier_name is 'custom', reads from environment variables: + - TRUEMEMORY_CUSTOM_EMBED_MODEL + - TRUEMEMORY_CUSTOM_RERANKER + - TRUEMEMORY_CUSTOM_EMBED_DIM + + Otherwise, returns the configuration from DEFAULT_TIERS. + If tier_name is None, defaults to the TRUEMEMORY_EMBED_MODEL env var or 'edge'. + """ + if tier_name is None: + tier_name = os.environ.get("TRUEMEMORY_EMBED_MODEL", "edge") + + lowered = tier_name.lower().strip() + + if lowered == "custom": + custom_embed = os.environ.get("TRUEMEMORY_CUSTOM_EMBED_MODEL", "model2vec") + custom_reranker = os.environ.get( + "TRUEMEMORY_CUSTOM_RERANKER", "cross-encoder/ms-marco-MiniLM-L-6-v2" + ) + try: + custom_dim = int(os.environ.get("TRUEMEMORY_CUSTOM_EMBED_DIM", "256")) + except ValueError: + log.warning("Invalid TRUEMEMORY_CUSTOM_EMBED_DIM; falling back to 256") + custom_dim = 256 + + return { + "embed_model": custom_embed, + "reranker": custom_reranker, + "embed_dim": custom_dim, + } + + # Return default or fallback to edge + return DEFAULT_TIERS.get(lowered, DEFAULT_TIERS["edge"]) + + +def get_embedding_dim(model_name: str) -> int: + """Return the dimension for a specific internal model name.""" + # Check MODEL_DIMS first, then fallback to config resolution if it's a tier name + if model_name in MODEL_DIMS: + return MODEL_DIMS[model_name] + + # If it's a tier name, resolve its config + if model_name in DEFAULT_TIERS or model_name == "custom": + return get_tier_config(model_name)["embed_dim"] + + # Final fallback for arbitrary HF IDs + return 256 diff --git a/truememory/mcp_server.py b/truememory/mcp_server.py index 7091a9d..dff4087 100644 --- a/truememory/mcp_server.py +++ b/truememory/mcp_server.py @@ -659,7 +659,10 @@ def truememory_stats() -> str: " Same models as Base plus HyDE query expansion.\n" " Requires an API key (Anthropic / OpenRouter / OpenAI) for the HyDE LLM call.\n" "\n" - "Which would you like: Edge, Base, or Pro?" + " Custom — Inject your own models via environment variables.\n" + " See documentation for TRUEMEMORY_CUSTOM_* vars.\n" + "\n" + "Which would you like: Edge, Base, Pro, or Custom?" ) stats["has_api_key"] = bool( os.environ.get("ANTHROPIC_API_KEY") @@ -691,8 +694,8 @@ def truememory_configure( """ global _memory tier = tier.lower().strip() - if tier not in ("edge", "base", "pro"): - return json.dumps({"error": "tier must be 'edge', 'base', or 'pro'"}) + if tier not in ("edge", "base", "pro", "custom"): + return json.dumps({"error": "tier must be 'edge', 'base', 'pro', or 'custom'"}) # Validate API key + provider pairing if api_key and not api_provider: @@ -706,9 +709,9 @@ def truememory_configure( "error": "api_provider must be one of: anthropic, openrouter, openai", }) - # Check Base / Pro dependencies before committing (both need sentence-transformers - # for the Qwen3 embedder + gte-reranker). - if tier in ("base", "pro"): + # Check Base / Pro / Custom dependencies before committing (these typically need + # sentence-transformers for the Qwen3/custom embedder + gte-reranker). + if tier in ("base", "pro", "custom"): try: import sentence_transformers # noqa: F401 except ImportError: diff --git a/truememory/reranker.py b/truememory/reranker.py index c34075c..8d97604 100644 --- a/truememory/reranker.py +++ b/truememory/reranker.py @@ -29,6 +29,8 @@ import threading from typing import TYPE_CHECKING +from truememory import config + log = logging.getLogger(__name__) if TYPE_CHECKING: @@ -43,10 +45,7 @@ _lock = threading.Lock() _inference_lock = threading.Lock() # Protects concurrent model.predict() calls -# --------------------------------------------------------------------------- -# Tier-aware reranker resolution (v0.4.0 paper §2.0) -# --------------------------------------------------------------------------- -# +# v0.4.0 paper §2.0 models: # Edge uses the lightweight MiniLM cross-encoder (22M params, CPU-friendly). # Base and Pro use gte-reranker-modernbert-base (149M, GPU recommended) — # required to reach the 91.5% / 91.8% LoCoMo targets for those tiers. @@ -55,24 +54,17 @@ # get_current_reranker_name() call (from TRUEMEMORY_EMBED_MODEL env var or # ~/.truememory/config.json), and can be updated at runtime via # set_active_tier() — the MCP server calls this on truememory_configure. -_TIER_RERANKERS = { - "edge": "cross-encoder/ms-marco-MiniLM-L-6-v2", - "base": "Alibaba-NLP/gte-reranker-modernbert-base", - "pro": "Alibaba-NLP/gte-reranker-modernbert-base", -} _active_tier: str | None = None # None = not yet resolved; resolved lazily def get_reranker_name_for_tier(tier: str) -> str: - """Pure mapping from tier name ("edge" / "base" / "pro") to reranker HF ID. + """Pure mapping from tier name ("edge" / "base" / "pro" / "custom") to reranker HF ID. Case-insensitive. Unknown or empty tier names fall back to the Edge default (MiniLM). Does not load any model — use ``get_reranker`` for that. """ - if not tier: - return _TIER_RERANKERS["edge"] - return _TIER_RERANKERS.get(tier.lower(), _TIER_RERANKERS["edge"]) + return config.get_tier_config(tier)["reranker"] def _resolve_tier_from_env_and_config() -> str: @@ -85,7 +77,7 @@ def _resolve_tier_from_env_and_config() -> str: """ import os env = os.environ.get("TRUEMEMORY_EMBED_MODEL", "").strip().lower() - if env in ("edge", "base", "pro"): + if env in ("edge", "base", "pro", "custom"): return env try: from pathlib import Path @@ -94,7 +86,7 @@ def _resolve_tier_from_env_and_config() -> str: if cfg_path.exists(): data = json.loads(cfg_path.read_text()) tier = (data.get("tier") or "").strip().lower() - if tier in ("edge", "base", "pro"): + if tier in ("edge", "base", "pro", "custom"): return tier except (json.JSONDecodeError, OSError) as e: # Hunter F04 duplicate: previously a bare `except Exception: pass` @@ -124,7 +116,7 @@ def set_active_tier(tier: str) -> None: _active_tier = "edge" return t = tier.strip().lower() - _active_tier = t if t in ("edge", "base", "pro") else "edge" + _active_tier = t if t in ("edge", "base", "pro", "custom") else "edge" def get_current_reranker_name() -> str: diff --git a/truememory/vector_search.py b/truememory/vector_search.py index 1b2a02e..06a2db6 100644 --- a/truememory/vector_search.py +++ b/truememory/vector_search.py @@ -48,6 +48,8 @@ import numpy as np +from truememory import config + if TYPE_CHECKING: pass @@ -68,46 +70,38 @@ class TrueMemoryMigrationError(Exception): # Singleton model loader # --------------------------------------------------------------------------- -# Public tier names → internal model identifiers (v0.4.0 paper-aligned Edge/Base/Pro) -_TIER_ALIASES = { - "edge": "model2vec", - "base": "qwen3_256", - "pro": "qwen3_256", -} - -_MODEL_DIMS = { - "model2vec": 256, - "minilm": 384, - "bge-small": 384, - "qwen3_256": 256, -} - # v0.4.0 breaking change: the old internal name "qwen3" (1024d native) is gone. # Callers must migrate to "pro" (tier alias) or "qwen3_256" (internal name). _REMOVED_MODELS = {"qwen3"} def _resolve_model_name(name: str) -> str: - """Resolve a public tier name (edge/base/pro) or internal model name. + """Resolve a public tier name (edge/base/pro/custom) or internal model name. Raises: ValueError: if *name* refers to a model removed in v0.4.0. """ - lowered = name.lower() + lowered = name.lower().strip() if lowered in _REMOVED_MODELS: raise ValueError( f"Embedding model {name!r} was removed in TrueMemory v0.4.0. " f"Migrate to 'pro' (tier alias) or 'qwen3_256' (internal name) — " f"the paper-aligned Qwen3-Embedding-0.6B @ 256d Matryoshka config." ) - return _TIER_ALIASES.get(lowered, name) + + # Check if it's a known tier (including 'custom') + if lowered in config.DEFAULT_TIERS or lowered == "custom": + return config.get_tier_config(lowered)["embed_model"] + + # Otherwise assume it's already an internal name or HF ID + return name _raw_env = os.environ.get("TRUEMEMORY_EMBED_MODEL", "edge") EMBEDDING_MODEL = _resolve_model_name(_raw_env) _model = None -_embedding_dim: int = _MODEL_DIMS.get(EMBEDDING_MODEL, 256) +_embedding_dim: int = config.get_embedding_dim(EMBEDDING_MODEL) _lock = threading.Lock() @@ -126,7 +120,7 @@ def set_embedding_model(name: str) -> None: def get_embedding_dim(name: str | None = None) -> int: """Return the embedding dimension for a given model name.""" name = _resolve_model_name(name) if name else EMBEDDING_MODEL - return _MODEL_DIMS.get(name, 256) + return config.get_embedding_dim(name) def get_model(): @@ -158,9 +152,15 @@ def get_model(): ) _embedding_dim = 256 else: - from model2vec import StaticModel - _model = StaticModel.from_pretrained("minishlab/potion-base-8M", force_download=False) - _embedding_dim = 256 + # Assume it's a HuggingFace ID + try: + from sentence_transformers import SentenceTransformer + _model = SentenceTransformer(resolved) + except Exception as e: + logger.warning("Failed to load model %r as SentenceTransformer: %s. Falling back to Model2Vec.", resolved, e) + from model2vec import StaticModel + _model = StaticModel.from_pretrained("minishlab/potion-base-8M", force_download=False) + _embedding_dim = 256 return _model