55from  jinja2  import  BaseLoader , Environment 
66
77from  continuous_eval .llms  import  LLMFactory 
8+ from  continuous_eval .metrics .base  import  (
9+     Arg ,
10+     Field ,
11+     MetricPrompt ,
12+     response_type ,
13+ )
814from  continuous_eval .metrics .base .llm  import  LLMMetric 
9- from  continuous_eval .metrics .base .metric  import  Arg , Field 
10- from  continuous_eval .metrics .base .prompt  import  MetricPrompt 
11- from  continuous_eval .metrics .base .response_type  import  JSON 
15+ from  continuous_eval .metrics .base .probabilistic  import  ProbabilisticMetric 
1216
1317_CWD  =  Path (__file__ ).parent 
1418
@@ -49,7 +53,7 @@ def __init__(
4953        self .prompt  =  MetricPrompt (
5054            sys_prompt ,
5155            user_prompt ,
52-             response_format = JSON (
56+             response_format = response_type . JSON (
5357                {k : v .type  for  k , v  in  response_format .items ()}
5458            ),
5559        )
@@ -60,3 +64,52 @@ def __init__(
6064    @property  
6165    def  help (self ):
6266        return  self ._criteria 
67+ 
68+ 
69+ class  ProbabilisticCustomMetric (ProbabilisticMetric ):
70+     def  __init__ (
71+         self ,
72+         name : str ,
73+         criteria : str ,
74+         rubric : str ,
75+         arguments : Dict [str , Arg ],
76+         response_format : response_type .ResponseFormatBaseType ,
77+         examples : Optional [List [Example ]] =  None ,
78+         temperature : float  =  1.0 ,
79+         model : str  =  LLMFactory .default (),
80+     ):
81+         if  not  isinstance (
82+             response_format , response_type .ResponseFormatBaseType 
83+         ):
84+             raise  ValueError ("response_format must be a ResponseFormatBaseType" )
85+         if  isinstance (response_format , response_type .JSON ):
86+             raise  ValueError (
87+                 "Probabilistic metrics do not support JSON response format, use CustomMetric instead" 
88+             )
89+         with  open (_CWD  /  "custom_metric_sys_probabilistic.jinja2" ) as  f :
90+             raw_system_prompt  =  f .read ()
91+         with  open (_CWD  /  "custom_metric_user.jinja2" ) as  f :
92+             raw_user_prompt  =  f .read ()
93+         env  =  Environment (loader = BaseLoader ())
94+         sys_prompt_template  =  env .from_string (raw_system_prompt )
95+         user_prompt_template  =  env .from_string (raw_user_prompt )
96+         sys_prompt  =  sys_prompt_template .render (
97+             criteria = criteria ,
98+             rubric = rubric ,
99+             examples = examples ,
100+             response_format = response_format ,
101+         )
102+         user_prompt  =  user_prompt_template .render (arguments = arguments )
103+         self ._criteria  =  criteria 
104+         self .prompt  =  MetricPrompt (
105+             sys_prompt ,
106+             user_prompt ,
107+             response_format = response_format ,
108+         )
109+         super ().__init__ (
110+             name = name , prompt = self .prompt , temperature = temperature , model = model 
111+         )
112+ 
113+     @property  
114+     def  help (self ):
115+         return  self ._criteria 
0 commit comments