@@ -189,29 +189,80 @@ def test_batch_classification(self):
189189 response_json = response .json ()
190190 results = response_json .get ("results" , [])
191191
192+ # Extract actual categories from results
193+ actual_categories = []
194+ correct_classifications = 0
195+
196+ for i , result in enumerate (results ):
197+ if isinstance (result , dict ):
198+ actual_category = result .get ("category" , "unknown" )
199+ else :
200+ actual_category = "unknown"
201+
202+ actual_categories .append (actual_category )
203+
204+ if (
205+ i < len (expected_categories )
206+ and actual_category == expected_categories [i ]
207+ ):
208+ correct_classifications += 1
209+
210+ # Calculate accuracy
211+ accuracy = (
212+ (correct_classifications / len (expected_categories )) * 100
213+ if expected_categories
214+ else 0
215+ )
216+
192217 self .print_response_info (
193218 response ,
194219 {
195220 "Total Texts" : len (texts ),
196221 "Results Count" : len (results ),
197222 "Processing Time (ms)" : response_json .get ("processing_time_ms" , 0 ),
223+ "Accuracy" : f"{ accuracy :.1f} % ({ correct_classifications } /{ len (expected_categories )} )" ,
198224 },
199225 )
200226
201- passed = response .status_code == 200 and len (results ) == len (texts )
227+ # Print detailed classification results
228+ print ("\n 📊 Detailed Classification Results:" )
229+ for i , (text , expected , actual ) in enumerate (
230+ zip (texts , expected_categories , actual_categories )
231+ ):
232+ status = "✅" if expected == actual else "❌"
233+ print (f" { i + 1 } . { status } Expected: { expected :<15} | Actual: { actual :<15} " )
234+ print (f" Text: { text [:60 ]} ..." )
235+
236+ # Check basic requirements first
237+ basic_checks_passed = response .status_code == 200 and len (results ) == len (texts )
238+
239+ # Check classification accuracy (should be high for a working system)
240+ accuracy_threshold = 75.0 # Expect at least 75% accuracy
241+ accuracy_passed = accuracy >= accuracy_threshold
242+
243+ overall_passed = basic_checks_passed and accuracy_passed
202244
203245 self .print_test_result (
204- passed = passed ,
246+ passed = overall_passed ,
205247 message = (
206- f"Successfully classified { len (results )} texts"
207- if passed
208- else f"Batch classification failed or returned wrong count "
248+ f"Successfully classified { len (results )} texts with { accuracy :.1f } % accuracy "
249+ if overall_passed
250+ else f"Batch classification issues: Basic checks: { basic_checks_passed } , Accuracy: { accuracy :.1f } % (threshold: { accuracy_threshold } %) "
209251 ),
210252 )
211253
254+ # Basic checks
212255 self .assertEqual (response .status_code , 200 , "Batch request failed" )
213256 self .assertEqual (len (results ), len (texts ), "Result count mismatch" )
214257
258+ # NEW: Validate classification accuracy
259+ self .assertGreaterEqual (
260+ accuracy ,
261+ accuracy_threshold ,
262+ f"Classification accuracy too low: { accuracy :.1f} % < { accuracy_threshold } %. "
263+ f"Expected: { expected_categories } , Actual: { actual_categories } " ,
264+ )
265+
215266
216267if __name__ == "__main__" :
217268 unittest .main ()
0 commit comments