44 Any ,
55 AsyncGenerator ,
66 Dict ,
7- Literal ,
87 Optional ,
98)
109from urllib .parse import quote_plus
2221 END ,
2322 StateGraph ,
2423)
25- from langgraph .graph .state import CompiledStateGraph
24+ from langgraph .graph .state import (
25+ Command ,
26+ CompiledStateGraph ,
27+ )
2628from langgraph .types import StateSnapshot
2729from openai import OpenAIError
2830from psycopg_pool import AsyncConnectionPool
@@ -126,14 +128,14 @@ async def _get_connection_pool(self) -> AsyncConnectionPool:
126128 raise e
127129 return self ._connection_pool
128130
129- async def _chat (self , state : GraphState ) -> dict :
131+ async def _chat (self , state : GraphState ) -> Command :
130132 """Process the chat state and generate a response.
131133
132134 Args:
133135 state (GraphState): The current state of the conversation.
134136
135137 Returns:
136- dict: Updated state with new messages .
138+ Command: Command object with updated state and next node to execute .
137139 """
138140 messages = prepare_messages (state .messages , self .llm , SYSTEM_PROMPT )
139141
@@ -145,15 +147,22 @@ async def _chat(self, state: GraphState) -> dict:
145147 for attempt in range (max_retries ):
146148 try :
147149 with llm_inference_duration_seconds .labels (model = self .llm .model_name ).time ():
148- generated_state = { "messages" : [ await self .llm .ainvoke (dump_messages (messages ))]}
150+ response_message = await self .llm .ainvoke (dump_messages (messages ))
149151 logger .info (
150152 "llm_response_generated" ,
151153 session_id = state .session_id ,
152154 llm_calls_num = llm_calls_num + 1 ,
153155 model = settings .LLM_MODEL ,
154156 environment = settings .ENVIRONMENT .value ,
155157 )
156- return generated_state
158+
159+ # Determine next node based on whether there are tool calls
160+ if response_message .tool_calls :
161+ goto = "tool_call"
162+ else :
163+ goto = END
164+
165+ return Command (update = {"messages" : [response_message ]}, goto = goto )
157166 except OpenAIError as e :
158167 logger .error (
159168 "llm_call_failed" ,
@@ -178,14 +187,14 @@ async def _chat(self, state: GraphState) -> dict:
178187 raise Exception (f"Failed to get a response from the LLM after { max_retries } attempts" )
179188
180189 # Define our tool node
181- async def _tool_call (self , state : GraphState ) -> GraphState :
190+ async def _tool_call (self , state : GraphState ) -> Command :
182191 """Process tool calls from the last message.
183192
184193 Args:
185194 state: The current agent state containing messages and tool calls.
186195
187196 Returns:
188- Dict with updated messages containing tool responses .
197+ Command: Command object with updated messages and routing back to chat .
189198 """
190199 outputs = []
191200 for tool_call in state .messages [- 1 ].tool_calls :
@@ -197,25 +206,7 @@ async def _tool_call(self, state: GraphState) -> GraphState:
197206 tool_call_id = tool_call ["id" ],
198207 )
199208 )
200- return {"messages" : outputs }
201-
202- def _should_continue (self , state : GraphState ) -> Literal ["end" , "continue" ]:
203- """Determine if the agent should continue or end based on the last message.
204-
205- Args:
206- state: The current agent state containing messages.
207-
208- Returns:
209- Literal["end", "continue"]: "end" if there are no tool calls, "continue" otherwise.
210- """
211- messages = state .messages
212- last_message = messages [- 1 ]
213- # If there is no function call, then we finish
214- if not last_message .tool_calls :
215- return "end"
216- # Otherwise if there is, we continue
217- else :
218- return "continue"
209+ return Command (update = {"messages" : outputs }, goto = "chat" )
219210
220211 async def create_graph (self ) -> Optional [CompiledStateGraph ]:
221212 """Create and configure the LangGraph workflow.
@@ -226,14 +217,8 @@ async def create_graph(self) -> Optional[CompiledStateGraph]:
226217 if self ._graph is None :
227218 try :
228219 graph_builder = StateGraph (GraphState )
229- graph_builder .add_node ("chat" , self ._chat )
230- graph_builder .add_node ("tool_call" , self ._tool_call )
231- graph_builder .add_conditional_edges (
232- "chat" ,
233- self ._should_continue ,
234- {"continue" : "tool_call" , "end" : END },
235- )
236- graph_builder .add_edge ("tool_call" , "chat" )
220+ graph_builder .add_node ("chat" , self ._chat , ends = ["tool_call" , END ])
221+ graph_builder .add_node ("tool_call" , self ._tool_call , ends = ["chat" ])
237222 graph_builder .set_entry_point ("chat" )
238223 graph_builder .set_finish_point ("chat" )
239224
@@ -293,7 +278,7 @@ async def get_response(
293278 "user_id" : user_id ,
294279 "session_id" : session_id ,
295280 "environment" : settings .ENVIRONMENT .value ,
296- "debug" : False ,
281+ "debug" : settings . DEBUG ,
297282 },
298283 }
299284 try :
0 commit comments