Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 56 additions & 5 deletions e2e-tests/03-classification-api-test.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,29 +189,80 @@ def test_batch_classification(self):
response_json = response.json()
results = response_json.get("results", [])

# Extract actual categories from results
actual_categories = []
correct_classifications = 0

for i, result in enumerate(results):
if isinstance(result, dict):
actual_category = result.get("category", "unknown")
else:
actual_category = "unknown"

actual_categories.append(actual_category)

if (
i < len(expected_categories)
and actual_category == expected_categories[i]
):
correct_classifications += 1

# Calculate accuracy
accuracy = (
(correct_classifications / len(expected_categories)) * 100
if expected_categories
else 0
)

self.print_response_info(
response,
{
"Total Texts": len(texts),
"Results Count": len(results),
"Processing Time (ms)": response_json.get("processing_time_ms", 0),
"Accuracy": f"{accuracy:.1f}% ({correct_classifications}/{len(expected_categories)})",
},
)

passed = response.status_code == 200 and len(results) == len(texts)
# Print detailed classification results
print("\n📊 Detailed Classification Results:")
for i, (text, expected, actual) in enumerate(
zip(texts, expected_categories, actual_categories)
):
status = "✅" if expected == actual else "❌"
print(f" {i+1}. {status} Expected: {expected:<15} | Actual: {actual:<15}")
print(f" Text: {text[:60]}...")

# Check basic requirements first
basic_checks_passed = response.status_code == 200 and len(results) == len(texts)

# Check classification accuracy (should be high for a working system)
accuracy_threshold = 75.0 # Expect at least 75% accuracy
accuracy_passed = accuracy >= accuracy_threshold

overall_passed = basic_checks_passed and accuracy_passed

self.print_test_result(
passed=passed,
passed=overall_passed,
message=(
f"Successfully classified {len(results)} texts"
if passed
else f"Batch classification failed or returned wrong count"
f"Successfully classified {len(results)} texts with {accuracy:.1f}% accuracy"
if overall_passed
else f"Batch classification issues: Basic checks: {basic_checks_passed}, Accuracy: {accuracy:.1f}% (threshold: {accuracy_threshold}%)"
),
)

# Basic checks
self.assertEqual(response.status_code, 200, "Batch request failed")
self.assertEqual(len(results), len(texts), "Result count mismatch")

# NEW: Validate classification accuracy
self.assertGreaterEqual(
accuracy,
accuracy_threshold,
f"Classification accuracy too low: {accuracy:.1f}% < {accuracy_threshold}%. "
f"Expected: {expected_categories}, Actual: {actual_categories}",
)


if __name__ == "__main__":
unittest.main()
Loading