Skip to content

Commit 21bceab

Browse files
authored
Add test caching (#5)
1 parent 4d198be commit 21bceab

File tree

5 files changed

+2073
-487
lines changed

5 files changed

+2073
-487
lines changed

.github/workflows/integration-tests.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,5 +39,6 @@ jobs:
3939
TAVILY_API_KEY: ${{ secrets.TAVILY_API_KEY }}
4040
LANGSMITH_API_KEY: ${{ secrets.LANGSMITH_API_KEY }}
4141
LANGSMITH_TRACING: true
42+
LANGSMITH_TEST_CACHE: tests/cassettes
4243
run: |
4344
uv run pytest tests/integration_tests

src/react_agent/graph.py

Lines changed: 17 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
from typing import Dict, List, Literal, cast
88

99
from langchain_core.messages import AIMessage
10-
from langchain_core.prompts import ChatPromptTemplate
1110
from langchain_core.runnables import RunnableConfig
1211
from langgraph.graph import StateGraph
1312
from langgraph.prebuilt import ToolNode
@@ -36,25 +35,21 @@ async def call_model(
3635
"""
3736
configuration = Configuration.from_runnable_config(config)
3837

39-
# Create a prompt template. Customize this to change the agent's behavior.
40-
prompt = ChatPromptTemplate.from_messages(
41-
[("system", configuration.system_prompt), ("placeholder", "{messages}")]
42-
)
43-
4438
# Initialize the model with tool binding. Change the model or add more tools here.
4539
model = load_chat_model(configuration.model).bind_tools(TOOLS)
4640

47-
# Prepare the input for the model, including the current system time
48-
message_value = await prompt.ainvoke(
49-
{
50-
"messages": state.messages,
51-
"system_time": datetime.now(tz=timezone.utc).isoformat(),
52-
},
53-
config,
41+
# Format the system prompt. Customize this to change the agent's behavior.
42+
system_message = configuration.system_prompt.format(
43+
system_time=datetime.now(tz=timezone.utc).isoformat()
5444
)
5545

5646
# Get the model's response
57-
response = cast(AIMessage, await model.ainvoke(message_value, config))
47+
response = cast(
48+
AIMessage,
49+
await model.ainvoke(
50+
[{"role": "system", "content": system_message}, *state.messages], config
51+
),
52+
)
5853

5954
# Handle the case when it's the last step and the model still wants to use a tool
6055
if state.is_last_step and response.tool_calls:
@@ -73,15 +68,15 @@ async def call_model(
7368

7469
# Define a new graph
7570

76-
workflow = StateGraph(State, input=InputState, config_schema=Configuration)
71+
builder = StateGraph(State, input=InputState, config_schema=Configuration)
7772

7873
# Define the two nodes we will cycle between
79-
workflow.add_node(call_model)
80-
workflow.add_node("tools", ToolNode(TOOLS))
74+
builder.add_node(call_model)
75+
builder.add_node("tools", ToolNode(TOOLS))
8176

8277
# Set the entrypoint as `call_model`
8378
# This means that this node is the first one called
84-
workflow.add_edge("__start__", "call_model")
79+
builder.add_edge("__start__", "call_model")
8580

8681

8782
def route_model_output(state: State) -> Literal["__end__", "tools"]:
@@ -108,7 +103,7 @@ def route_model_output(state: State) -> Literal["__end__", "tools"]:
108103

109104

110105
# Add a conditional edge to determine the next step after `call_model`
111-
workflow.add_conditional_edges(
106+
builder.add_conditional_edges(
112107
"call_model",
113108
# After call_model finishes running, the next node(s) are scheduled
114109
# based on the output from route_model_output
@@ -117,11 +112,11 @@ def route_model_output(state: State) -> Literal["__end__", "tools"]:
117112

118113
# Add a normal edge from `tools` to `call_model`
119114
# This creates a cycle: after using tools, we always return to the model
120-
workflow.add_edge("tools", "call_model")
115+
builder.add_edge("tools", "call_model")
121116

122-
# Compile the workflow into an executable graph
117+
# Compile the builder into an executable graph
123118
# You can customize this by adding interrupt points for state updates
124-
graph = workflow.compile(
119+
graph = builder.compile(
125120
interrupt_before=[], # Add node names here to update state before they're called
126121
interrupt_after=[], # Add node names here to update state after they're called
127122
)

0 commit comments

Comments
 (0)