Skip to content

Commit 87fdcd8

Browse files
committed
feat: add postgres backend + improve sqlite backend
1 parent b7d7340 commit 87fdcd8

File tree

9 files changed

+729
-43
lines changed

9 files changed

+729
-43
lines changed

examples/postgres_demo.py

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
from typing import List
2+
3+
from pydantic import BaseModel
4+
from testcontainers.postgres import PostgresContainer
5+
6+
from memstate import Constraint, MemoryStore
7+
from memstate.backends.postgres import PostgresStorage
8+
9+
10+
# --- Data Model ---
11+
class UserProfile(BaseModel):
12+
email: str
13+
full_name: str
14+
role: str
15+
level: str = "Junior"
16+
skills: List[str] = []
17+
18+
19+
def print_fact(title, fact):
20+
print(title)
21+
if fact:
22+
print(f" ID: {fact['id']}")
23+
print(f" Payload: {fact['payload']}")
24+
else:
25+
print(" None")
26+
print()
27+
28+
29+
# --- MAIN DEMO ---
30+
def main():
31+
print(f"🚀 MemState + PostgreSQL (JSONB) Demo\n")
32+
33+
# 1. Start Postgres in Docker (Automatic)
34+
print("🐳 Starting Postgres container...")
35+
with PostgresContainer("postgres:18-alpine") as postgres:
36+
37+
# Fix driver string for SQLAlchemy (testcontainers returns old format)
38+
raw_url = postgres.get_connection_url()
39+
connection_string = raw_url.replace("postgresql+psycopg2://", "postgresql+psycopg://")
40+
41+
print(f"🔌 Connecting to: {connection_string}")
42+
43+
# 2. Init Storage & Memory
44+
pg_storage = PostgresStorage(connection_string)
45+
memory = MemoryStore(pg_storage)
46+
47+
# 3. Register Schema with SINGLETON Constraint
48+
# "email" is the unique key. If we commit a new model with the same email,
49+
# MemState will UPDATE the existing record instead of creating a duplicate.
50+
memory.register_schema(typename="user_profile", model=UserProfile, constraint=Constraint(singleton_key="email"))
51+
52+
# --- SCENARIO START ---
53+
54+
# Step 4: Create Initial Profile (Junior)
55+
print(f"\n1️⃣ Agent creates a Junior profile...")
56+
57+
profile_v1 = UserProfile(
58+
email="[email protected]", full_name="Alex Dev", role="Backend", level="Junior", skills=["Python"]
59+
)
60+
61+
# Using commit_model (High-Level API)
62+
# Note: We do NOT pass fact_id. MemState creates a new one.
63+
fact_id = memory.commit_model(profile_v1, actor="Agent_Smith", reason="Initial onboarding")
64+
65+
current = pg_storage.load(fact_id)
66+
print_fact("Current State (Junior):", current)
67+
68+
# Step 5: Update Profile (Singleton Logic)
69+
print(f"2️⃣ Agent finds LinkedIn info. Updating to Senior...")
70+
71+
profile_v2 = UserProfile(
72+
email="[email protected]", # SAME EMAIL triggers Singleton Update
73+
full_name="Alex Dev",
74+
role="Tech Lead",
75+
level="Senior",
76+
skills=["Python", "Architecture", "Postgres"],
77+
)
78+
79+
# We perform a new commit. MemState detects email match and performs UPDATE.
80+
memory.commit_model(profile_v2, actor="Agent_Smith", reason="LinkedIn data enrichment")
81+
82+
current = pg_storage.load(fact_id)
83+
print_fact("Current State (Senior):", current)
84+
85+
# Step 6: JSONB Querying
86+
print(f"3️⃣ Testing Postgres JSONB Querying...")
87+
print(" Query: SELECT * WHERE payload->>'level' == 'Senior'")
88+
89+
results = memory.query(
90+
typename="user_profile", filters={"payload.level": "Senior"} # MemState converts this to JSONB path
91+
)
92+
93+
if len(results) == 1:
94+
print(f"✅ Found correct user: {results[0]['payload']['full_name']}")
95+
else:
96+
print(f"❌ Query failed!")
97+
98+
# Step 7: Audit Log (Compliance)
99+
print(f"\n4️⃣ Checking Transaction Log (History)...")
100+
# Assuming you implemented get_tx_log in PostgresStorage
101+
history = pg_storage.get_tx_log(limit=5)
102+
103+
for tx in history:
104+
op = tx.get("op", "UNKNOWN")
105+
actor = tx.get("actor", "System")
106+
reason = tx.get("reason", "None")
107+
print(f" 📜 [{op}] by {actor}: {reason}")
108+
109+
# Step 8: Rollback
110+
print(f"\n5️⃣ Oops! Update was a mistake. Rolling back...")
111+
memory.rollback(1)
112+
113+
final = pg_storage.load(fact_id)
114+
print(f" Restored Level: {final['payload']['level']}")
115+
116+
if final["payload"]["level"] == "Junior":
117+
print(f"\n✨ ACID Rollback successful! Data restored to Junior.")
118+
119+
120+
if __name__ == "__main__":
121+
main()

