Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions tests/test_custom_tier.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import os
import pytest
from truememory import config
from truememory.vector_search import _resolve_model_name
from truememory.reranker import get_reranker_name_for_tier

def test_custom_tier_env_vars(monkeypatch):
monkeypatch.setenv("TRUEMEMORY_CUSTOM_EMBED_MODEL", "custom-embed")
monkeypatch.setenv("TRUEMEMORY_CUSTOM_RERANKER", "custom-reranker")
monkeypatch.setenv("TRUEMEMORY_CUSTOM_EMBED_DIM", "512")

conf = config.get_tier_config("custom")
assert conf["embed_model"] == "custom-embed"
assert conf["reranker"] == "custom-reranker"
assert conf["embed_dim"] == 512

def test_custom_tier_resolution(monkeypatch):
monkeypatch.setenv("TRUEMEMORY_EMBED_MODEL", "custom")
monkeypatch.setenv("TRUEMEMORY_CUSTOM_EMBED_MODEL", "my-model")
monkeypatch.setenv("TRUEMEMORY_CUSTOM_RERANKER", "my-reranker")

assert _resolve_model_name("custom") == "my-model"
assert get_reranker_name_for_tier("custom") == "my-reranker"

def test_standard_tiers():
assert config.get_tier_config("edge")["embed_model"] == "model2vec"
assert config.get_tier_config("base")["embed_model"] == "qwen3_256"
assert config.get_tier_config("pro")["embed_model"] == "qwen3_256"

def test_fallback_logic():
# Unknown tier falls back to edge
assert config.get_tier_config("unknown")["embed_model"] == "model2vec"

def test_embedding_dim_resolution():
assert config.get_embedding_dim("model2vec") == 256
assert config.get_embedding_dim("minilm") == 384

# Tier name resolution
assert config.get_embedding_dim("edge") == 256
assert config.get_embedding_dim("base") == 256
92 changes: 92 additions & 0 deletions truememory/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
"""
TrueMemory Configuration
========================

Centralized configuration for memory tiers, embedding models, and rerankers.
Allows for easy customization of the memory pipeline via standard tiers or
a 'custom' tier that reads from environment variables.
"""

import os
import logging

log = logging.getLogger(__name__)

# Default tier mappings (v0.4.0 paper-aligned)
DEFAULT_TIERS = {
"edge": {
"embed_model": "model2vec",
"reranker": "cross-encoder/ms-marco-MiniLM-L-6-v2",
"embed_dim": 256,
},
"base": {
"embed_model": "qwen3_256",
"reranker": "Alibaba-NLP/gte-reranker-modernbert-base",
"embed_dim": 256,
},
"pro": {
"embed_model": "qwen3_256",
"reranker": "Alibaba-NLP/gte-reranker-modernbert-base",
"embed_dim": 256,
},
}

# Legacy model dimensions (for backward compatibility or explicit model selection)
MODEL_DIMS = {
"model2vec": 256,
"minilm": 384,
"bge-small": 384,
"qwen3_256": 256,
}


def get_tier_config(tier_name: str | None = None) -> dict:
"""
Resolve the configuration for a given tier name.

If tier_name is 'custom', reads from environment variables:
- TRUEMEMORY_CUSTOM_EMBED_MODEL
- TRUEMEMORY_CUSTOM_RERANKER
- TRUEMEMORY_CUSTOM_EMBED_DIM

Otherwise, returns the configuration from DEFAULT_TIERS.
If tier_name is None, defaults to the TRUEMEMORY_EMBED_MODEL env var or 'edge'.
"""
if tier_name is None:
tier_name = os.environ.get("TRUEMEMORY_EMBED_MODEL", "edge")

lowered = tier_name.lower().strip()

if lowered == "custom":
custom_embed = os.environ.get("TRUEMEMORY_CUSTOM_EMBED_MODEL", "model2vec")
custom_reranker = os.environ.get(
"TRUEMEMORY_CUSTOM_RERANKER", "cross-encoder/ms-marco-MiniLM-L-6-v2"
)
try:
custom_dim = int(os.environ.get("TRUEMEMORY_CUSTOM_EMBED_DIM", "256"))
except ValueError:
log.warning("Invalid TRUEMEMORY_CUSTOM_EMBED_DIM; falling back to 256")
custom_dim = 256

return {
"embed_model": custom_embed,
"reranker": custom_reranker,
"embed_dim": custom_dim,
}

