Skip to content

Commit 2c40d38

Browse files
isahers1Erick Friis
authored andcommitted
core: allow passing message dicts into ChatPromptTemplate (#29363)
Co-authored-by: Erick Friis <[email protected]>
1 parent 177a24c commit 2c40d38

File tree

2 files changed

+45
-1
lines changed

2 files changed

+45
-1
lines changed

libs/core/langchain_core/prompts/chat.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -828,6 +828,7 @@ def pretty_print(self) -> None:
828828
Union[str, list[dict], list[object]],
829829
],
830830
str,
831+
dict,
831832
]
832833

833834

@@ -1461,7 +1462,15 @@ def _convert_to_message(
14611462
_message = _create_template_from_message_type(
14621463
"human", message, template_format=template_format
14631464
)
1464-
elif isinstance(message, tuple):
1465+
elif isinstance(message, (tuple, dict)):
1466+
if isinstance(message, dict):
1467+
if set(message.keys()) != {"content", "role"}:
1468+
msg = (
1469+
"Expected dict to have exact keys 'role' and 'content'."
1470+
f" Got: {message}"
1471+
)
1472+
raise ValueError(msg)
1473+
message = (message["role"], message["content"])
14651474
if len(message) != 2:
14661475
msg = f"Expected 2-tuple of (role, template), got {message}"
14671476
raise ValueError(msg)

libs/core/tests/unit_tests/prompts/test_chat.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -824,6 +824,41 @@ def test_chat_prompt_message_placeholder_tuple() -> None:
824824
assert optional_prompt.format_messages() == []
825825

826826

827+
def test_chat_prompt_message_placeholder_dict() -> None:
828+
prompt = ChatPromptTemplate([{"role": "placeholder", "content": "{convo}"}])
829+
assert prompt.format_messages(convo=[("user", "foo")]) == [
830+
HumanMessage(content="foo")
831+
]
832+
833+
assert prompt.format_messages() == []
834+
835+
# Is optional = True
836+
optional_prompt = ChatPromptTemplate(
837+
[{"role": "placeholder", "content": ["{convo}", False]}]
838+
)
839+
assert optional_prompt.format_messages(convo=[("user", "foo")]) == [
840+
HumanMessage(content="foo")
841+
]
842+
with pytest.raises(KeyError):
843+
assert optional_prompt.format_messages() == []
844+
845+
846+
def test_chat_prompt_message_dict() -> None:
847+
prompt = ChatPromptTemplate(
848+
[{"role": "system", "content": "foo"}, {"role": "user", "content": "bar"}]
849+
)
850+
assert prompt.format_messages() == [
851+
SystemMessage(content="foo"),
852+
HumanMessage(content="bar"),
853+
]
854+
855+
with pytest.raises(ValueError):
856+
ChatPromptTemplate([{"role": "system", "content": False}])
857+
858+
with pytest.raises(ValueError):
859+
ChatPromptTemplate([{"role": "foo", "content": "foo"}])
860+
861+
827862
async def test_messages_prompt_accepts_list() -> None:
828863
prompt = ChatPromptTemplate([MessagesPlaceholder("history")])
829864
value = prompt.invoke([("user", "Hi there")]) # type: ignore

0 commit comments

Comments
 (0)