Skip to content

Commit ac34be9

Browse files
yossiovadiaAias00
authored andcommitted
Fix/improve batch classification test (#319)
* feat: improve batch classification test to validate accuracy Previously, the batch classification test only validated HTTP status and result count, but never checked if the classifications were correct. The expected_categories variable was created but never used for validation. Changes: - Extract actual categories from batch classification results - Compare against expected categories and calculate accuracy percentage - Add detailed output showing each classification result - Assert that accuracy meets 75% threshold - Maintain backward compatibility with existing HTTP/count checks This improved test now properly catches classification accuracy issues and will fail when the classification system returns incorrect results, exposing problems that were previously hidden. Related to issue #318: Batch Classification API Returns Incorrect Categories Signed-off-by: Yossi Ovadia <[email protected]> * style: apply black formatting to classification test Automatic formatting applied by black pre-commit hook. Signed-off-by: Yossi Ovadia <[email protected]> --------- Signed-off-by: Yossi Ovadia <[email protected]> Signed-off-by: liuhy <[email protected]>
1 parent 55896dd commit ac34be9

File tree

1 file changed

+56
-5
lines changed

1 file changed

+56
-5
lines changed

e2e-tests/03-classification-api-test.py

Lines changed: 56 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -189,29 +189,80 @@ def test_batch_classification(self):
189189
response_json = response.json()
190190
results = response_json.get("results", [])
191191

192+
# Extract actual categories from results
193+
actual_categories = []
194+
correct_classifications = 0
195+
196+
for i, result in enumerate(results):
197+
if isinstance(result, dict):
198+
actual_category = result.get("category", "unknown")
199+
else:
200+
actual_category = "unknown"
201+
202+
actual_categories.append(actual_category)
203+
204+
if (
205+
i < len(expected_categories)
206+
and actual_category == expected_categories[i]
207+
):
208+
correct_classifications += 1
209+
210+
# Calculate accuracy
211+
accuracy = (
212+
(correct_classifications / len(expected_categories)) * 100
213+
if expected_categories
214+
else 0
215+
)
216+
192217
self.print_response_info(
193218
response,
194219
{
195220
"Total Texts": len(texts),
196221
"Results Count": len(results),
197222
"Processing Time (ms)": response_json.get("processing_time_ms", 0),
223+
"Accuracy": f"{accuracy:.1f}% ({correct_classifications}/{len(expected_categories)})",
198224
},
199225
)
200226

201-
passed = response.status_code == 200 and len(results) == len(texts)
227+
# Print detailed classification results
228+
print("\n📊 Detailed Classification Results:")
229+
for i, (text, expected, actual) in enumerate(
230+
zip(texts, expected_categories, actual_categories)
231+
):
232+
status = "✅" if expected == actual else "❌"
233+
print(f" {i+1}. {status} Expected: {expected:<15} | Actual: {actual:<15}")
234+
print(f" Text: {text[:60]}...")
235+
236+
# Check basic requirements first
237+
basic_checks_passed = response.status_code == 200 and len(results) == len(texts)
238+
239+
# Check classification accuracy (should be high for a working system)
240+
accuracy_threshold = 75.0 # Expect at least 75% accuracy
241+
accuracy_passed = accuracy >= accuracy_threshold
242+
243+
overall_passed = basic_checks_passed and accuracy_passed
202244

203245
self.print_test_result(
204-
passed=passed,
246+
passed=overall_passed,
205247
message=(
206-
f"Successfully classified {len(results)} texts"
207-
if passed
208-
else f"Batch classification failed or returned wrong count"
248+
f"Successfully classified {len(results)} texts with {accuracy:.1f}% accuracy"
249+
if overall_passed
250+
else f"Batch classification issues: Basic checks: {basic_checks_passed}, Accuracy: {accuracy:.1f}% (threshold: {accuracy_threshold}%)"
209251
),
210252
)
211253

254+
# Basic checks
212255
self.assertEqual(response.status_code, 200, "Batch request failed")
213256
self.assertEqual(len(results), len(texts), "Result count mismatch")
214257

258+
# NEW: Validate classification accuracy
259+
self.assertGreaterEqual(
260+
accuracy,
261+
accuracy_threshold,
262+
f"Classification accuracy too low: {accuracy:.1f}% < {accuracy_threshold}%. "
263+
f"Expected: {expected_categories}, Actual: {actual_categories}",
264+
)
265+
215266

216267
if __name__ == "__main__":
217268
unittest.main()

0 commit comments

Comments
 (0)