-
Notifications
You must be signed in to change notification settings - Fork 472
Expand file tree
/
Copy pathgraph.py
More file actions
397 lines (342 loc) · 14.9 KB
/
graph.py
File metadata and controls
397 lines (342 loc) · 14.9 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
"""This file contains the LangGraph Agent/workflow and interactions with the LLM."""
from typing import (
Any,
AsyncGenerator,
Dict,
Literal,
Optional,
)
from urllib.parse import quote_plus
from asgiref.sync import sync_to_async
from langchain_core.messages import (
BaseMessage,
ToolMessage,
convert_to_openai_messages,
)
from langchain_openai import ChatOpenAI
from langfuse.langchain import CallbackHandler
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
from langgraph.graph import (
END,
StateGraph,
)
from langgraph.graph.state import CompiledStateGraph
from langgraph.types import StateSnapshot
from openai import OpenAIError
from psycopg_pool import AsyncConnectionPool
from app.core.config import (
Environment,
settings,
)
from app.core.langgraph.tools import tools
from app.core.logging import logger
from app.core.metrics import llm_inference_duration_seconds
from app.core.prompts import SYSTEM_PROMPT
from app.schemas import (
GraphState,
Message,
)
from app.utils import (
dump_messages,
prepare_messages,
)
class LangGraphAgent:
"""Manages the LangGraph Agent/workflow and interactions with the LLM.
This class handles the creation and management of the LangGraph workflow,
including LLM interactions, database connections, and response processing.
"""
def __init__(self):
"""Initialize the LangGraph Agent with necessary components."""
# Use environment-specific LLM model
self.llm = ChatOpenAI(
model=settings.LLM_MODEL,
temperature=settings.DEFAULT_LLM_TEMPERATURE,
api_key=settings.LLM_API_KEY,
max_tokens=settings.MAX_TOKENS,
**self._get_model_kwargs(),
).bind_tools(tools)
self.tools_by_name = {tool.name: tool for tool in tools}
self._connection_pool: Optional[AsyncConnectionPool] = None
self._graph: Optional[CompiledStateGraph] = None
logger.info("llm_initialized", model=settings.LLM_MODEL, environment=settings.ENVIRONMENT.value)
def _get_model_kwargs(self) -> Dict[str, Any]:
"""Get environment-specific model kwargs.
Returns:
Dict[str, Any]: Additional model arguments based on environment
"""
model_kwargs = {}
# Development - we can use lower speeds for cost savings
if settings.ENVIRONMENT == Environment.DEVELOPMENT:
model_kwargs["top_p"] = 0.8
# Production - use higher quality settings
elif settings.ENVIRONMENT == Environment.PRODUCTION:
model_kwargs["top_p"] = 0.95
model_kwargs["presence_penalty"] = 0.1
model_kwargs["frequency_penalty"] = 0.1
return model_kwargs
async def _get_connection_pool(self) -> AsyncConnectionPool:
"""Get a PostgreSQL connection pool using environment-specific settings.
Returns:
AsyncConnectionPool: A connection pool for PostgreSQL database.
"""
if self._connection_pool is None:
try:
# Configure pool size based on environment
max_size = settings.POSTGRES_POOL_SIZE
connection_url = (
"postgresql://"
f"{quote_plus(settings.POSTGRES_USER)}:{quote_plus(settings.POSTGRES_PASSWORD)}"
f"@{settings.POSTGRES_HOST}:{settings.POSTGRES_PORT}/{settings.POSTGRES_DB}"
)
self._connection_pool = AsyncConnectionPool(
connection_url,
open=False,
max_size=max_size,
kwargs={
"autocommit": True,
"connect_timeout": 5,
"prepare_threshold": None,
},
)
await self._connection_pool.open()
logger.info("connection_pool_created", max_size=max_size, environment=settings.ENVIRONMENT.value)
except Exception as e:
logger.error("connection_pool_creation_failed", error=str(e), environment=settings.ENVIRONMENT.value)
# In production, we might want to degrade gracefully
if settings.ENVIRONMENT == Environment.PRODUCTION:
logger.warning("continuing_without_connection_pool", environment=settings.ENVIRONMENT.value)
return None
raise e
return self._connection_pool
async def _chat(self, state: GraphState) -> dict:
"""Process the chat state and generate a response.
Args:
state (GraphState): The current state of the conversation.
Returns:
dict: Updated state with new messages.
"""
messages = prepare_messages(state.messages, self.llm, SYSTEM_PROMPT)
llm_calls_num = 0
# Configure retry attempts based on environment
max_retries = settings.MAX_LLM_CALL_RETRIES
for attempt in range(max_retries):
try:
with llm_inference_duration_seconds.labels(model=self.llm.model_name).time():
generated_state = {"messages": [await self.llm.ainvoke(dump_messages(messages))]}
logger.info(
"llm_response_generated",
session_id=state.session_id,
llm_calls_num=llm_calls_num + 1,
model=settings.LLM_MODEL,
environment=settings.ENVIRONMENT.value,
)
return generated_state
except OpenAIError as e:
logger.error(
"llm_call_failed",
llm_calls_num=llm_calls_num,
attempt=attempt + 1,
max_retries=max_retries,
error=str(e),
environment=settings.ENVIRONMENT.value,
)
llm_calls_num += 1
# In production, we might want to fall back to a more reliable model
if settings.ENVIRONMENT == Environment.PRODUCTION and attempt == max_retries - 2:
fallback_model = "gpt-4o"
logger.warning(
"using_fallback_model", model=fallback_model, environment=settings.ENVIRONMENT.value
)
self.llm.model_name = fallback_model
continue
raise Exception(f"Failed to get a response from the LLM after {max_retries} attempts")
# Define our tool node
async def _tool_call(self, state: GraphState) -> GraphState:
"""Process tool calls from the last message.
Args:
state: The current agent state containing messages and tool calls.
Returns:
Dict with updated messages containing tool responses.
"""
outputs = []
for tool_call in state.messages[-1].tool_calls:
tool_result = await self.tools_by_name[tool_call["name"]].ainvoke(tool_call["args"])
outputs.append(
ToolMessage(
content=tool_result,
name=tool_call["name"],
tool_call_id=tool_call["id"],
)
)
return {"messages": outputs}
def _should_continue(self, state: GraphState) -> Literal["end", "continue"]:
"""Determine if the agent should continue or end based on the last message.
Args:
state: The current agent state containing messages.
Returns:
Literal["end", "continue"]: "end" if there are no tool calls, "continue" otherwise.
"""
messages = state.messages
last_message = messages[-1]
# If there is no function call, then we finish
if not last_message.tool_calls:
return "end"
# Otherwise if there is, we continue
else:
return "continue"
async def create_graph(self) -> Optional[CompiledStateGraph]:
"""Create and configure the LangGraph workflow.
Returns:
Optional[CompiledStateGraph]: The configured LangGraph instance or None if init fails
"""
if self._graph is None:
try:
graph_builder = StateGraph(GraphState)
graph_builder.add_node("chat", self._chat)
graph_builder.add_node("tool_call", self._tool_call)
graph_builder.add_conditional_edges(
"chat",
self._should_continue,
{"continue": "tool_call", "end": END},
)
graph_builder.add_edge("tool_call", "chat")
graph_builder.set_entry_point("chat")
graph_builder.set_finish_point("chat")
# Get connection pool (may be None in production if DB unavailable)
connection_pool = await self._get_connection_pool()
if connection_pool:
checkpointer = AsyncPostgresSaver(connection_pool)
await checkpointer.setup()
else:
# In production, proceed without checkpointer if needed
checkpointer = None
if settings.ENVIRONMENT != Environment.PRODUCTION:
raise Exception("Connection pool initialization failed")
self._graph = graph_builder.compile(
checkpointer=checkpointer, name=f"{settings.PROJECT_NAME} Agent ({settings.ENVIRONMENT.value})"
)
logger.info(
"graph_created",
graph_name=f"{settings.PROJECT_NAME} Agent",
environment=settings.ENVIRONMENT.value,
has_checkpointer=checkpointer is not None,
)
except Exception as e:
logger.error("graph_creation_failed", error=str(e), environment=settings.ENVIRONMENT.value)
# In production, we don't want to crash the app
if settings.ENVIRONMENT == Environment.PRODUCTION:
logger.warning("continuing_without_graph")
return None
raise e
return self._graph
async def get_response(
self,
messages: list[Message],
session_id: str,
user_id: Optional[str] = None,
) -> list[dict]:
"""Get a response from the LLM.
Args:
messages (list[Message]): The messages to send to the LLM.
session_id (str): The session ID for Langfuse tracking.
user_id (Optional[str]): The user ID for Langfuse tracking.
Returns:
list[dict]: The response from the LLM.
"""
if self._graph is None:
self._graph = await self.create_graph()
config = {
"configurable": {"thread_id": session_id},
"callbacks": [CallbackHandler()],
"metadata": {
"user_id": user_id,
"session_id": session_id,
"environment": settings.ENVIRONMENT.value,
"debug": False,
},
}
try:
response = await self._graph.ainvoke(
{"messages": dump_messages(messages), "session_id": session_id}, config
)
return self.__process_messages(response["messages"])
except Exception as e:
logger.error(f"Error getting response: {str(e)}")
raise e
async def get_stream_response(
self, messages: list[Message], session_id: str, user_id: Optional[str] = None
) -> AsyncGenerator[str, None]:
"""Get a stream response from the LLM.
Args:
messages (list[Message]): The messages to send to the LLM.
session_id (str): The session ID for the conversation.
user_id (Optional[str]): The user ID for the conversation.
Yields:
str: Tokens of the LLM response.
"""
config = {
"configurable": {"thread_id": session_id},
"callbacks": [
CallbackHandler(
environment=settings.ENVIRONMENT.value, debug=False, user_id=user_id, session_id=session_id
)
],
}
if self._graph is None:
self._graph = await self.create_graph()
try:
async for token, _ in self._graph.astream(
{"messages": dump_messages(messages), "session_id": session_id}, config, stream_mode="messages"
):
try:
yield token.content
except Exception as token_error:
logger.error("Error processing token", error=str(token_error), session_id=session_id)
# Continue with next token even if current one fails
continue
except Exception as stream_error:
logger.error("Error in stream processing", error=str(stream_error), session_id=session_id)
raise stream_error
async def get_chat_history(self, session_id: str) -> list[Message]:
"""Get the chat history for a given thread ID.
Args:
session_id (str): The session ID for the conversation.
Returns:
list[Message]: The chat history.
"""
if self._graph is None:
self._graph = await self.create_graph()
state: StateSnapshot = await sync_to_async(self._graph.get_state)(
config={"configurable": {"thread_id": session_id}}
)
return self.__process_messages(state.values["messages"]) if state.values else []
def __process_messages(self, messages: list[BaseMessage]) -> list[Message]:
openai_style_messages = convert_to_openai_messages(messages)
# keep just assistant and user messages
return [
Message(**message)
for message in openai_style_messages
if message["role"] in ["assistant", "user"] and message["content"]
]
async def clear_chat_history(self, session_id: str) -> None:
"""Clear all chat history for a given thread ID.
Args:
session_id: The ID of the session to clear history for.
Raises:
Exception: If there's an error clearing the chat history.
"""
try:
# Make sure the pool is initialized in the current event loop
conn_pool = await self._get_connection_pool()
# Use a new connection for this specific operation
async with conn_pool.connection() as conn:
for table in settings.CHECKPOINT_TABLES:
try:
await conn.execute(f"DELETE FROM {table} WHERE thread_id = %s", (session_id,))
logger.info(f"Cleared {table} for session {session_id}")
except Exception as e:
logger.error(f"Error clearing {table}", error=str(e))
raise
except Exception as e:
logger.error("Failed to clear chat history", error=str(e))
raise