Skip to content

Commit c85530e

Browse files
authored
Standardise conversation input with OpenAI format (#9)
* Standardise conversation input with OpenAI format * align with backend changes
1 parent 526de43 commit c85530e

File tree

8 files changed

+155
-31
lines changed

8 files changed

+155
-31
lines changed

src/engram/__init__.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,9 @@
1010
RunStatus,
1111
SearchResults,
1212
StringContent,
13-
ToolCallMetadata,
13+
ToolCallCustomInput,
14+
ToolCallFuncInput,
15+
ToolCallInput,
1416
)
1517
from .async_client import AsyncEngramClient
1618
from .client import EngramClient
@@ -43,7 +45,9 @@
4345
"RunStatus",
4446
"SearchResults",
4547
"StringContent",
46-
"ToolCallMetadata",
48+
"ToolCallCustomInput",
49+
"ToolCallFuncInput",
50+
"ToolCallInput",
4751
"ValidationError",
4852
"__version__",
4953
]

src/engram/_models/__init__.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,9 @@
77
RetrievalConfig,
88
SearchResults,
99
StringContent,
10-
ToolCallMetadata,
10+
ToolCallCustomInput,
11+
ToolCallFuncInput,
12+
ToolCallInput,
1113
)
1214
from .run import CommittedOperation, CommittedOperations, Run, RunStatus
1315

@@ -24,5 +26,7 @@
2426
"RunStatus",
2527
"SearchResults",
2628
"StringContent",
27-
"ToolCallMetadata",
29+
"ToolCallCustomInput",
30+
"ToolCallFuncInput",
31+
"ToolCallInput",
2832
]

src/engram/_models/memory.py

Lines changed: 33 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,21 +21,48 @@ class StringContent:
2121

2222

2323
@dataclass(slots=True)
24-
class ToolCallMetadata:
25-
"""Tool call metadata."""
24+
class ToolCallFuncInput:
25+
"""The function details of an OpenAI-format function tool call."""
2626

2727
name: str
28+
arguments: str
29+
30+
31+
@dataclass(slots=True)
32+
class ToolCallCustomInput:
33+
"""The details of an OpenAI-format custom tool call."""
34+
35+
name: str
36+
input: str
37+
38+
39+
@dataclass(slots=True)
40+
class ToolCallInput:
41+
"""A single tool call in OpenAI Chat Completions format.
42+
43+
Set either `function` or `custom` depending on the tool type.
44+
"""
45+
2846
id: str
47+
type: str = "function"
48+
function: ToolCallFuncInput | None = None
49+
custom: ToolCallCustomInput | None = None
2950

3051

3152
@dataclass(slots=True)
3253
class MessageContent:
33-
"""A message in a conversation."""
54+
"""A message in a conversation using the OpenAI Chat Completions format.
3455
35-
role: Literal["user", "assistant", "system"]
36-
content: str
56+
- 'tool' role (tool results) is mapped to 'user' by the server.
57+
- 'developer' role is mapped to 'system' by the server.
58+
"""
59+
60+
role: Literal["user", "assistant", "system", "tool", "developer"]
61+
content: str = ""
3762
created_at: str | None = None
38-
tool_call_metadata: ToolCallMetadata | None = None
63+
tool_call_id: str | None = None
64+
name: str | None = None
65+
tool_calls: list[ToolCallInput] | None = None
3966

4067

4168
@dataclass(slots=True)

src/engram/_serialization/_builders.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,19 @@
88
PreExtractedContent,
99
RetrievalConfig,
1010
StringContent,
11+
ToolCallInput,
1112
)
1213

1314

