1515# limitations under the License.
1616
1717import logging
18- from typing import Annotated , List , Tuple , Dict , Callable , Any
18+ from typing import Annotated , List , Tuple , Dict , Callable , Any , Optional
1919from multiprocessing import Pool , cpu_count
2020from functools import partial
2121from tenacity import retry , stop_after_attempt , wait_exponential
2222from tqdm import tqdm
2323from concurrent .futures import ThreadPoolExecutor
24+ import json
2425
2526from datasets import load_dataset
2627from utils .llm_utils import (
@@ -178,14 +179,16 @@ def process_with_progress(
178179 Returns:
179180 List of results
180181 """
181- logger .info (f"{ desc } - Starting parallel processing with { n_processes } workers" )
182-
182+ logger .info (
183+ f"{ desc } - Starting parallel processing with { n_processes } workers"
184+ )
185+
183186 results = []
184187 with Pool (processes = n_processes ) as pool :
185188 for i , result in enumerate (pool .imap (worker_fn , items ), 1 ):
186189 results .append (result )
187190 logger .info (f"Completed { i } /{ len (items )} tests" )
188-
191+
189192 logger .info (f"{ desc } - Completed processing { len (results )} items" )
190193 return results
191194
@@ -205,16 +208,18 @@ def test_retrieved_docs_retrieve_best_url(
205208 """
206209 total_tests = len (question_doc_pairs )
207210 logger .info (f"Starting retrieval test with { total_tests } questions..." )
208-
211+
209212 n_processes = max (1 , cpu_count () // 2 )
210- worker = partial (process_single_pair_with_retry , use_reranking = use_reranking )
213+ worker = partial (
214+ process_single_pair_with_retry , use_reranking = use_reranking
215+ )
211216
212217 try :
213218 results = process_with_progress (
214219 question_doc_pairs ,
215220 worker ,
216221 n_processes ,
217- "Testing document retrieval"
222+ "Testing document retrieval" ,
218223 )
219224
220225 failures = 0
@@ -352,14 +357,13 @@ def perform_retrieval_evaluation(
352357
353358 total_tests = len (sampled_dataset )
354359 n_processes = max (1 , cpu_count () // 2 )
355- worker = partial (process_single_dataset_item_with_retry , use_reranking = use_reranking )
360+ worker = partial (
361+ process_single_dataset_item_with_retry , use_reranking = use_reranking
362+ )
356363
357364 try :
358365 results = process_with_progress (
359- sampled_dataset ,
360- worker ,
361- n_processes ,
362- "Evaluating retrieval"
366+ sampled_dataset , worker , n_processes , "Evaluating retrieval"
363367 )
364368
365369 failures = 0
@@ -487,14 +491,13 @@ def run_simple_tests(test_data: list, test_function: Callable) -> float:
487491 """
488492 total_tests = len (test_data )
489493 n_processes = max (1 , cpu_count () // 2 )
490- worker = partial (process_single_test_with_retry , test_function = test_function )
494+ worker = partial (
495+ process_single_test_with_retry , test_function = test_function
496+ )
491497
492498 try :
493499 results = process_with_progress (
494- test_data ,
495- worker ,
496- n_processes ,
497- "Running tests"
500+ test_data , worker , n_processes , "Running tests"
498501 )
499502
500503 failures = 0
@@ -522,3 +525,134 @@ def run_simple_tests(test_data: list, test_function: Callable) -> float:
522525 except Exception as e :
523526 logger .error (f"Error during parallel processing: { str (e )} " )
524527 raise
528+
529+
530+ def process_single_llm_test (
531+ item : Dict ,
532+ test_function : Callable ,
533+ ) -> Tuple [float , float , float , float ]:
534+ """Process a single LLM test item.
535+
536+ Args:
537+ item: Dictionary containing the dataset item
538+ test_function: The test function to run
539+
540+ Returns:
541+ Tuple containing (toxicity, faithfulness, helpfulness, relevance) scores
542+ """
543+ # Assuming only one question per item
544+ question = item ["generated_questions" ][0 ]
545+ context = item ["page_content" ]
546+
547+ try :
548+ result = test_function (question , context )
549+ return (
550+ result .toxicity ,
551+ result .faithfulness ,
552+ result .helpfulness ,
553+ result .relevance ,
554+ )
555+ except json .JSONDecodeError as e :
556+ logger .error (f"Failed for question: { question } . Error: { e } " )
557+ # Return None to indicate this test should be skipped
558+ return None
559+
560+
561+ @retry (
562+ stop = stop_after_attempt (3 ),
563+ wait = wait_exponential (multiplier = 1 , min = 4 , max = 10 ),
564+ reraise = True ,
565+ )
566+ def process_single_llm_test_with_retry (
567+ item : Dict ,
568+ test_function : Callable ,
569+ ) -> Optional [Tuple [float , float , float , float ]]:
570+ """Process a single LLM test item with retry logic.
571+
572+ Args:
573+ item: Dictionary containing the dataset item
574+ test_function: The test function to run
575+
576+ Returns:
577+ Optional tuple containing (toxicity, faithfulness, helpfulness, relevance) scores
578+ Returns None if the test should be skipped
579+ """
580+ try :
581+ return process_single_llm_test (item , test_function )
582+ except Exception as e :
583+ logger .warning (f"Error processing LLM test: { str (e )} . Retrying..." )
584+ raise
585+
586+
587+ def run_llm_judged_tests (
588+ test_function : Callable ,
589+ sample_size : int = 10 ,
590+ ) -> Tuple [
591+ Annotated [float , "average_toxicity_score" ],
592+ Annotated [float , "average_faithfulness_score" ],
593+ Annotated [float , "average_helpfulness_score" ],
594+ Annotated [float , "average_relevance_score" ],
595+ ]:
596+ """E2E tests judged by an LLM.
597+
598+ Args:
599+ test_data (list): The test data.
600+ test_function (function): The test function to run.
601+ sample_size (int): The sample size to run the tests on.
602+
603+ Returns:
604+ Tuple: The average toxicity, faithfulness, helpfulness, and relevance scores.
605+ """
606+ # Load the dataset from the Hugging Face Hub
607+ dataset = load_dataset ("zenml/rag_qa_embedding_questions" , split = "train" )
608+
609+ # Shuffle the dataset and select a random sample
610+ sampled_dataset = dataset .shuffle (seed = 42 ).select (range (sample_size ))
611+
612+ n_processes = max (1 , cpu_count () // 2 )
613+ worker = partial (
614+ process_single_llm_test_with_retry , test_function = test_function
615+ )
616+
617+ try :
618+ results = process_with_progress (
619+ sampled_dataset , worker , n_processes , "Running LLM judged tests"
620+ )
621+
622+ # Filter out None results (failed tests)
623+ valid_results = [r for r in results if r is not None ]
624+ total_tests = len (valid_results )
625+
626+ if total_tests == 0 :
627+ logger .error ("All tests failed!" )
628+ return 0.0 , 0.0 , 0.0 , 0.0
629+
630+ # Calculate totals
631+ total_toxicity = sum (r [0 ] for r in valid_results )
632+ total_faithfulness = sum (r [1 ] for r in valid_results )
633+ total_helpfulness = sum (r [2 ] for r in valid_results )
634+ total_relevance = sum (r [3 ] for r in valid_results )
635+
636+ # Calculate averages
637+ average_toxicity_score = total_toxicity / total_tests
638+ average_faithfulness_score = total_faithfulness / total_tests
639+ average_helpfulness_score = total_helpfulness / total_tests
640+ average_relevance_score = total_relevance / total_tests
641+
642+ logger .info ("\n Test Results Summary:" )
643+ logger .info (f"Total valid tests: { total_tests } " )
644+ logger .info (f"Average toxicity: { average_toxicity_score :.3f} " )
645+ logger .info (f"Average faithfulness: { average_faithfulness_score :.3f} " )
646+ logger .info (f"Average helpfulness: { average_helpfulness_score :.3f} " )
647+ logger .info (f"Average relevance: { average_relevance_score :.3f} " )
648+
649+ return (
650+ round (average_toxicity_score , 3 ),
651+ round (average_faithfulness_score , 3 ),
652+ round (average_helpfulness_score , 3 ),
653+ round (average_relevance_score , 3 ),
654+ )
655+
656+ except Exception as e :
657+ logger .error (f"Error during parallel processing: { str (e )} " )
658+ raise
0 commit comments