diff --git a/e2e-tests/03-classification-api-test.py b/e2e-tests/03-classification-api-test.py index 804ddde9..eb930974 100755 --- a/e2e-tests/03-classification-api-test.py +++ b/e2e-tests/03-classification-api-test.py @@ -189,29 +189,80 @@ def test_batch_classification(self): response_json = response.json() results = response_json.get("results", []) + # Extract actual categories from results + actual_categories = [] + correct_classifications = 0 + + for i, result in enumerate(results): + if isinstance(result, dict): + actual_category = result.get("category", "unknown") + else: + actual_category = "unknown" + + actual_categories.append(actual_category) + + if ( + i < len(expected_categories) + and actual_category == expected_categories[i] + ): + correct_classifications += 1 + + # Calculate accuracy + accuracy = ( + (correct_classifications / len(expected_categories)) * 100 + if expected_categories + else 0 + ) + self.print_response_info( response, { "Total Texts": len(texts), "Results Count": len(results), "Processing Time (ms)": response_json.get("processing_time_ms", 0), + "Accuracy": f"{accuracy:.1f}% ({correct_classifications}/{len(expected_categories)})", }, ) - passed = response.status_code == 200 and len(results) == len(texts) + # Print detailed classification results + print("\nšŸ“Š Detailed Classification Results:") + for i, (text, expected, actual) in enumerate( + zip(texts, expected_categories, actual_categories) + ): + status = "āœ…" if expected == actual else "āŒ" + print(f" {i+1}. {status} Expected: {expected:<15} | Actual: {actual:<15}") + print(f" Text: {text[:60]}...") + + # Check basic requirements first + basic_checks_passed = response.status_code == 200 and len(results) == len(texts) + + # Check classification accuracy (should be high for a working system) + accuracy_threshold = 75.0 # Expect at least 75% accuracy + accuracy_passed = accuracy >= accuracy_threshold + + overall_passed = basic_checks_passed and accuracy_passed self.print_test_result( - passed=passed, + passed=overall_passed, message=( - f"Successfully classified {len(results)} texts" - if passed - else f"Batch classification failed or returned wrong count" + f"Successfully classified {len(results)} texts with {accuracy:.1f}% accuracy" + if overall_passed + else f"Batch classification issues: Basic checks: {basic_checks_passed}, Accuracy: {accuracy:.1f}% (threshold: {accuracy_threshold}%)" ), ) + # Basic checks self.assertEqual(response.status_code, 200, "Batch request failed") self.assertEqual(len(results), len(texts), "Result count mismatch") + # NEW: Validate classification accuracy + self.assertGreaterEqual( + accuracy, + accuracy_threshold, + f"Classification accuracy too low: {accuracy:.1f}% < {accuracy_threshold}%. " + f"Expected: {expected_categories}, Actual: {actual_categories}", + ) + if __name__ == "__main__": unittest.main()