15+
def _serialize_tool_call(tc: ToolCallInput) -> dict[str, Any]:
16+
out: dict[str, Any] = {"id": tc.id, "type": tc.type}
17+
if tc.function is not None:
18+
out["function"] = {"name": tc.function.name, "arguments": tc.function.arguments}
19+
if tc.custom is not None:
20+
out["custom"] = {"name": tc.custom.name, "input": tc.custom.input}
21+
return out
22+
23+
1424
def _serialize_content(content: AddContent) -> dict[str, Any]:
1525
"""Build the content envelope with the type discriminator."""
1626
if isinstance(content, str):
@@ -39,11 +49,12 @@ def _serialize_conversation_content(content: ConversationContent) -> dict[str, A
3949
m: dict[str, Any] = {"role": msg.role, "content": msg.content}
4050
if msg.created_at is not None:
4151
m["created_at"] = msg.created_at
42-
if msg.tool_call_metadata is not None:
43-
m["tool_call_metadata"] = {
44-
"name": msg.tool_call_metadata.name,
45-
"id": msg.tool_call_metadata.id,
46-
}
52+
if msg.tool_call_id is not None:
53+
m["tool_call_id"] = msg.tool_call_id
54+
if msg.name is not None:
55+
m["name"] = msg.name
56+
if msg.tool_calls is not None:
57+
m["tool_calls"] = [_serialize_tool_call(tc) for tc in msg.tool_calls]
4758
messages.append(m)
4859
conversation: dict[str, Any] = {"messages": messages}
4960
if content.metadata is not None:

tests/test_client_async.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@
1111
PreExtractedContent,
1212
RetrievalConfig,
1313
StringContent,
14-
ToolCallMetadata,
14+
ToolCallFuncInput,
15+
ToolCallInput,
1516
)
1617
from engram.async_client import DEFAULT_BASE_URL, AsyncEngramClient
1718
from engram.errors import APIError, AuthenticationError, ValidationError
@@ -202,8 +203,11 @@ def handler(request: httpx.Request) -> httpx.Response:
202203
MessageContent(role="user", content="hi"),
203204
MessageContent(
204205
role="assistant",
205-
content="using tool",
206-
tool_call_metadata=ToolCallMetadata(name="search", id="tc1"),
206+
tool_calls=[
207+
ToolCallInput(
208+
id="tc1", function=ToolCallFuncInput(name="search", arguments="{}")
209+
)
210+
],
207211
),
208212
],
209213
metadata={"session_id": "s1"},
@@ -214,7 +218,9 @@ def handler(request: httpx.Request) -> httpx.Response:
214218
assert body["content"]["type"] == "conversation"
215219
conv = body["content"]["conversation"]
216220
assert conv["metadata"] == {"session_id": "s1"}
217-
assert conv["messages"][1]["tool_call_metadata"] == {"name": "search", "id": "tc1"}
221+
assert conv["messages"][1]["tool_calls"] == [
222+
{"id": "tc1", "type": "function", "function": {"name": "search", "arguments": "{}"}}
223+
]
218224
assert body["conversation_id"] == "c1"
219225

220226

tests/test_client_sync.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@
1111
PreExtractedContent,
1212
RetrievalConfig,
1313
StringContent,
14-
ToolCallMetadata,
14+
ToolCallFuncInput,
15+
ToolCallInput,
1516
)
1617
from engram.client import DEFAULT_BASE_URL, EngramClient
1718
from engram.errors import APIError, AuthenticationError, ValidationError
@@ -217,8 +218,11 @@ def handler(request: httpx.Request) -> httpx.Response:
217218
MessageContent(role="user", content="hi"),
218219
MessageContent(
219220
role="assistant",
220-
content="using tool",
221-
tool_call_metadata=ToolCallMetadata(name="search", id="tc1"),
221+
tool_calls=[
222+
ToolCallInput(
223+
id="tc1", function=ToolCallFuncInput(name="search", arguments="{}")
224+
)
225+
],
222226
),
223227
],
224228
metadata={"session_id": "s1"},
@@ -229,7 +233,9 @@ def handler(request: httpx.Request) -> httpx.Response:
229233
assert body["content"]["type"] == "conversation"
230234
conv = body["content"]["conversation"]
231235
assert conv["metadata"] == {"session_id": "s1"}
232-
assert conv["messages"][1]["tool_call_metadata"] == {"name": "search", "id": "tc1"}
236+
assert conv["messages"][1]["tool_calls"] == [
237+
{"id": "tc1", "type": "function", "function": {"name": "search", "arguments": "{}"}}
238+
]
233239
assert body["conversation_id"] == "c1"
234240

235241

tests/test_imports.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,9 @@ def test_public_imports() -> None:
1919
RunStatus,
2020
SearchResults,
2121
StringContent,
22-
ToolCallMetadata,
22+
ToolCallCustomInput,
23+
ToolCallFuncInput,
24+
ToolCallInput,
2325
ValidationError,
2426
)
2527

