|
| 1 | +import os |
| 2 | +import shutil |
| 3 | +from pathlib import Path |
| 4 | + |
| 5 | +import pytest |
| 6 | +from openai import OpenAI |
| 7 | +from pydantic import BaseModel |
| 8 | + |
| 9 | +import pandasai as pai |
| 10 | +from pandasai import DataFrame |
| 11 | +from pandasai.helpers.path import find_project_root |
| 12 | + |
| 13 | +# Read the API key from an environment variable |
| 14 | +JUDGE_OPENAI_API_KEY = os.getenv("JUDGE_OPENAI_API_KEY", None) |
| 15 | + |
| 16 | + |
| 17 | +class Evaluation(BaseModel): |
| 18 | + score: int |
| 19 | + justification: str |
| 20 | + |
| 21 | + |
| 22 | +@pytest.mark.skipif( |
| 23 | + JUDGE_OPENAI_API_KEY is None, |
| 24 | + reason="JUDGE_OPENAI_API_KEY key not set, skipping tests", |
| 25 | +) |
| 26 | +class TestAgentLLMJudge: |
| 27 | + root_dir = find_project_root() |
| 28 | + heart_stroke_path = os.path.join(root_dir, "examples", "data", "heart.csv") |
| 29 | + loans_path = os.path.join(root_dir, "examples", "data", "loans_payments.csv") |
| 30 | + |
| 31 | + loans_questions = [ |
| 32 | + "What is the total number of payments?", |
| 33 | + "What is the average payment amount?", |
| 34 | + "How many unique loan IDs are there?", |
| 35 | + "What is the most common payment amount?", |
| 36 | + "What is the total amount of payments?", |
| 37 | + "What is the median payment amount?", |
| 38 | + "How many payments are above $1000?", |
| 39 | + "What is the minimum and maximum payment?", |
| 40 | + "Show me a monthly trend of payments", |
| 41 | + "Show me the distribution of payment amounts", |
| 42 | + "Show me the top 10 payment amounts", |
| 43 | + "Give me a summary of payment statistics", |
| 44 | + "Show me payments above $1000", |
| 45 | + ] |
| 46 | + |
| 47 | + heart_strokes_questions = [ |
| 48 | + "What is the total number of patients in the dataset?", |
| 49 | + "How many people had a stroke?", |
| 50 | + "What is the average age of patients?", |
| 51 | + "What percentage of patients have hypertension?", |
| 52 | + "What is the average BMI?", |
| 53 | + "How many smokers are in the dataset?", |
| 54 | + "What is the gender distribution?", |
| 55 | + "Is there a correlation between age and stroke occurrence?", |
| 56 | + "Show me the age distribution of patients.", |
| 57 | + "What is the most common work type?", |
| 58 | + "Give me a breakdown of stroke occurrences.", |
| 59 | + "Show me hypertension statistics.", |
| 60 | + "Give me smoking statistics summary.", |
| 61 | + "Show me the distribution of work types.", |
| 62 | + ] |
| 63 | + |
| 64 | + combined_questions = [ |
| 65 | + "Compare payment patterns between age groups.", |
| 66 | + "Show relationship between payments and health conditions.", |
| 67 | + "Analyze payment differences between hypertension groups.", |
| 68 | + "Calculate average payments by health condition.", |
| 69 | + "Show payment distribution across age groups.", |
| 70 | + ] |
| 71 | + |
| 72 | + evaluation_scores = [] |
| 73 | + |
| 74 | + @pytest.fixture(autouse=True) |
| 75 | + def setup(self): |
| 76 | + """Setup shared resources for the test class.""" |
| 77 | + |
| 78 | + self.client = OpenAI(api_key=JUDGE_OPENAI_API_KEY) |
| 79 | + |
| 80 | + self.evaluation_prompt = ( |
| 81 | + "You are an AI evaluation expert tasked with assessing the quality of a code snippet provided as a response.\n" |
| 82 | + "The question was: {question}\n" |
| 83 | + "The AI provided the following code:\n" |
| 84 | + "{code}\n\n" |
| 85 | + "Here is the context summary of the data:\n" |
| 86 | + "{context}\n\n" |
| 87 | + "Evaluate the code based on the following criteria:\n" |
| 88 | + "- Correctness: Does the code achieve the intended goal or answer the question accurately?\n" |
| 89 | + "- Efficiency: Is the code optimized and avoids unnecessary computations or steps?\n" |
| 90 | + "- Clarity: Is the code written in a clear and understandable way?\n" |
| 91 | + "- Robustness: Does the code handle potential edge cases or errors gracefully?\n" |
| 92 | + "- Best Practices: Does the code follow standard coding practices and conventions?\n" |
| 93 | + "The code should only use the function execute_sql_query(sql_query: str) -> pd.Dataframe to connects to the database and get the data" |
| 94 | + "The code should declare the result variable as a dictionary with the following structure:\n" |
| 95 | + "'type': 'string', 'value': f'The highest salary is 2.' or 'type': 'number', 'value': 125 or 'type': 'dataframe', 'value': pd.DataFrame() or 'type': 'plot', 'value': 'temp_chart.png'\n" |
| 96 | + ) |
| 97 | + |
| 98 | + def test_judge_setup(self): |
| 99 | + """Test evaluation setup with OpenAI.""" |
| 100 | + question = "How many unique loan IDs are there?" |
| 101 | + |
| 102 | + df = pai.read_csv(str(self.loans_path)) |
| 103 | + df_context = DataFrame.serialize_dataframe(df) |
| 104 | + |
| 105 | + response = df.chat(question) |
| 106 | + |
| 107 | + prompt = self.evaluation_prompt.format( |
| 108 | + context=df_context, question=question, code=response.last_code_executed |
| 109 | + ) |
| 110 | + |
| 111 | + completion = self.client.beta.chat.completions.parse( |
| 112 | + model="gpt-4.1-mini", |
| 113 | + messages=[{"role": "user", "content": prompt}], |
| 114 | + response_format=Evaluation, |
| 115 | + ) |
| 116 | + |
| 117 | + evaluation_response: Evaluation = completion.choices[0].message.parsed |
| 118 | + |
| 119 | + self.evaluation_scores.append(evaluation_response.score) |
| 120 | + |
| 121 | + assert evaluation_response.score > 5, evaluation_response.justification |
| 122 | + |
| 123 | + @pytest.mark.parametrize("question", loans_questions) |
| 124 | + def test_loans_questions(self, question): |
| 125 | + """Test multiple loan-related questions.""" |
| 126 | + |
| 127 | + df = pai.read_csv(str(self.loans_path)) |
| 128 | + df_context = DataFrame.serialize_dataframe(df) |
| 129 | + |
| 130 | + response = df.chat(question) |
| 131 | + |
| 132 | + prompt = self.evaluation_prompt.format( |
| 133 | + context=df_context, question=question, code=response.last_code_executed |
| 134 | + ) |
| 135 | + |
| 136 | + completion = self.client.beta.chat.completions.parse( |
| 137 | + model="gpt-4.1-mini", |
| 138 | + messages=[{"role": "user", "content": prompt}], |
| 139 | + response_format=Evaluation, |
| 140 | + ) |
| 141 | + |
| 142 | + evaluation_response: Evaluation = completion.choices[0].message.parsed |
| 143 | + |
| 144 | + self.evaluation_scores.append(evaluation_response.score) |
| 145 | + |
| 146 | + assert evaluation_response.score > 5, evaluation_response.justification |
| 147 | + |
| 148 | + @pytest.mark.parametrize("question", heart_strokes_questions) |
| 149 | + def test_heart_strokes_questions(self, question): |
| 150 | + """Test multiple loan-related questions.""" |
| 151 | + |
| 152 | + self.df = pai.read_csv(str(self.heart_stroke_path)) |
| 153 | + df_context = DataFrame.serialize_dataframe(self.df) |
| 154 | + |
| 155 | + response = self.df.chat(question) |
| 156 | + |
| 157 | + prompt = self.evaluation_prompt.format( |
| 158 | + context=df_context, question=question, code=response.last_code_executed |
| 159 | + ) |
| 160 | + |
| 161 | + completion = self.client.beta.chat.completions.parse( |
| 162 | + model="gpt-4.1-mini", |
| 163 | + messages=[{"role": "user", "content": prompt}], |
| 164 | + response_format=Evaluation, |
| 165 | + ) |
| 166 | + |
| 167 | + evaluation_response: Evaluation = completion.choices[0].message.parsed |
| 168 | + |
| 169 | + self.evaluation_scores.append(evaluation_response.score) |
| 170 | + |
| 171 | + assert evaluation_response.score > 5, evaluation_response.justification |
| 172 | + |
| 173 | + @pytest.mark.parametrize("question", combined_questions) |
| 174 | + def test_combined_questions_with_type(self, question): |
| 175 | + """ |
| 176 | + Test heart stoke related questions to ensure the response types match the expected ones. |
| 177 | + """ |
| 178 | + |
| 179 | + heart_stroke = pai.read_csv(str(self.heart_stroke_path)) |
| 180 | + loans = pai.read_csv(str(self.loans_path)) |
| 181 | + |
| 182 | + df_context = f"{DataFrame.serialize_dataframe(heart_stroke)}\n{DataFrame.serialize_dataframe(loans)}" |
| 183 | + |
| 184 | + response = pai.chat(question, *(heart_stroke, loans)) |
| 185 | + |
| 186 | + prompt = self.evaluation_prompt.format( |
| 187 | + context=df_context, question=question, code=response.last_code_executed |
| 188 | + ) |
| 189 | + |
| 190 | + completion = self.client.beta.chat.completions.parse( |
| 191 | + model="gpt-4.1-mini", |
| 192 | + messages=[{"role": "user", "content": prompt}], |
| 193 | + response_format=Evaluation, |
| 194 | + ) |
| 195 | + |
| 196 | + evaluation_response: Evaluation = completion.choices[0].message.parsed |
| 197 | + |
| 198 | + self.evaluation_scores.append(evaluation_response.score) |
| 199 | + |
| 200 | + assert evaluation_response.score > 5, evaluation_response.justification |
| 201 | + |
| 202 | + def test_average_score(self): |
| 203 | + if self.evaluation_scores: |
| 204 | + average_score = sum(self.evaluation_scores) / len(self.evaluation_scores) |
| 205 | + file_path = Path(self.root_dir) / "test_agent_llm_judge.txt" |
| 206 | + with open(file_path, "w") as f: |
| 207 | + f.write(f"{average_score}") |
| 208 | + assert ( |
| 209 | + average_score >= 5 |
| 210 | + ), f"Average score should be at least 5, got {average_score}" |
0 commit comments