Skip to content

Commit 5780912

Browse files
committed
refactor: update LangGraphAgent methods to return Command objects and streamline graph node handling
1 parent 595aaf3 commit 5780912

File tree

1 file changed

+21
-36
lines changed

1 file changed

+21
-36
lines changed

app/core/langgraph/graph.py

Lines changed: 21 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
Any,
55
AsyncGenerator,
66
Dict,
7-
Literal,
87
Optional,
98
)
109
from urllib.parse import quote_plus
@@ -22,7 +21,10 @@
2221
END,
2322
StateGraph,
2423
)
25-
from langgraph.graph.state import CompiledStateGraph
24+
from langgraph.graph.state import (
25+
Command,
26+
CompiledStateGraph,
27+
)
2628
from langgraph.types import StateSnapshot
2729
from openai import OpenAIError
2830
from psycopg_pool import AsyncConnectionPool
@@ -126,14 +128,14 @@ async def _get_connection_pool(self) -> AsyncConnectionPool:
126128
raise e
127129
return self._connection_pool
128130

129-
async def _chat(self, state: GraphState) -> dict:
131+
async def _chat(self, state: GraphState) -> Command:
130132
"""Process the chat state and generate a response.
131133
132134
Args:
133135
state (GraphState): The current state of the conversation.
134136
135137
Returns:
136-
dict: Updated state with new messages.
138+
Command: Command object with updated state and next node to execute.
137139
"""
138140
messages = prepare_messages(state.messages, self.llm, SYSTEM_PROMPT)
139141

@@ -145,15 +147,22 @@ async def _chat(self, state: GraphState) -> dict:
145147
for attempt in range(max_retries):
146148
try:
147149
with llm_inference_duration_seconds.labels(model=self.llm.model_name).time():
148-
generated_state = {"messages": [await self.llm.ainvoke(dump_messages(messages))]}
150+
response_message = await self.llm.ainvoke(dump_messages(messages))
149151
logger.info(
150152
"llm_response_generated",
151153
session_id=state.session_id,
152154
llm_calls_num=llm_calls_num + 1,
153155
model=settings.LLM_MODEL,
154156
environment=settings.ENVIRONMENT.value,
155157
)
156-
return generated_state
158+
159+
# Determine next node based on whether there are tool calls
160+
if response_message.tool_calls:
161+
goto = "tool_call"
162+
else:
163+
goto = END
164+
165+
return Command(update={"messages": [response_message]}, goto=goto)
157166
except OpenAIError as e:
158167
logger.error(
159168
"llm_call_failed",
@@ -178,14 +187,14 @@ async def _chat(self, state: GraphState) -> dict:
178187
raise Exception(f"Failed to get a response from the LLM after {max_retries} attempts")
179188

180189
# Define our tool node
181-
async def _tool_call(self, state: GraphState) -> GraphState:
190+
async def _tool_call(self, state: GraphState) -> Command:
182191
"""Process tool calls from the last message.
183192
184193
Args:
185194
state: The current agent state containing messages and tool calls.
186195
187196
Returns:
188-
Dict with updated messages containing tool responses.
197+
Command: Command object with updated messages and routing back to chat.
189198
"""
190199
outputs = []
191200
for tool_call in state.messages[-1].tool_calls:
@@ -197,25 +206,7 @@ async def _tool_call(self, state: GraphState) -> GraphState:
197206
tool_call_id=tool_call["id"],
198207
)
199208
)
200-
return {"messages": outputs}
201-
202-
def _should_continue(self, state: GraphState) -> Literal["end", "continue"]:
203-
"""Determine if the agent should continue or end based on the last message.
204-
205-
Args:
206-
state: The current agent state containing messages.
207-
208-
Returns:
209-
Literal["end", "continue"]: "end" if there are no tool calls, "continue" otherwise.
210-
"""
211-
messages = state.messages
212-
last_message = messages[-1]
213-
# If there is no function call, then we finish
214-
if not last_message.tool_calls:
215-
return "end"
216-
# Otherwise if there is, we continue
217-
else:
218-
return "continue"
209+
return Command(update={"messages": outputs}, goto="chat")
219210

220211
async def create_graph(self) -> Optional[CompiledStateGraph]:
221212
"""Create and configure the LangGraph workflow.
@@ -226,14 +217,8 @@ async def create_graph(self) -> Optional[CompiledStateGraph]:
226217
if self._graph is None:
227218
try:
228219
graph_builder = StateGraph(GraphState)
229-
graph_builder.add_node("chat", self._chat)
230-
graph_builder.add_node("tool_call", self._tool_call)
231-
graph_builder.add_conditional_edges(
232-
"chat",
233-
self._should_continue,
234-
{"continue": "tool_call", "end": END},
235-
)
236-
graph_builder.add_edge("tool_call", "chat")
220+
graph_builder.add_node("chat", self._chat, ends=["tool_call", END])
221+
graph_builder.add_node("tool_call", self._tool_call, ends=["chat"])
237222
graph_builder.set_entry_point("chat")
238223
graph_builder.set_finish_point("chat")
239224

@@ -293,7 +278,7 @@ async def get_response(
293278
"user_id": user_id,
294279
"session_id": session_id,
295280
"environment": settings.ENVIRONMENT.value,
296-
"debug": False,
281+
"debug": settings.DEBUG,
297282
},
298283
}
299284
try:

0 commit comments

Comments
 (0)