@@ -41,7 +43,9 @@ def test_public_imports() -> None:
4143
assert isinstance(ConversationContent, type)
4244
assert isinstance(MessageContent, type)
4345
assert isinstance(StringContent, type)
44-
assert isinstance(ToolCallMetadata, type)
46+
assert isinstance(ToolCallCustomInput, type)
47+
assert isinstance(ToolCallFuncInput, type)
48+
assert isinstance(ToolCallInput, type)
4549

4650
expected_exports = {
4751
"APIError",
@@ -62,7 +66,9 @@ def test_public_imports() -> None:
6266
"RunStatus",
6367
"SearchResults",
6468
"StringContent",
65-
"ToolCallMetadata",
69+
"ToolCallCustomInput",
70+
"ToolCallFuncInput",
71+
"ToolCallInput",
6672
"ValidationError",
6773
"__version__",
6874
}

tests/test_serialization.py

Lines changed: 65 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,9 @@
44
PreExtractedContent,
55
RetrievalConfig,
66
StringContent,
7-
ToolCallMetadata,
7+
ToolCallCustomInput,
8+
ToolCallFuncInput,
9+
ToolCallInput,
810
)
911
from engram._serialization import (
1012
build_add_body,
@@ -160,12 +162,15 @@ def test_build_add_body_conversation_content_with_message_timestamps() -> None:
160162
assert "tool_call_metadata" not in msg
161163

162164

163-
def test_build_add_body_conversation_content_with_tool_call_metadata() -> None:
165+
def test_build_add_body_conversation_content_with_tool_calls() -> None:
164166
messages = [
165167
MessageContent(
166168
role="assistant",
167-
content="using tool",
168-
tool_call_metadata=ToolCallMetadata(name="search", id="tc1"),
169+
tool_calls=[
170+
ToolCallInput(
171+
id="tc1", function=ToolCallFuncInput(name="search", arguments='{"q":"x"}')
172+
)
173+
],
169174
)
170175
]
171176
body = build_add_body(
@@ -175,7 +180,62 @@ def test_build_add_body_conversation_content_with_tool_call_metadata() -> None:
175180
group=None,
176181
)
177182
msg = body["content"]["conversation"]["messages"][0]
178-
assert msg["tool_call_metadata"] == {"name": "search", "id": "tc1"}
183+
assert msg["tool_calls"] == [
184+
{"id": "tc1", "type": "function", "function": {"name": "search", "arguments": '{"q":"x"}'}}
185+
]
186+
187+
188+
def test_build_add_body_conversation_content_with_custom_tool_calls() -> None:
189+
messages = [
190+
MessageContent(
191+
role="assistant",
192+
tool_calls=[
193+
ToolCallInput(
194+
id="tc2",
195+
type="custom",
196+
custom=ToolCallCustomInput(name="my_tool", input="some input"),
197+
)
198+
],
199+
)
200+
]
201+
body = build_add_body(
202+
ConversationContent(messages=messages),
203+
user_id=None,
204+
conversation_id=None,
205+
group=None,
206+
)
207+
msg = body["content"]["conversation"]["messages"][0]
208+
assert msg["tool_calls"] == [
209+
{"id": "tc2", "type": "custom", "custom": {"name": "my_tool", "input": "some input"}}
210+
]
211+
212+
213+
def test_build_add_body_conversation_content_with_tool_role() -> None:
214+
messages = [MessageContent(role="tool", content="result", tool_call_id="tc1", name="search")]
215+
body = build_add_body(
216+
ConversationContent(messages=messages),
217+
user_id=None,
218+
conversation_id=None,
219+
group=None,
220+
)
221+
msg = body["content"]["conversation"]["messages"][0]
222+
assert msg["role"] == "tool"
223+
assert msg["tool_call_id"] == "tc1"
224+
assert msg["name"] == "search"
225+
assert msg["content"] == "result"
226+
227+
228+
def test_build_add_body_conversation_content_with_developer_role() -> None:
229+
messages = [MessageContent(role="developer", content="You are a helpful assistant.")]
230+
body = build_add_body(
231+
ConversationContent(messages=messages),
232+
user_id=None,
233+
conversation_id=None,
234+
group=None,
235+
)
236+
msg = body["content"]["conversation"]["messages"][0]
237+
assert msg["role"] == "developer"
238+
assert msg["content"] == "You are a helpful assistant."
179239

180240

181241
# ── build_memory_params ─────────────────────────────────────────────────

0 commit comments

Comments
 (0)