Skip to content

Commit 4e2aaa3

Browse files
committed
Add LLM-judged evaluation function for RAG tests
Introduces a new function `run_llm_judged_tests` to perform end-to-end tests on RAG systems using LLM evaluation. The implementation includes: - Parallel processing of test cases - Scoring for toxicity, faithfulness, helpfulness, and relevance - Retry logic for robust test execution - Detailed logging of test results
1 parent b82f6cf commit 4e2aaa3

File tree

1 file changed

+151
-17
lines changed

1 file changed

+151
-17
lines changed

llm-complete-guide/steps/eval_retrieval.py

Lines changed: 151 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,13 @@
1515
# limitations under the License.
1616

1717
import logging
18-
from typing import Annotated, List, Tuple, Dict, Callable, Any
18+
from typing import Annotated, List, Tuple, Dict, Callable, Any, Optional
1919
from multiprocessing import Pool, cpu_count
2020
from functools import partial
2121
from tenacity import retry, stop_after_attempt, wait_exponential
2222
from tqdm import tqdm
2323
from concurrent.futures import ThreadPoolExecutor
24+
import json
2425

2526
from datasets import load_dataset
2627
from utils.llm_utils import (
@@ -178,14 +179,16 @@ def process_with_progress(
178179
Returns:
179180
List of results
180181
"""
181-
logger.info(f"{desc} - Starting parallel processing with {n_processes} workers")
182-
182+
logger.info(
183+
f"{desc} - Starting parallel processing with {n_processes} workers"
184+
)
185+
183186
results = []
184187
with Pool(processes=n_processes) as pool:
185188
for i, result in enumerate(pool.imap(worker_fn, items), 1):
186189
results.append(result)
187190
logger.info(f"Completed {i}/{len(items)} tests")
188-
191+
189192
logger.info(f"{desc} - Completed processing {len(results)} items")
190193
return results
191194

@@ -205,16 +208,18 @@ def test_retrieved_docs_retrieve_best_url(
205208
"""
206209
total_tests = len(question_doc_pairs)
207210
logger.info(f"Starting retrieval test with {total_tests} questions...")
208-
211+
209212
n_processes = max(1, cpu_count() // 2)
210-
worker = partial(process_single_pair_with_retry, use_reranking=use_reranking)
213+
worker = partial(
214+
process_single_pair_with_retry, use_reranking=use_reranking
215+
)
211216

212217
try:
213218
results = process_with_progress(
214219
question_doc_pairs,
215220
worker,
216221
n_processes,
217-
"Testing document retrieval"
222+
"Testing document retrieval",
218223
)
219224

220225
failures = 0
@@ -352,14 +357,13 @@ def perform_retrieval_evaluation(
352357

353358
total_tests = len(sampled_dataset)
354359
n_processes = max(1, cpu_count() // 2)
355-
worker = partial(process_single_dataset_item_with_retry, use_reranking=use_reranking)
360+
worker = partial(
361+
process_single_dataset_item_with_retry, use_reranking=use_reranking
362+
)
356363

357364
try:
358365
results = process_with_progress(
359-
sampled_dataset,
360-
worker,
361-
n_processes,
362-
"Evaluating retrieval"
366+
sampled_dataset, worker, n_processes, "Evaluating retrieval"
363367
)
364368

365369
failures = 0
@@ -487,14 +491,13 @@ def run_simple_tests(test_data: list, test_function: Callable) -> float:
487491
"""
488492
total_tests = len(test_data)
489493
n_processes = max(1, cpu_count() // 2)
490-
worker = partial(process_single_test_with_retry, test_function=test_function)
494+
worker = partial(
495+
process_single_test_with_retry, test_function=test_function
496+
)
491497

492498
try:
493499
results = process_with_progress(
494-
test_data,
495-
worker,
496-
n_processes,
497-
"Running tests"
500+
test_data, worker, n_processes, "Running tests"
498501
)
499502

500503
failures = 0
@@ -522,3 +525,134 @@ def run_simple_tests(test_data: list, test_function: Callable) -> float:
522525
except Exception as e:
523526
logger.error(f"Error during parallel processing: {str(e)}")
524527
raise
528+
529+
530+
def process_single_llm_test(
531+
item: Dict,
532+
test_function: Callable,
533+
) -> Tuple[float, float, float, float]:
534+
"""Process a single LLM test item.
535+
536+
Args:
537+
item: Dictionary containing the dataset item
538+
test_function: The test function to run
539+
540+
Returns:
541+
Tuple containing (toxicity, faithfulness, helpfulness, relevance) scores
542+
"""
543+
# Assuming only one question per item
544+
question = item["generated_questions"][0]
545+
context = item["page_content"]
546+
547+
try:
548+
result = test_function(question, context)
549+
return (
550+
result.toxicity,
551+
result.faithfulness,
552+
result.helpfulness,
553+
result.relevance,
554+
)
555+
except json.JSONDecodeError as e:
556+
logger.error(f"Failed for question: {question}. Error: {e}")
557+
# Return None to indicate this test should be skipped
558+
return None
559+
560+
561+
@retry(
562+
stop=stop_after_attempt(3),
563+
wait=wait_exponential(multiplier=1, min=4, max=10),
564+
reraise=True,
565+
)
566+
def process_single_llm_test_with_retry(
567+
item: Dict,
568+
test_function: Callable,
569+
) -> Optional[Tuple[float, float, float, float]]:
570+
"""Process a single LLM test item with retry logic.
571+
572+
Args:
573+
item: Dictionary containing the dataset item
574+
test_function: The test function to run
575+
576+
Returns:
577+
Optional tuple containing (toxicity, faithfulness, helpfulness, relevance) scores
578+
Returns None if the test should be skipped
579+
"""
580+
try:
581+
return process_single_llm_test(item, test_function)
582+
except Exception as e:
583+
logger.warning(f"Error processing LLM test: {str(e)}. Retrying...")
584+
raise
585+
586+
587+
def run_llm_judged_tests(
588+
test_function: Callable,
589+
sample_size: int = 10,
590+
) -> Tuple[
591+
Annotated[float, "average_toxicity_score"],
592+
Annotated[float, "average_faithfulness_score"],
593+
Annotated[float, "average_helpfulness_score"],
594+
Annotated[float, "average_relevance_score"],
595+
]:
596+
"""E2E tests judged by an LLM.
597+
598+
Args:
599+
test_data (list): The test data.
600+
test_function (function): The test function to run.
601+
sample_size (int): The sample size to run the tests on.
602+
603+
Returns:
604+
Tuple: The average toxicity, faithfulness, helpfulness, and relevance scores.
605+
"""
606+
# Load the dataset from the Hugging Face Hub
607+
dataset = load_dataset("zenml/rag_qa_embedding_questions", split="train")
608+
609+
# Shuffle the dataset and select a random sample
610+
sampled_dataset = dataset.shuffle(seed=42).select(range(sample_size))
611+
612+
n_processes = max(1, cpu_count() // 2)
613+
worker = partial(
614+
process_single_llm_test_with_retry, test_function=test_function
615+
)
616+
617+
try:
618+
results = process_with_progress(
619+
sampled_dataset, worker, n_processes, "Running LLM judged tests"
620+
)
621+
622+
# Filter out None results (failed tests)
623+
valid_results = [r for r in results if r is not None]
624+
total_tests = len(valid_results)
625+
626+
if total_tests == 0:
627+
logger.error("All tests failed!")
628+
return 0.0, 0.0, 0.0, 0.0
629+
630+
# Calculate totals
631+
total_toxicity = sum(r[0] for r in valid_results)
632+
total_faithfulness = sum(r[1] for r in valid_results)
633+
total_helpfulness = sum(r[2] for r in valid_results)
634+
total_relevance = sum(r[3] for r in valid_results)
635+
636+
# Calculate averages
637+
average_toxicity_score = total_toxicity / total_tests
638+
average_faithfulness_score = total_faithfulness / total_tests
639+
average_helpfulness_score = total_helpfulness / total_tests
640+
average_relevance_score = total_relevance / total_tests
641+
642+
logger.info("\nTest Results Summary:")
643+
logger.info(f"Total valid tests: {total_tests}")
644+
logger.info(f"Average toxicity: {average_toxicity_score:.3f}")
645+
logger.info(f"Average faithfulness: {average_faithfulness_score:.3f}")
646+
logger.info(f"Average helpfulness: {average_helpfulness_score:.3f}")
647+
logger.info(f"Average relevance: {average_relevance_score:.3f}")
648+
649+
return (
650+
round(average_toxicity_score, 3),
651+
round(average_faithfulness_score, 3),
652+
round(average_helpfulness_score, 3),
653+
round(average_relevance_score, 3),
654+
)
655+
656+
except Exception as e:
657+
logger.error(f"Error during parallel processing: {str(e)}")
658+
raise

0 commit comments

Comments
 (0)