-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain.py
More file actions
98 lines (89 loc) · 3.88 KB
/
main.py
File metadata and controls
98 lines (89 loc) · 3.88 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import os
import json
import httpx
from dotenv import load_dotenv
from langchain_openai import ChatOpenAI
from langchain.agents import create_agent
from langchain.tools import tool
from langchain_core.messages import HumanMessage
# ----------------------------------------------------------------------
# 1. Load environment variables
# ----------------------------------------------------------------------
load_dotenv()
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
VANNA_API_KEY = os.getenv("VANNA_API_KEY")
VANNA_AGENT_ID = os.getenv("VANNA_AGENT_ID")
VANNA_USER_EMAIL = os.getenv("VANNA_USER_EMAIL")
OPENWEATHER_KEY = os.getenv("OPENWEATHER_KEY")
# ----------------------------------------------------------------------
# 2. Define the Vanna tool
# ----------------------------------------------------------------------
@tool("vanna_query")
def vanna_query(message: str) -> str:
"""Ask Vanna AI a question about data or analytics."""
headers = {
"Content-Type": "application/json",
"VANNA-API-KEY": VANNA_API_KEY,
}
payload = {
"message": message,
"user_email": VANNA_USER_EMAIL,
"agent_id": VANNA_AGENT_ID,
"acceptable_responses": ["text", "dataframe"],
}
# Stream from Vanna’s SSE endpoint and collect the final message
with httpx.stream("POST", "https://app.vanna.ai/api/v2/chat_sse",
headers=headers, json=payload, timeout=None) as r:
buffer, final_answer = "", None
for chunk in r.iter_text():
buffer += chunk
while "\n\n" in buffer:
event, buffer = buffer.split("\n\n", 1)
if event.startswith("data: "):
data_str = event[6:]
try:
data = json.loads(data_str)
except json.JSONDecodeError:
continue
if data.get("semantic_type") == "final_ai_message":
final_answer = data.get("text", "")
return final_answer or "No final answer received from Vanna."
# ----------------------------------------------------------------------
# 3. Define the OpenWeather tool
# ----------------------------------------------------------------------
@tool("get_weather")
def get_weather(location: str) -> str:
"""Get the current weather for any city using OpenWeather."""
url = (
f"https://api.openweathermap.org/data/2.5/weather?"
f"q={location}&appid={OPENWEATHER_KEY}&units=metric"
)
try:
r = httpx.get(url, timeout=10)
if r.status_code != 200:
return f"Weather API error: {r.text}"
data = r.json()
main = data.get("main", {})
weather_desc = data.get("weather", [{}])[0].get("description", "unknown")
temp = main.get("temp")
feels = main.get("feels_like")
city = data.get("name", location.title())
return (
f"The weather in {city} is currently {weather_desc}, "
f"with a temperature of {temp}°C (feels like {feels}°C)."
)
except Exception as e:
return f"Weather API request failed: {e}"
# ----------------------------------------------------------------------
# 4. Create the LangChain agent with both tools
# ----------------------------------------------------------------------
llm = ChatOpenAI(model="gpt-4o-mini", temperature=0)
agent = create_agent(llm, tools=[vanna_query, get_weather])
# ----------------------------------------------------------------------
# 5. Prompt the user
# ----------------------------------------------------------------------
query = input("Ask something (e.g., 'What's the weather in Paris?' or 'Who is the largest customer?'): ")
response = agent.invoke({"messages": [HumanMessage(content=query)]})
messages = response.get("messages", [])
final_msg = messages[-1].content if messages else response
print("\n✅ Final Answer:\n", final_msg)