memstate/backends/postgres.py

Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
from typing import Any
2+
3+
try:
4+
from sqlalchemy import (
5+
Column,
6+
ColumnElement,
7+
Integer,
8+
MetaData,
9+
String,
10+
Table,
11+
create_engine,
12+
delete,
13+
desc,
14+
func,
15+
select,
16+
)
17+
from sqlalchemy.dialects.postgresql import JSONB
18+
from sqlalchemy.dialects.postgresql import insert as pg_insert
19+
from sqlalchemy.engine import Engine
20+
except ImportError:
21+
raise ImportError("Run `pip install postgres[binary]` to use Postgres backend.")
22+
23+
from memstate.backends.base import StorageBackend
24+
25+
26+
class PostgresStorage(StorageBackend):
27+
def __init__(self, engine_or_url: str | Engine, table_prefix: str = "memstate") -> None:
28+
if isinstance(engine_or_url, str):
29+
self._engine = create_engine(engine_or_url, future=True)
30+
else:
31+
self._engine = engine_or_url
32+
33+
self._metadata = MetaData()
34+
self._table_prefix = table_prefix
35+
36+
# --- Define Tables ---
37+
self._facts_table = Table(
38+
f"{table_prefix}_facts",
39+
self._metadata,
40+
Column("id", String, primary_key=True),
41+
Column("doc", JSONB, nullable=False), # Используем JSONB для индексации
42+
)
43+
44+
self._log_table = Table(
45+
f"{table_prefix}_log",
46+
self._metadata,
47+
Column("seq", Integer, primary_key=True, autoincrement=True),
48+
Column("entry", JSONB, nullable=False),
49+
)
50+
51+
with self._engine.begin() as conn:
52+
self._metadata.create_all(conn)
53+
54+
def load(self, id: str) -> dict[str, Any] | None:
55+
with self._engine.connect() as conn:
56+
stmt = select(self._facts_table.c.doc).where(self._facts_table.c.id == id)
57+
row = conn.execute(stmt).first()
58+
if row:
59+
return row[0] # SQLAlchemy deserializes JSONB automatically
60+
return None
61+
62+
def save(self, fact_data: dict[str, Any]) -> None:
63+
# Postgres Native Upsert (INSERT ... ON CONFLICT DO UPDATE)
64+
stmt = pg_insert(self._facts_table).values(id=fact_data["id"], doc=fact_data)
65+
upsert_stmt = stmt.on_conflict_do_update(
66+
index_elements=["id"], set_={"doc": stmt.excluded.doc} # Conflict over PK
67+
)
68+
69+
with self._engine.begin() as conn:
70+
conn.execute(upsert_stmt)
71+
72+
def delete(self, id: str) -> None:
73+
with self._engine.begin() as conn:
74+
conn.execute(delete(self._facts_table).where(self._facts_table.c.id == id))
75+
76+
def query(self, type_filter: str | None = None, json_filters: dict[str, Any] | None = None) -> list[dict[str, Any]]:
77+
78+
stmt = select(self._facts_table.c.doc)
79+
80+
# 1. Filter by type (fact)
81+
if type_filter:
82+
# Postgres JSONB access: doc->>'type'
83+
stmt = stmt.where(self._facts_table.c.doc["type"].astext == type_filter)
84+
85+
# 2. JSON filters (the hardest part)
86+
# We expect keys of type "payload.user.id"
87+
if json_filters:
88+
for key, value in json_filters.items():
89+
# Split the path: payload.role -> ['payload', 'role']
90+
path_parts = key.split(".")
91+
92+
# Building a JSONB access chain
93+
json_col: ColumnElement[Any] = self._facts_table.c.doc
94+
95+
# Go deeper to the last key
96+
for part in path_parts[:-1]:
97+
json_col = json_col[part]
98+
99+
# Compare the last key
100+
# Important: cast value to JSONB so that types (int/bool/str) work
101+
# Or use the @> (contains) operator for reliability
102+
103+
# Simple option (SQLAlchemy automatically casts types when comparing JSONB)
104+
stmt = stmt.where(json_col[path_parts[-1]] == func.to_jsonb(value))
105+
106+
with self._engine.connect() as conn:
107+
rows = conn.execute(stmt).all()
108+
return [r[0] for r in rows]
109+
110+
def append_tx(self, tx_data: dict[str, Any]) -> None:
111+
with self._engine.begin() as conn:
112+
conn.execute(self._log_table.insert().values(entry=tx_data))
113+
114+
def get_tx_log(self, limit: int = 100, offset: int = 0) -> list[dict[str, Any]]:
115+
stmt = select(self._log_table.c.entry).order_by(desc(self._log_table.c.seq)).limit(limit).offset(offset)
116+
with self._engine.connect() as conn:
117+
rows = conn.execute(stmt).all()
118+
return [r[0] for r in rows]
119+
120+
def delete_session(self, session_id: str) -> list[str]:
121+
# 1. Find the ID to delete
122+
# WHERE doc->>'session_id' == session_id
123+
find_stmt = select(self._facts_table.c.id).where(self._facts_table.c.doc["session_id"].astext == session_id)
124+
125+
with self._engine.connect() as conn:
126+
ids_to_delete = [r[0] for r in conn.execute(find_stmt).all()]
127+
128+
if not ids_to_delete:
129+
return []
130+
131+
# 2. Delete
132+
del_stmt = delete(self._facts_table).where(self._facts_table.c.id.in_(ids_to_delete))
133+
with self._engine.begin() as conn:
134+
conn.execute(del_stmt)
135+
136+
return ids_to_delete

