Skip to content

Commit 8961e05

Browse files
authored
Merge pull request #1717 from nehcneb/main
Add support for new GPT models from OpenAI
2 parents b19815b + 65333f4 commit 8961e05

File tree

3 files changed

+222
-6
lines changed

3 files changed

+222
-6
lines changed

extensions/llms/openai/pandasai_openai/openai.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ class OpenAI(BaseOpenAI):
1616
1717
An API call to OpenAI API is sent and response is recorded and returned.
1818
The default chat model is **gpt-3.5-turbo**.
19-
The list of supported Chat models includes ["gpt-4", "gpt-4-0613", "gpt-4-32k",
19+
The list of supported Chat models includes ["gpt-4.1", "gpt-4.1-mini", "gpt-4.1-nano", "gpt-4o", "gpt-4o-mini", "gpt-4", "gpt-4-0613", "gpt-4-32k",
2020
"gpt-4-32k-0613", "gpt-3.5-turbo", "gpt-3.5-turbo-16k", "gpt-3.5-turbo-0613",
2121
"gpt-3.5-turbo-16k-0613", "gpt-3.5-turbo-instruct"].
2222
The list of supported Completion models includes "gpt-3.5-turbo-instruct" and
@@ -41,10 +41,16 @@ class OpenAI(BaseOpenAI):
4141
"gpt-4o-2024-05-13",
4242
"gpt-4o-mini",
4343
"gpt-4o-mini-2024-07-18",
44+
"gpt-4.1",
45+
"gpt-4.1-2025-04-14",
46+
"gpt-4.1-mini",
47+
"gpt-4.1-mini-2025-04-14",
48+
"gpt-4.1-nano",
49+
"gpt-4.1-nano-2025-04-14"
4450
]
4551
_supported_completion_models = ["gpt-3.5-turbo-instruct"]
4652

47-
model: str = "gpt-4o-mini"
53+
model: str = "gpt-4.1-mini"
4854

