From 834a209b71fce12706762c983285c1bebe252bb4 Mon Sep 17 00:00:00 2001 From: Yossi Ovadia Date: Thu, 2 Oct 2025 09:50:42 -0700 Subject: [PATCH 1/2] feat: improve batch classification test to validate accuracy Previously, the batch classification test only validated HTTP status and result count, but never checked if the classifications were correct. The expected_categories variable was created but never used for validation. Changes: - Extract actual categories from batch classification results - Compare against expected categories and calculate accuracy percentage - Add detailed output showing each classification result - Assert that accuracy meets 75% threshold - Maintain backward compatibility with existing HTTP/count checks This improved test now properly catches classification accuracy issues and will fail when the classification system returns incorrect results, exposing problems that were previously hidden. Related to issue #318: Batch Classification API Returns Incorrect Categories Signed-off-by: Yossi Ovadia --- e2e-tests/03-classification-api-test.py | 52 ++++++++++++++++++++++--- 1 file changed, 47 insertions(+), 5 deletions(-) diff --git a/e2e-tests/03-classification-api-test.py b/e2e-tests/03-classification-api-test.py index 804ddde9..47334a49 100755 --- a/e2e-tests/03-classification-api-test.py +++ b/e2e-tests/03-classification-api-test.py @@ -189,29 +189,71 @@ def test_batch_classification(self): response_json = response.json() results = response_json.get("results", []) + # Extract actual categories from results + actual_categories = [] + correct_classifications = 0 + + for i, result in enumerate(results): + if isinstance(result, dict): + actual_category = result.get("category", "unknown") + else: + actual_category = "unknown" + + actual_categories.append(actual_category) + + if i < len(expected_categories) and actual_category == expected_categories[i]: + correct_classifications += 1 + + # Calculate accuracy + accuracy = (correct_classifications / len(expected_categories)) * 100 if expected_categories else 0 + self.print_response_info( response, { "Total Texts": len(texts), "Results Count": len(results), "Processing Time (ms)": response_json.get("processing_time_ms", 0), + "Accuracy": f"{accuracy:.1f}% ({correct_classifications}/{len(expected_categories)})", }, ) - passed = response.status_code == 200 and len(results) == len(texts) + # Print detailed classification results + print("\nšŸ“Š Detailed Classification Results:") + for i, (text, expected, actual) in enumerate(zip(texts, expected_categories, actual_categories)): + status = "āœ…" if expected == actual else "āŒ" + print(f" {i+1}. {status} Expected: {expected:<15} | Actual: {actual:<15}") + print(f" Text: {text[:60]}...") + + # Check basic requirements first + basic_checks_passed = response.status_code == 200 and len(results) == len(texts) + + # Check classification accuracy (should be high for a working system) + accuracy_threshold = 75.0 # Expect at least 75% accuracy + accuracy_passed = accuracy >= accuracy_threshold + + overall_passed = basic_checks_passed and accuracy_passed self.print_test_result( - passed=passed, + passed=overall_passed, message=( - f"Successfully classified {len(results)} texts" - if passed - else f"Batch classification failed or returned wrong count" + f"Successfully classified {len(results)} texts with {accuracy:.1f}% accuracy" + if overall_passed + else f"Batch classification issues: Basic checks: {basic_checks_passed}, Accuracy: {accuracy:.1f}% (threshold: {accuracy_threshold}%)" ), ) + # Basic checks self.assertEqual(response.status_code, 200, "Batch request failed") self.assertEqual(len(results), len(texts), "Result count mismatch") + # NEW: Validate classification accuracy + self.assertGreaterEqual( + accuracy, + accuracy_threshold, + f"Classification accuracy too low: {accuracy:.1f}% < {accuracy_threshold}%. " + f"Expected: {expected_categories}, Actual: {actual_categories}" + ) + if __name__ == "__main__": unittest.main() From 3c9d3e51c57ad699df053b30e689739c5bdc9444 Mon Sep 17 00:00:00 2001 From: Yossi Ovadia Date: Thu, 2 Oct 2025 09:54:42 -0700 Subject: [PATCH 2/2] style: apply black formatting to classification test Automatic formatting applied by black pre-commit hook. Signed-off-by: Yossi Ovadia --- e2e-tests/03-classification-api-test.py | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/e2e-tests/03-classification-api-test.py b/e2e-tests/03-classification-api-test.py index 47334a49..eb930974 100755 --- a/e2e-tests/03-classification-api-test.py +++ b/e2e-tests/03-classification-api-test.py @@ -201,11 +201,18 @@ def test_batch_classification(self): actual_categories.append(actual_category) - if i < len(expected_categories) and actual_category == expected_categories[i]: + if ( + i < len(expected_categories) + and actual_category == expected_categories[i] + ): correct_classifications += 1 # Calculate accuracy - accuracy = (correct_classifications / len(expected_categories)) * 100 if expected_categories else 0 + accuracy = ( + (correct_classifications / len(expected_categories)) * 100 + if expected_categories + else 0 + ) self.print_response_info( response, @@ -219,7 +226,9 @@ def test_batch_classification(self): # Print detailed classification results print("\nšŸ“Š Detailed Classification Results:") - for i, (text, expected, actual) in enumerate(zip(texts, expected_categories, actual_categories)): + for i, (text, expected, actual) in enumerate( + zip(texts, expected_categories, actual_categories) + ): status = "āœ…" if expected == actual else "āŒ" print(f" {i+1}. {status} Expected: {expected:<15} | Actual: {actual:<15}") print(f" Text: {text[:60]}...") @@ -251,7 +260,7 @@ def test_batch_classification(self): accuracy, accuracy_threshold, f"Classification accuracy too low: {accuracy:.1f}% < {accuracy_threshold}%. " - f"Expected: {expected_categories}, Actual: {actual_categories}" + f"Expected: {expected_categories}, Actual: {actual_categories}", )