memstate/backends/redis.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,11 @@
66
try:
77
import redis
88
except ImportError:
9-
redis = None # type: ignore[assignment]
9+
raise ImportError("redis package is required. pip install redis")
1010

1111

1212
class RedisStorage(StorageBackend):
1313
def __init__(self, client_or_url: Union[str, "redis.Redis"] = "redis://localhost:6379/0") -> None:
14-
if not redis:
15-
raise ImportError("redis package is required. pip install redis")
16-
1714
self.prefix = "mem:"
1815

1916
if isinstance(client_or_url, str):

memstate/backends/sqlite.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import json
2+
import re
23
import sqlite3
34
import threading
45
from typing import Any
@@ -28,6 +29,7 @@ def _init_db(self) -> None:
2829
with self._lock:
2930
c = self._conn.cursor()
3031
c.execute("PRAGMA journal_mode=WAL;")
32+
c.execute("PRAGMA synchronous=NORMAL;")
3133

3234
c.execute(
3335
"""
@@ -41,6 +43,7 @@ def _init_db(self) -> None:
4143
"""
4244
)
4345
c.execute("CREATE INDEX IF NOT EXISTS idx_facts_type ON facts(type)")
46+
c.execute("CREATE INDEX IF NOT EXISTS idx_facts_session ON facts(json_extract(data, '$.session_id'))")
4447
c.execute(
4548
"""
4649
CREATE TABLE IF NOT EXISTS tx_log
@@ -93,6 +96,8 @@ def query(self, type_filter: str | None = None, json_filters: dict[str, Any] | N
9396

9497
if json_filters:
9598
for key, value in json_filters.items():
99+
if not re.match(r"^[a-zA-Z0-9_.]+$", key):
100+
raise ValueError(f"Invalid characters in filter key: {key}")
96101
query += f" AND json_extract(data, '$.{key}') = ?"
97102
params.append(value)
98103

@@ -125,14 +130,10 @@ def delete_session(self, session_id: str) -> list[str]:
125130
with self._lock:
126131
c = self._conn.cursor()
127132

128-
c.execute("SELECT id FROM facts WHERE json_extract(data, '$.session_id') = ?", (session_id,))
133+
c.execute("DELETE FROM facts WHERE json_extract(data, '$.session_id') = ? RETURNING id", (session_id,))
129134
rows = c.fetchall()
130135
ids = [row["id"] for row in rows]
131-
132-
if ids:
133-
c.execute("DELETE FROM facts WHERE json_extract(data, '$.session_id') = ?", (session_id,))
134-
self._conn.commit()
135-
136+
self._conn.commit()
136137
return ids
137138

138139
def close(self):

memstate/storage.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ def validate(self, typename: str, payload: dict[str, Any]) -> dict[str, Any]:
2727
return payload
2828
try:
2929
instance = model_cls.model_validate(payload)
30-
return instance.model_dump()
30+
return instance.model_dump(mode="json")
3131
except ValidationError as e:
3232
raise ValidationFailed(str(e))
3333

@@ -116,7 +116,7 @@ def commit(
116116
op = Operation.COMMIT_EPHEMERAL if ephemeral else Operation.COMMIT
117117

118118
try:
119-
new_state = fact.model_dump()
119+
new_state = fact.model_dump(mode="json")
120120
self.storage.save(new_state)
121121
self._log_tx(op, fact.id, previous_state, new_state, actor, reason)
122122
self._notify_hooks(op, fact.id, fact)
@@ -153,7 +153,9 @@ def commit_model(
153153
f"Please call memory.register_schema('your_type_name', {model.__class__.__name__}) first."
154154
)
155155

156-
fact = Fact(id=fact_id or str(uuid.uuid4()), type=schema_type, payload=model.model_dump(), source=source)
156+
fact = Fact(
157+
id=fact_id or str(uuid.uuid4()), type=schema_type, payload=model.model_dump(mode="json"), source=source
158+
)
157159

158160
return self.commit(fact, session_id=session_id, ephemeral=ephemeral, actor=actor, reason=reason)
159161

@@ -279,4 +281,4 @@ def _log_tx(
279281
actor=actor,
280282
reason=reason,
281283
)
282-
self.storage.append_tx(tx.model_dump())
284+
self.storage.append_tx(tx.model_dump(mode="json"))

pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ dependencies = [
4848
redis = ["redis>=7.1.0"]
4949
langgraph = ["langgraph>=1.0.4"]
5050
chromadb = ["chromadb>=1.3.5"]
51+
postgres = ["sqlalchemy>=2.0.0", "psycopg[binary]>=3.3.2"]
5152

5253
[dependency-groups]
5354
dev = [
@@ -59,6 +60,7 @@ dev = [
5960
"pre-commit>=4.4.0",
6061
"pytest>=9.0.1",
6162
"ruff>=0.14.6",
63+
"testcontainers>=4.13.3",
6264
]
6365

6466
[project.urls]

0 commit comments

Comments
 (0)