1717import json
1818import logging
1919import warnings
20- from typing import Annotated , Callable , Tuple
20+ from typing import Annotated , Callable , List , Tuple
2121
2222# Suppress the specific FutureWarning about clean_up_tokenization_spaces
2323warnings .filterwarnings (
3131from litellm import completion
3232from pydantic import BaseModel , conint
3333from structures import TestResult
34- from utils .llm_utils import process_input_with_retrieval
34+ from utils .llm_utils import get_completion_from_messages , process_input_with_retrieval
3535from utils .openai_utils import get_openai_api_key
3636from zenml import step
3737
7070
7171
7272def test_content_for_bad_words (
73- item : dict , n_items_retrieved : int = 5
73+ item : dict , n_items_retrieved : int = 5 , tracing_tags : List [ str ] = []
7474) -> TestResult :
7575 """
7676 Test if responses contain bad words.
@@ -85,7 +85,7 @@ def test_content_for_bad_words(
8585 question = item ["question" ]
8686 bad_words = item ["bad_words" ]
8787 response = process_input_with_retrieval (
88- question , n_items_retrieved = n_items_retrieved
88+ question , n_items_retrieved = n_items_retrieved , tracing_tags = tracing_tags
8989 )
9090 for word in bad_words :
9191 if word in response :
@@ -99,7 +99,7 @@ def test_content_for_bad_words(
9999
100100
101101def test_response_starts_with_bad_words (
102- item : dict , n_items_retrieved : int = 5
102+ item : dict , n_items_retrieved : int = 5 , tracing_tags : List [ str ] = []
103103) -> TestResult :
104104 """
105105 Test if responses improperly start with bad words.
@@ -114,7 +114,7 @@ def test_response_starts_with_bad_words(
114114 question = item ["question" ]
115115 bad_words = item ["bad_words" ]
116116 response = process_input_with_retrieval (
117- question , n_items_retrieved = n_items_retrieved
117+ question , n_items_retrieved = n_items_retrieved , tracing_tags = tracing_tags
118118 )
119119 for word in bad_words :
120120 if response .lower ().startswith (word .lower ()):
@@ -128,7 +128,7 @@ def test_response_starts_with_bad_words(
128128
129129
130130def test_content_contains_good_words (
131- item : dict , n_items_retrieved : int = 5
131+ item : dict , n_items_retrieved : int = 5 , tracing_tags : List [ str ] = []
132132) -> TestResult :
133133 """
134134 Test if responses properly contain good words.
@@ -143,7 +143,7 @@ def test_content_contains_good_words(
143143 question = item ["question" ]
144144 good_words = item ["good_words" ]
145145 response = process_input_with_retrieval (
146- question , n_items_retrieved = n_items_retrieved
146+ question , n_items_retrieved = n_items_retrieved , tracing_tags = tracing_tags
147147 )
148148 for word in good_words :
149149 if word not in response :
@@ -179,6 +179,7 @@ def llm_judged_test_e2e(
179179 question : str ,
180180 context : str ,
181181 n_items_retrieved : int = 5 ,
182+ tracing_tags : List [str ] = []
182183) -> LLMJudgedTestResult :
183184 """E2E tests judged by an LLM.
184185
@@ -191,7 +192,7 @@ def llm_judged_test_e2e(
191192 """
192193 logging .debug ("Starting LLM judged test..." )
193194 response = process_input_with_retrieval (
194- question , n_items_retrieved = n_items_retrieved
195+ question , n_items_retrieved = n_items_retrieved , tracing_tags = tracing_tags
195196 )
196197 logging .debug ("Input processed with retrieval." )
197198 prompt = f"""
@@ -217,13 +218,12 @@ def llm_judged_test_e2e(
217218 }}
218219 """
219220 logging .debug ("Prompt created." )
220- response = completion (
221- model = "gpt-4-turbo" ,
221+ json_output = get_completion_from_messages (
222222 messages = [{"content" : prompt , "role" : "user" }],
223- api_key = get_openai_api_key (),
224- )
223+ model = "gpt-4-turbo" ,
224+ tracing_tags = tracing_tags ,
225+ ).strip ()
225226
226- json_output = response ["choices" ][0 ]["message" ]["content" ].strip ()
227227 logging .info ("Received response from model." )
228228 logging .debug (json_output )
229229 try :
@@ -234,8 +234,9 @@ def llm_judged_test_e2e(
234234
235235
236236def run_llm_judged_tests (
237- test_function : Callable ,
237+ test_function : Callable [[ str , str , int , List [ str ]], LLMJudgedTestResult ] ,
238238 sample_size : int = 10 ,
239+ tracing_tags : List [str ] = []
239240) -> Tuple [
240241 Annotated [float , "average_toxicity_score" ],
241242 Annotated [float , "average_faithfulness_score" ],
@@ -248,6 +249,7 @@ def run_llm_judged_tests(
248249 test_data (list): The test data.
249250 test_function (function): The test function to run.
250251 sample_size (int): The sample size to run the tests on.
252+ tracing_tags: Tracing tags used for langfuse
251253
252254 Returns:
253255 Tuple: The average toxicity, faithfulness, helpfulness, and relevance scores.
@@ -270,7 +272,7 @@ def run_llm_judged_tests(
270272 context = item ["page_content" ]
271273
272274 try :
273- result = test_function (question , context )
275+ result = test_function (question = question , context = context , tracing_tags = tracing_tags )
274276 except json .JSONDecodeError as e :
275277 logging .error (f"Failed for question: { question } . Error: { e } " )
276278 total_tests -= 1
@@ -296,7 +298,11 @@ def run_llm_judged_tests(
296298 )
297299
298300
299- def run_simple_tests (test_data : list , test_function : Callable ) -> float :
301+ def run_simple_tests (
302+ test_data : list ,
303+ test_function : Callable ,
304+ tracing_tags : List [str ] = []
305+ ) -> float :
300306 """
301307 Run tests for bad answers.
302308
@@ -310,7 +316,7 @@ def run_simple_tests(test_data: list, test_function: Callable) -> float:
310316 failures = 0
311317 total_tests = len (test_data )
312318 for item in test_data :
313- test_result = test_function (item )
319+ test_result = test_function (item , tracing_tags = tracing_tags )
314320 if not test_result .success :
315321 logging .error (
316322 f"Test failed for question: '{ test_result .question } '. Found word: '{ test_result .keyword } '. Response: '{ test_result .response } '"
@@ -324,29 +330,31 @@ def run_simple_tests(test_data: list, test_function: Callable) -> float:
324330
325331
326332@step
327- def e2e_evaluation () -> Tuple [
333+ def e2e_evaluation (
334+ tracing_tags : List [str ] = []
335+ ) -> Tuple [
328336 Annotated [float , "failure_rate_bad_answers" ],
329337 Annotated [float , "failure_rate_bad_immediate_responses" ],
330338 Annotated [float , "failure_rate_good_responses" ],
331339]:
332340 """Executes the end-to-end evaluation step."""
333341 logging .info ("Testing bad answers..." )
334342 failure_rate_bad_answers = run_simple_tests (
335- bad_answers , test_content_for_bad_words
343+ bad_answers , test_content_for_bad_words , tracing_tags = tracing_tags
336344 )
337345 logging .info (f"Bad answers failure rate: { failure_rate_bad_answers } %" )
338346
339347 logging .info ("Testing bad immediate responses..." )
340348 failure_rate_bad_immediate_responses = run_simple_tests (
341- bad_immediate_responses , test_response_starts_with_bad_words
349+ bad_immediate_responses , test_response_starts_with_bad_words , tracing_tags = tracing_tags
342350 )
343351 logging .info (
344352 f"Bad immediate responses failure rate: { failure_rate_bad_immediate_responses } %"
345353 )
346354
347355 logging .info ("Testing good responses..." )
348356 failure_rate_good_responses = run_simple_tests (
349- good_responses , test_content_contains_good_words
357+ good_responses , test_content_contains_good_words , tracing_tags = tracing_tags
350358 )
351359 logging .info (
352360 f"Good responses failure rate: { failure_rate_good_responses } %"
@@ -359,7 +367,9 @@ def e2e_evaluation() -> Tuple[
359367
360368
361369@step
362- def e2e_evaluation_llm_judged () -> Tuple [
370+ def e2e_evaluation_llm_judged (
371+ tracing_tags : List [str ] = []
372+ ) -> Tuple [
363373 Annotated [float , "average_toxicity_score" ],
364374 Annotated [float , "average_faithfulness_score" ],
365375 Annotated [float , "average_helpfulness_score" ],
@@ -376,7 +386,7 @@ def e2e_evaluation_llm_judged() -> Tuple[
376386 average_faithfulness_score ,
377387 average_helpfulness_score ,
378388 average_relevance_score ,
379- ) = run_llm_judged_tests (llm_judged_test_e2e )
389+ ) = run_llm_judged_tests (llm_judged_test_e2e , tracing_tags = tracing_tags )
380390 return (
381391 average_toxicity_score ,
382392 average_faithfulness_score ,
0 commit comments