Skip to content

Commit 23408cc

Browse files
authored
Schema Update
1 parent 5a82f8a commit 23408cc

File tree

2 files changed

+14
-2
lines changed

2 files changed

+14
-2
lines changed

src/react_agent/graph.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
from datetime import datetime, timezone
77
from typing import Dict, List, Literal, cast
88

9-
from langchain.chat_models import init_chat_model
109
from langchain_core.messages import AIMessage
1110
from langchain_core.prompts import ChatPromptTemplate
1211
from langchain_core.runnables import RunnableConfig
@@ -16,6 +15,7 @@
1615
from react_agent.configuration import Configuration
1716
from react_agent.state import InputState, State
1817
from react_agent.tools import TOOLS
18+
from react_agent.utils import load_chat_model
1919

2020
# Define the function that calls the model
2121

@@ -42,7 +42,7 @@ async def call_model(
4242
)
4343

4444
# Initialize the model with tool binding. Change the model or add more tools here.
45-
model = init_chat_model(configuration.model_name).bind_tools(TOOLS)
45+
model = load_chat_model(configuration.model_name).bind_tools(TOOLS)
4646

4747
# Prepare the input for the model, including the current system time
4848
message_value = await prompt.ainvoke(

src/react_agent/utils.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
"""Utility & helper functions."""
22

3+
from langchain.chat_models import init_chat_model
4+
from langchain_core.language_models import BaseChatModel
35
from langchain_core.messages import BaseMessage
46

57

@@ -13,3 +15,13 @@ def get_message_text(msg: BaseMessage) -> str:
1315
else:
1416
txts = [c if isinstance(c, str) else (c.get("text") or "") for c in content]
1517
return "".join(txts).strip()
18+
19+
20+
def load_chat_model(fully_specified_name: str) -> BaseChatModel:
21+
"""Load a chat model from a fully specified name.
22+
23+
Args:
24+
fully_specified_name (str): String in the format 'provider/model'.
25+
"""
26+
provider, model = fully_specified_name.split("/", maxsplit=1)
27+
return init_chat_model(model, model_provider=provider)

0 commit comments

Comments
 (0)