4955
def __init__(
5056
self,
Lines changed: 210 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,210 @@
1+
import os
2+
import shutil
3+
from pathlib import Path
4+
5+
import pytest
6+
from openai import OpenAI
7+
from pydantic import BaseModel
8+
9+
import pandasai as pai
10+
from pandasai import DataFrame
11+
from pandasai.helpers.path import find_project_root
12+
13+
# Read the API key from an environment variable
14+
JUDGE_OPENAI_API_KEY = os.getenv("JUDGE_OPENAI_API_KEY", None)
15+
16+
17+
class Evaluation(BaseModel):
18+
score: int
19+
justification: str
20+
21+
22+
@pytest.mark.skipif(
23+
JUDGE_OPENAI_API_KEY is None,
24+
reason="JUDGE_OPENAI_API_KEY key not set, skipping tests",
25+
)
26+
class TestAgentLLMJudge:
27+
root_dir = find_project_root()
28+
heart_stroke_path = os.path.join(root_dir, "examples", "data", "heart.csv")
29+
loans_path = os.path.join(root_dir, "examples", "data", "loans_payments.csv")
30+
31+
loans_questions = [
32+
"What is the total number of payments?",
33+
"What is the average payment amount?",
34+
"How many unique loan IDs are there?",
35+
"What is the most common payment amount?",
36+
"What is the total amount of payments?",
37+
"What is the median payment amount?",
38+
"How many payments are above $1000?",
39+
"What is the minimum and maximum payment?",
40+
"Show me a monthly trend of payments",
41+
"Show me the distribution of payment amounts",
42+
"Show me the top 10 payment amounts",
43+
"Give me a summary of payment statistics",
44+
"Show me payments above $1000",
45+
]
46+
47+
heart_strokes_questions = [
48+
"What is the total number of patients in the dataset?",
49+
"How many people had a stroke?",
50+
"What is the average age of patients?",
51+
"What percentage of patients have hypertension?",
52+
"What is the average BMI?",
53+
"How many smokers are in the dataset?",
54+
"What is the gender distribution?",
55+
"Is there a correlation between age and stroke occurrence?",
56+
"Show me the age distribution of patients.",
57+
"What is the most common work type?",
58+
"Give me a breakdown of stroke occurrences.",
59+
"Show me hypertension statistics.",
60+
"Give me smoking statistics summary.",
61+
"Show me the distribution of work types.",
62+
]
63+
64+
combined_questions = [
65+
"Compare payment patterns between age groups.",
66+
"Show relationship between payments and health conditions.",
67+
"Analyze payment differences between hypertension groups.",
68+
"Calculate average payments by health condition.",
69+
"Show payment distribution across age groups.",
70+
]
71+
72+
evaluation_scores = []
73+
74+
@pytest.fixture(autouse=True)
75+
def setup(self):
76+
"""Setup shared resources for the test class."""
77+
78+
self.client = OpenAI(api_key=JUDGE_OPENAI_API_KEY)
79+
80+
self.evaluation_prompt = (
81+
"You are an AI evaluation expert tasked with assessing the quality of a code snippet provided as a response.\n"
82+
"The question was: {question}\n"
83+
"The AI provided the following code:\n"
84+
"{code}\n\n"
85+
"Here is the context summary of the data:\n"
86+
"{context}\n\n"
87+
"Evaluate the code based on the following criteria:\n"
88+
"- Correctness: Does the code achieve the intended goal or answer the question accurately?\n"
89+
"- Efficiency: Is the code optimized and avoids unnecessary computations or steps?\n"
90+
"- Clarity: Is the code written in a clear and understandable way?\n"
91+
"- Robustness: Does the code handle potential edge cases or errors gracefully?\n"
92+
"- Best Practices: Does the code follow standard coding practices and conventions?\n"
93+
"The code should only use the function execute_sql_query(sql_query: str) -> pd.Dataframe to connects to the database and get the data"
94+
"The code should declare the result variable as a dictionary with the following structure:\n"
95+
"'type': 'string', 'value': f'The highest salary is 2.' or 'type': 'number', 'value': 125 or 'type': 'dataframe', 'value': pd.DataFrame() or 'type': 'plot', 'value': 'temp_chart.png'\n"
96+
)
97+
98+
def test_judge_setup(self):
99+
"""Test evaluation setup with OpenAI."""
100+
question = "How many unique loan IDs are there?"
101+
102+
df = pai.read_csv(str(self.loans_path))
103+
df_context = DataFrame.serialize_dataframe(df)
104+
105+
response = df.chat(question)
106+
107+
prompt = self.evaluation_prompt.format(
108+
context=df_context, question=question, code=response.last_code_executed
109+
)
110+
111+
completion = self.client.beta.chat.completions.parse(
112+
model="gpt-4.1-mini",
113+
messages=[{"role": "user", "content": prompt}],
114+
response_format=Evaluation,
115+
)
116+
117+
evaluation_response: Evaluation = completion.choices[0].message.parsed
118+
119+
self.evaluation_scores.append(evaluation_response.score)
120+
121+
assert evaluation_response.score > 5, evaluation_response.justification
122+
123+
@pytest.mark.parametrize("question", loans_questions)
124+
def test_loans_questions(self, question):
125+
"""Test multiple loan-related questions."""
126+
127+
df = pai.read_csv(str(self.loans_path))
128+
df_context = DataFrame.serialize_dataframe(df)
129+
130+
response = df.chat(question)
131+
132+
prompt = self.evaluation_prompt.format(
133+
context=df_context, question=question, code=response.last_code_executed
134+
)
135+
136+
completion = self.client.beta.chat.completions.parse(
137+
model="gpt-4.1-mini",
138+
messages=[{"role": "user", "content": prompt}],
139+
response_format=Evaluation,
140+
)
141+
142+
evaluation_response: Evaluation = completion.choices[0].message.parsed
143+
144+
self.evaluation_scores.append(evaluation_response.score)
145+
146+
assert evaluation_response.score > 5, evaluation_response.justification
147+
148+
@pytest.mark.parametrize("question", heart_strokes_questions)
149+
def test_heart_strokes_questions(self, question):
150+
"""Test multiple loan-related questions."""
151+
152+
self.df = pai.read_csv(str(self.heart_stroke_path))
153+
df_context = DataFrame.serialize_dataframe(self.df)
154+
155+
response = self.df.chat(question)
156+
157+
prompt = self.evaluation_prompt.format(
158+
context=df_context, question=question, code=response.last_code_executed
159+
)
160+
161+
completion = self.client.beta.chat.completions.parse(
162+
model="gpt-4.1-mini",
163+
messages=[{"role": "user", "content": prompt}],
164+
response_format=Evaluation,
165+
)
166+
167+
evaluation_response: Evaluation = completion.choices[0].message.parsed
168+
169+
self.evaluation_scores.append(evaluation_response.score)
170+
171+
assert evaluation_response.score > 5, evaluation_response.justification
172+
173+
@pytest.mark.parametrize("question", combined_questions)
174+
def test_combined_questions_with_type(self, question):
175+
"""
176+
Test heart stoke related questions to ensure the response types match the expected ones.
177+
"""
178+
179+
heart_stroke = pai.read_csv(str(self.heart_stroke_path))
180+
loans = pai.read_csv(str(self.loans_path))
181+
182+
df_context = f"{DataFrame.serialize_dataframe(heart_stroke)}\n{DataFrame.serialize_dataframe(loans)}"
183+
184+
response = pai.chat(question, *(heart_stroke, loans))
185+
186+
prompt = self.evaluation_prompt.format(
187+
context=df_context, question=question, code=response.last_code_executed
188+
)
189+
190+
completion = self.client.beta.chat.completions.parse(
191+
model="gpt-4.1-mini",
192+
messages=[{"role": "user", "content": prompt}],
193+
response_format=Evaluation,
194+
)
195+
196+
evaluation_response: Evaluation = completion.choices[0].message.parsed
197+
198+
self.evaluation_scores.append(evaluation_response.score)
199+
200+
assert evaluation_response.score > 5, evaluation_response.justification
201+
202+
def test_average_score(self):
203+
if self.evaluation_scores:
204+
average_score = sum(self.evaluation_scores) / len(self.evaluation_scores)
205+
file_path = Path(self.root_dir) / "test_agent_llm_judge.txt"
206+
with open(file_path, "w") as f:
207+
f.write(f"{average_score}")
208+
assert (
209+
average_score >= 5
210+
), f"Average score should be at least 5, got {average_score}"

tests/unit_tests/agent/test_agent_llm_judge.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ def test_judge_setup(self):
109109
)
110110

