Skip to content

Commit 834a209

Browse files
committed
feat: improve batch classification test to validate accuracy
Previously, the batch classification test only validated HTTP status and result count, but never checked if the classifications were correct. The expected_categories variable was created but never used for validation. Changes: - Extract actual categories from batch classification results - Compare against expected categories and calculate accuracy percentage - Add detailed output showing each classification result - Assert that accuracy meets 75% threshold - Maintain backward compatibility with existing HTTP/count checks This improved test now properly catches classification accuracy issues and will fail when the classification system returns incorrect results, exposing problems that were previously hidden. Related to issue #318: Batch Classification API Returns Incorrect Categories Signed-off-by: Yossi Ovadia <[email protected]>
1 parent 077b8d0 commit 834a209

File tree

1 file changed

+47
-5
lines changed

1 file changed

+47
-5
lines changed

e2e-tests/03-classification-api-test.py

Lines changed: 47 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -189,29 +189,71 @@ def test_batch_classification(self):
189189
response_json = response.json()
190190
results = response_json.get("results", [])
191191

192+
# Extract actual categories from results
193+
actual_categories = []
194+
correct_classifications = 0
195+
196+
for i, result in enumerate(results):
197+
if isinstance(result, dict):
198+
actual_category = result.get("category", "unknown")
199+
else:
200+
actual_category = "unknown"
201+
202+
actual_categories.append(actual_category)
203+
204+
if i < len(expected_categories) and actual_category == expected_categories[i]:
205+
correct_classifications += 1
206+
207+
# Calculate accuracy
208+
accuracy = (correct_classifications / len(expected_categories)) * 100 if expected_categories else 0
209+
192210
self.print_response_info(
193211
response,
194212
{
195213
"Total Texts": len(texts),
196214
"Results Count": len(results),
197215
"Processing Time (ms)": response_json.get("processing_time_ms", 0),
216+
"Accuracy": f"{accuracy:.1f}% ({correct_classifications}/{len(expected_categories)})",
198217
},
199218
)
200219

201-
passed = response.status_code == 200 and len(results) == len(texts)
220+
# Print detailed classification results
221+
print("\n📊 Detailed Classification Results:")
222+
for i, (text, expected, actual) in enumerate(zip(texts, expected_categories, actual_categories)):
223+
status = "✅" if expected == actual else "❌"
224+
print(f" {i+1}. {status} Expected: {expected:<15} | Actual: {actual:<15}")
225+
print(f" Text: {text[:60]}...")
226+
227+
# Check basic requirements first
228+
basic_checks_passed = response.status_code == 200 and len(results) == len(texts)
229+
230+
# Check classification accuracy (should be high for a working system)
231+
accuracy_threshold = 75.0 # Expect at least 75% accuracy
232+
accuracy_passed = accuracy >= accuracy_threshold
233+
234+
overall_passed = basic_checks_passed and accuracy_passed
202235

203236
self.print_test_result(
204-
passed=passed,
237+
passed=overall_passed,
205238
message=(
206-
f"Successfully classified {len(results)} texts"
207-
if passed
208-
else f"Batch classification failed or returned wrong count"
239+
f"Successfully classified {len(results)} texts with {accuracy:.1f}% accuracy"
240+
if overall_passed
241+
else f"Batch classification issues: Basic checks: {basic_checks_passed}, Accuracy: {accuracy:.1f}% (threshold: {accuracy_threshold}%)"
209242
),
210243
)
211244

245+
# Basic checks
212246
self.assertEqual(response.status_code, 200, "Batch request failed")
213247
self.assertEqual(len(results), len(texts), "Result count mismatch")
214248

249+
# NEW: Validate classification accuracy
250+
self.assertGreaterEqual(
251+
accuracy,
252+
accuracy_threshold,
253+
f"Classification accuracy too low: {accuracy:.1f}% < {accuracy_threshold}%. "
254+
f"Expected: {expected_categories}, Actual: {actual_categories}"
255+
)
256+
215257

216258
if __name__ == "__main__":
217259
unittest.main()

0 commit comments

Comments
 (0)