4
4
"""
5
5
6
6
from datetime import datetime , timezone
7
- from typing import Literal
7
+ from typing import Dict , List , Literal , cast
8
8
9
9
from langchain .chat_models import init_chat_model
10
10
from langchain_core .messages import AIMessage
20
20
# Define the function that calls the model
21
21
22
22
23
- async def call_model (state : State , config : RunnableConfig ):
23
+ async def call_model (
24
+ state : State , config : RunnableConfig
25
+ ) -> Dict [str , List [AIMessage ]]:
24
26
"""Call the LLM powering our "agent".
25
27
26
28
This function prepares the prompt, initializes the model, and processes the response.
@@ -36,22 +38,26 @@ async def call_model(state: State, config: RunnableConfig):
36
38
37
39
# Create a prompt template. Customize this to change the agent's behavior.
38
40
prompt = ChatPromptTemplate .from_messages (
39
- [("system" , configuration [ " system_prompt" ] ), ("placeholder" , "{messages}" )]
41
+ [("system" , configuration . system_prompt ), ("placeholder" , "{messages}" )]
40
42
)
41
43
42
44
# Initialize the model with tool binding. Change the model or add more tools here.
43
- model = init_chat_model (configuration [ " model_name" ] ).bind_tools (TOOLS )
45
+ model = init_chat_model (configuration . model_name ).bind_tools (TOOLS )
44
46
45
47
# Prepare the input for the model, including the current system time
46
48
message_value = await prompt .ainvoke (
47
- {** state , "system_time" : datetime .now (tz = timezone .utc ).isoformat ()}, config
49
+ {
50
+ "messages" : state .messages ,
51
+ "system_time" : datetime .now (tz = timezone .utc ).isoformat (),
52
+ },
53
+ config ,
48
54
)
49
55
50
56
# Get the model's response
51
- response : AIMessage = await model .ainvoke (message_value , config )
57
+ response = cast ( AIMessage , await model .ainvoke (message_value , config ) )
52
58
53
59
# Handle the case when it's the last step and the model still wants to use a tool
54
- if state [ " is_last_step" ] and response .tool_calls :
60
+ if state . is_last_step and response .tool_calls :
55
61
return {
56
62
"messages" : [
57
63
AIMessage (
@@ -89,14 +95,16 @@ def route_model_output(state: State) -> Literal["__end__", "tools"]:
89
95
Returns:
90
96
str: The name of the next node to call ("__end__" or "tools").
91
97
"""
92
- messages = state ["messages" ]
93
- last_message = messages [- 1 ]
94
- # If there is no function call, then we finish
98
+ last_message = state .messages [- 1 ]
99
+ if not isinstance (last_message , AIMessage ):
100
+ raise ValueError (
101
+ f"Expected AIMessage in output edges, but got { type (last_message ).__name__ } "
102
+ )
103
+ # If there is no tool call, then we finish
95
104
if not last_message .tool_calls :
96
105
return "__end__"
97
- # Otherwise if there are tools called, we continue
98
- else :
99
- return "tools"
106
+ # Otherwise we execute the requested actions
107
+ return "tools"
100
108
101
109
102
110
# Add a conditional edge to determine the next step after `call_model`
0 commit comments