111111
completion = self.client.beta.chat.completions.parse(
112-
model="gpt-4o-mini",
112+
model="gpt-4.1-mini",
113113
messages=[{"role": "user", "content": prompt}],
114114
response_format=Evaluation,
115115
)
@@ -134,7 +134,7 @@ def test_loans_questions(self, question):
134134
)
135135

136136
completion = self.client.beta.chat.completions.parse(
137-
model="gpt-4o-mini",
137+
model="gpt-4.1-mini",
138138
messages=[{"role": "user", "content": prompt}],
139139
response_format=Evaluation,
140140
)
@@ -159,7 +159,7 @@ def test_heart_strokes_questions(self, question):
159159
)
160160

161161
completion = self.client.beta.chat.completions.parse(
162-
model="gpt-4o-mini",
162+
model="gpt-4.1-mini",
163163
messages=[{"role": "user", "content": prompt}],
164164
response_format=Evaluation,
165165
)
@@ -188,7 +188,7 @@ def test_combined_questions_with_type(self, question):
188188
)
189189

190190
completion = self.client.beta.chat.completions.parse(
191-
model="gpt-4o-mini",
191+
model="gpt-4.1-mini",
192192
messages=[{"role": "user", "content": prompt}],
193193
response_format=Evaluation,
194194
)

0 commit comments

Comments
 (0)