Skip to content

Commit d5d578d

Browse files
committed
feat: add method commit_model
1 parent 4336a1e commit d5d578d

File tree

2 files changed

+57
-0
lines changed

2 files changed

+57
-0
lines changed

memstate/storage.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,15 @@ def validate(self, typename: str, payload: dict[str, Any]) -> dict[str, Any]:
3030
except ValidationError as e:
3131
raise ValidationFailed(str(e))
3232

33+
def get_type_by_model(self, model_class: type[BaseModel]) -> str | None:
34+
"""
35+
Reverse lookup: finds the registered type name for a given Pydantic class.
36+
"""
37+
for type_name, cls in self._schemas.items():
38+
if cls == model_class:
39+
return type_name
40+
return None
41+
3342

3443
class Constraint:
3544
def __init__(self, singleton_key: str | None = None, immutable: bool = False) -> None:
@@ -121,6 +130,30 @@ def commit(
121130

122131
raise e
123132

133+
def commit_model(
134+
self,
135+
model: BaseModel,
136+
session_id: str | None = None,
137+
ephemeral: bool = False,
138+
actor: str | None = None,
139+
reason: str | None = None,
140+
) -> str:
141+
"""
142+
Commit a Pydantic model instance directly.
143+
Auto-detects the schema type from the registry.
144+
"""
145+
schema_type = self._schema_registry.get_type_by_model(model.__class__)
146+
147+
if not schema_type:
148+
raise MemoryStoreError(
149+
f"Model class '{model.__class__.__name__}' is not registered. "
150+
f"Please call memory.register_schema('your_type_name', {model.__class__.__name__}) first."
151+
)
152+
153+
fact = Fact(type=schema_type, payload=model.model_dump())
154+
155+
return self.commit(fact, session_id=session_id, ephemeral=ephemeral, actor=actor, reason=reason)
156+
124157
def update(self, fact_id: str, patch: dict[str, Any], actor: str | None = None, reason: str | None = None) -> str:
125158
with self._lock:
126159
existing = self.storage.load(fact_id)

tests/test_storage.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,3 +122,27 @@ def test_ephemeral_session_discard(memory):
122122

123123
remaining = memory.query(filters={"session_id": session_id})
124124
assert len(remaining) == 0
125+
126+
127+
def test_commit_model_success(memory):
128+
schema_name = "user_v1"
129+
memory.register_schema(schema_name, User)
130+
131+
user = User(name="Survivor", age=50)
132+
133+
fact_id = memory.commit_model(user, actor="system", session_id="session_1")
134+
135+
saved_fact = memory.storage.load(fact_id)
136+
137+
assert saved_fact is not None
138+
assert saved_fact["type"] == schema_name
139+
assert saved_fact["payload"] == {"name": "Survivor", "age": 50}
140+
141+
assert saved_fact["session_id"] == "session_1"
142+
143+
144+
def test_commit_model_raises_on_unregistered(memory):
145+
unknown = User(name="Survivor", age=50)
146+
147+
with pytest.raises(MemoryStoreError, match="is not registered"):
148+
memory.commit_model(unknown)

0 commit comments

Comments
 (0)