# Return default or fallback to edge
return DEFAULT_TIERS.get(lowered, DEFAULT_TIERS["edge"])


def get_embedding_dim(model_name: str) -> int:
"""Return the dimension for a specific internal model name."""
# Check MODEL_DIMS first, then fallback to config resolution if it's a tier name
if model_name in MODEL_DIMS:
return MODEL_DIMS[model_name]

# If it's a tier name, resolve its config
if model_name in DEFAULT_TIERS or model_name == "custom":
return get_tier_config(model_name)["embed_dim"]

# Final fallback for arbitrary HF IDs
return 256
15 changes: 9 additions & 6 deletions truememory/mcp_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -659,7 +659,10 @@ def truememory_stats() -> str:
" Same models as Base plus HyDE query expansion.\n"
" Requires an API key (Anthropic / OpenRouter / OpenAI) for the HyDE LLM call.\n"
"\n"
"Which would you like: Edge, Base, or Pro?"
" Custom — Inject your own models via environment variables.\n"
" See documentation for TRUEMEMORY_CUSTOM_* vars.\n"
"\n"
"Which would you like: Edge, Base, Pro, or Custom?"
)
stats["has_api_key"] = bool(
os.environ.get("ANTHROPIC_API_KEY")
Expand Down Expand Up @@ -691,8 +694,8 @@ def truememory_configure(
"""
global _memory
tier = tier.lower().strip()
if tier not in ("edge", "base", "pro"):
return json.dumps({"error": "tier must be 'edge', 'base', or 'pro'"})
if tier not in ("edge", "base", "pro", "custom"):
return json.dumps({"error": "tier must be 'edge', 'base', 'pro', or 'custom'"})

# Validate API key + provider pairing
if api_key and not api_provider:
Expand All @@ -706,9 +709,9 @@ def truememory_configure(
"error": "api_provider must be one of: anthropic, openrouter, openai",
})

# Check Base / Pro dependencies before committing (both need sentence-transformers
# for the Qwen3 embedder + gte-reranker).
if tier in ("base", "pro"):
# Check Base / Pro / Custom dependencies before committing (these typically need
# sentence-transformers for the Qwen3/custom embedder + gte-reranker).
if tier in ("base", "pro", "custom"):
try:
import sentence_transformers # noqa: F401
except ImportError:
Expand Down
24 changes: 8 additions & 16 deletions truememory/reranker.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@
import threading
from typing import TYPE_CHECKING

from truememory import config

log = logging.getLogger(__name__)

if TYPE_CHECKING:
Expand All @@ -43,10 +45,7 @@
_lock = threading.Lock()
_inference_lock = threading.Lock() # Protects concurrent model.predict() calls

# ---------------------------------------------------------------------------
# Tier-aware reranker resolution (v0.4.0 paper §2.0)
# ---------------------------------------------------------------------------
#
# v0.4.0 paper §2.0 models:
# Edge uses the lightweight MiniLM cross-encoder (22M params, CPU-friendly).
# Base and Pro use gte-reranker-modernbert-base (149M, GPU recommended) —
# required to reach the 91.5% / 91.8% LoCoMo targets for those tiers.
Expand All @@ -55,24 +54,17 @@
# get_current_reranker_name() call (from TRUEMEMORY_EMBED_MODEL env var or
# ~/.truememory/config.json), and can be updated at runtime via
# set_active_tier() — the MCP server calls this on truememory_configure.
_TIER_RERANKERS = {
"edge": "cross-encoder/ms-marco-MiniLM-L-6-v2",
"base": "Alibaba-NLP/gte-reranker-modernbert-base",
"pro": "Alibaba-NLP/gte-reranker-modernbert-base",
}

_active_tier: str | None = None # None = not yet resolved; resolved lazily


def get_reranker_name_for_tier(tier: str) -> str:
"""Pure mapping from tier name ("edge" / "base" / "pro") to reranker HF ID.
"""Pure mapping from tier name ("edge" / "base" / "pro" / "custom") to reranker HF ID.

Case-insensitive. Unknown or empty tier names fall back to the Edge
default (MiniLM). Does not load any model — use ``get_reranker`` for that.
"""
if not tier:
return _TIER_RERANKERS["edge"]
return _TIER_RERANKERS.get(tier.lower(), _TIER_RERANKERS["edge"])
return config.get_tier_config(tier)["reranker"]


def _resolve_tier_from_env_and_config() -> str:
Expand All @@ -85,7 +77,7 @@ def _resolve_tier_from_env_and_config() -> str:
"""
import os
env = os.environ.get("TRUEMEMORY_EMBED_MODEL", "").strip().lower()
if env in ("edge", "base", "pro"):
if env in ("edge", "base", "pro", "custom"):
return env
try:
from pathlib import Path
Expand All @@ -94,7 +86,7 @@ def _resolve_tier_from_env_and_config() -> str:
if cfg_path.exists():
data = json.loads(cfg_path.read_text())
tier = (data.get("tier") or "").strip().lower()
if tier in ("edge", "base", "pro"):
if tier in ("edge", "base", "pro", "custom"):
return tier
except (json.JSONDecodeError, OSError) as e:
# Hunter F04 duplicate: previously a bare `except Exception: pass`
Expand Down Expand Up @@ -124,7 +116,7 @@ def set_active_tier(tier: str) -> None:
_active_tier = "edge"
return
t = tier.strip().lower()
_active_tier = t if t in ("edge", "base", "pro") else "edge"
_active_tier = t if t in ("edge", "base", "pro", "custom") else "edge"


def get_current_reranker_name() -> str:
Expand Down
44 changes: 22 additions & 22 deletions truememory/vector_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@

import numpy as np

from truememory import config

if TYPE_CHECKING:
pass

Expand All @@ -68,46 +70,38 @@ class TrueMemoryMigrationError(Exception):
# Singleton model loader
# ---------------------------------------------------------------------------

# Public tier names → internal model identifiers (v0.4.0 paper-aligned Edge/Base/Pro)
_TIER_ALIASES = {
"edge": "model2vec",
"base": "qwen3_256",
"pro": "qwen3_256",
}

_MODEL_DIMS = {
"model2vec": 256,
"minilm": 384,
"bge-small": 384,
"qwen3_256": 256,
}

# v0.4.0 breaking change: the old internal name "qwen3" (1024d native) is gone.
# Callers must migrate to "pro" (tier alias) or "qwen3_256" (internal name).
_REMOVED_MODELS = {"qwen3"}


def _resolve_model_name(name: str) -> str:
"""Resolve a public tier name (edge/base/pro) or internal model name.
"""Resolve a public tier name (edge/base/pro/custom) or internal model name.

Raises:
ValueError: if *name* refers to a model removed in v0.4.0.
"""
lowered = name.lower()
lowered = name.lower().strip()
if lowered in _REMOVED_MODELS:
raise ValueError(
f"Embedding model {name!r} was removed in TrueMemory v0.4.0. "
f"Migrate to 'pro' (tier alias) or 'qwen3_256' (internal name) — "
f"the paper-aligned Qwen3-Embedding-0.6B @ 256d Matryoshka config."
)
return _TIER_ALIASES.get(lowered, name)

# Check if it's a known tier (including 'custom')
if lowered in config.DEFAULT_TIERS or lowered == "custom":
return config.get_tier_config(lowered)["embed_model"]

# Otherwise assume it's already an internal name or HF ID
return name


_raw_env = os.environ.get("TRUEMEMORY_EMBED_MODEL", "edge")
EMBEDDING_MODEL = _resolve_model_name(_raw_env)

_model = None
_embedding_dim: int = _MODEL_DIMS.get(EMBEDDING_MODEL, 256)
_embedding_dim: int = config.get_embedding_dim(EMBEDDING_MODEL)
_lock = threading.Lock()


Expand All @@ -126,7 +120,7 @@ def set_embedding_model(name: str) -> None:
def get_embedding_dim(name: str | None = None) -> int:
"""Return the embedding dimension for a given model name."""
name = _resolve_model_name(name) if name else EMBEDDING_MODEL
return _MODEL_DIMS.get(name, 256)
return config.get_embedding_dim(name)


def get_model():
Expand Down Expand Up @@ -158,9 +152,15 @@ def get_model():
)
_embedding_dim = 256
else:
from model2vec import StaticModel
_model = StaticModel.from_pretrained("minishlab/potion-base-8M", force_download=False)
_embedding_dim = 256
# Assume it's a HuggingFace ID
try:
from sentence_transformers import SentenceTransformer
_model = SentenceTransformer(resolved)
except Exception as e:
logger.warning("Failed to load model %r as SentenceTransformer: %s. Falling back to Model2Vec.", resolved, e)
from model2vec import StaticModel
_model = StaticModel.from_pretrained("minishlab/potion-base-8M", force_download=False)
_embedding_dim = 256
return _model


Expand Down