Skip to content

Commit d1eca31

Browse files
address key sourcery comments
1 parent f730d28 commit d1eca31

File tree

4 files changed

+333
-153
lines changed

4 files changed

+333
-153
lines changed

src/vllm_judge/judge.py

Lines changed: 156 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -103,108 +103,191 @@ async def evaluate(
103103
MetricNotFoundError: If metric name not found
104104
ParseError: If unable to parse model response
105105
"""
106+
# Resolve metric if string
107+
resolved_metric = self._resolve_metric(metric)
108+
109+
# Handle model-specific metrics early
110+
if isinstance(resolved_metric, ModelSpecificMetric):
111+
return await self._evaluate_model_specific_metric(
112+
resolved_metric, content, sampling_params
113+
)
114+
115+
# Process normal evaluation
116+
evaluation_params = self._prepare_evaluation_params(
117+
resolved_metric, criteria, rubric, scale, examples,
118+
system_prompt, template_engine
119+
)
120+
121+
# Process templates
122+
processed_params = self._process_templates(
123+
evaluation_params, template_vars, input, context
124+
)
125+
126+
# Build and execute evaluation
127+
return await self._execute_evaluation(
128+
content, processed_params, sampling_params, **kwargs
129+
)
130+
131+
def _resolve_metric(self, metric: Union[Metric, str, None]) -> Optional[Metric]:
132+
"""Resolve metric string to Metric object."""
106133
if metric and isinstance(metric, str):
107-
metric: Metric = self.get_metric(metric)
108-
109-
# Handle model-specific metrics
110-
if isinstance(metric, ModelSpecificMetric):
111-
if isinstance(content, dict):
112-
raise InvalidInputError("Model-specific metrics only support string and list of dicts as content for now")
113-
114-
if isinstance(content, list) and len(content) == 0:
115-
raise InvalidInputError("Conversation content cannot be an empty list.")
116-
117-
is_conversation = (
118-
isinstance(content, list) and
119-
all(isinstance(msg, dict) and "role" in msg and "content" in msg for msg in content)
134+
return self.get_metric(metric)
135+
return metric
136+
137+
async def _evaluate_model_specific_metric(
138+
self,
139+
metric: ModelSpecificMetric,
140+
content: Union[str, List[Dict[str, str]]],
141+
sampling_params: Optional[Dict[str, Any]]
142+
) -> EvaluationResult:
143+
"""Handle evaluation for model-specific metrics."""
144+
# Validate content for model-specific metrics
145+
if isinstance(content, dict):
146+
raise InvalidInputError(
147+
"Model-specific metrics only support string and list of dicts as content for now"
120148
)
121-
if isinstance(content, list) and not is_conversation:
122-
raise InvalidInputError("Invalid content structure for conversation. Please provide a list of dicts with role and content fields.")
123-
124-
125-
# Skip ALL our formatting
126-
if is_conversation:
127-
messages = content
128-
else:
129-
messages = [{"role": "user", "content": content}]
149+
150+
if isinstance(content, list) and len(content) == 0:
151+
raise InvalidInputError("Conversation content cannot be an empty list.")
152+
153+
# Validate conversation format
154+
is_conversation = (
155+
isinstance(content, list) and
156+
all(isinstance(msg, dict) and "role" in msg and "content" in msg for msg in content)
157+
)
158+
if isinstance(content, list) and not is_conversation:
159+
raise InvalidInputError(
160+
"Invalid content structure for conversation. "
161+
"Please provide a list of dicts with role and content fields."
162+
)
163+
164+
# Prepare messages
165+
if is_conversation:
166+
messages = content
167+
else:
168+
messages = [{"role": "user", "content": content}]
130169

131-
# logger.info(f"Evaluating model-specific metric {metric.name}.")
132-
logger.info(f"We assume you're using {metric.model_pattern} type model. If not, please do not use this metric and use a normal metric instead.")
133-
134-
# vLLM applies model's chat template automatically
135-
llm_response:str = await self._call_model(messages, sampling_params, return_choices=False)
136-
137-
# Use metric's parser
138-
return metric.parser_func(llm_response)
170+
logger.info(
171+
f"We assume you're using {metric.model_pattern} type model. "
172+
f"If not, please do not use this metric and use a normal metric instead."
173+
)
139174

140-
# Handle normal metrics
141-
# Handle metric parameter
142-
metric_template_vars = {}
175+
# Get model response and parse
176+
llm_response = await self._call_model(messages, sampling_params, return_choices=False)
177+
return metric.parser_func(llm_response)
178+
179+
def _prepare_evaluation_params(
180+
self,
181+
metric: Optional[Metric],
182+
criteria: Optional[str],
183+
rubric: Union[str, Dict[Union[int, float], str], None],
184+
scale: Optional[Tuple[int, int]],
185+
examples: Optional[List[Dict[str, Any]]],
186+
system_prompt: Optional[str],
187+
template_engine: Union[str, TemplateEngine]
188+
) -> Dict[str, Any]:
189+
"""Prepare evaluation parameters, merging metric defaults with user overrides."""
190+
params = {
191+
"criteria": criteria,
192+
"rubric": rubric,
193+
"scale": scale,
194+
"examples": examples,
195+
"system_prompt": system_prompt,
196+
"template_engine": template_engine,
197+
"metric_template_vars": {}
198+
}
143199

144200
if metric:
145201
# Use metric defaults but allow overrides
146-
criteria = criteria or metric.criteria
147-
rubric = rubric or metric.rubric
148-
scale = scale or metric.scale
149-
examples = examples or metric.examples
150-
system_prompt = system_prompt or metric.system_prompt
151-
metric_template_vars = metric.template_vars
202+
params["criteria"] = criteria or metric.criteria
203+
params["rubric"] = rubric or metric.rubric
204+
params["scale"] = scale or metric.scale
205+
params["examples"] = examples or metric.examples
206+
params["system_prompt"] = system_prompt or metric.system_prompt
207+
params["metric_template_vars"] = metric.template_vars
152208
if metric.template_engine:
153-
template_engine = metric.template_engine
209+
params["template_engine"] = metric.template_engine
154210

155-
# Validate inputs
156-
if not criteria:
211+
# Validate required parameters
212+
if not params["criteria"]:
157213
raise InvalidInputError("Either 'criteria' or 'metric' must be provided")
158214

215+
return params
216+
217+
def _process_templates(
218+
self,
219+
params: Dict[str, Any],
220+
template_vars: Optional[Dict[str, Any]],
221+
input_text: Optional[str],
222+
context: Optional[str]
223+
) -> Dict[str, Any]:
224+
"""Process all template variables and return processed parameters."""
159225
# Determine template engine
160-
engine = TemplateEngine(template_engine)
226+
engine = TemplateEngine(params["template_engine"])
161227

162228
# Merge template variables (metric defaults + user provided)
163-
all_template_vars = {**metric_template_vars, **(template_vars or {})}
164-
# Add input to template variables if provided
165-
if input:
166-
all_template_vars["input"] = input
229+
all_template_vars = {**params["metric_template_vars"], **(template_vars or {})}
230+
if input_text:
231+
all_template_vars["input"] = input_text
167232

168-
# Process templates
169-
criteria = TemplateProcessor.apply_template(
170-
criteria, all_template_vars, engine, strict=True
171-
)
172-
rubric = TemplateProcessor.apply_template(
173-
rubric, all_template_vars, engine, strict=True
174-
)
175-
system_prompt = TemplateProcessor.apply_template(
176-
system_prompt, all_template_vars, engine, strict=True
177-
)
178-
context = TemplateProcessor.apply_template(
233+
# Process templates for all relevant fields
234+
template_fields = ["criteria", "rubric", "system_prompt"]
235+
processed = {}
236+
237+
for field in template_fields:
238+
processed[field] = TemplateProcessor.apply_template(
239+
params[field], all_template_vars, engine, strict=True
240+
)
241+
242+
# Process additional fields
243+
processed["context"] = TemplateProcessor.apply_template(
179244
context, all_template_vars, engine, strict=True
180245
)
181-
input = TemplateProcessor.apply_template(
182-
input, all_template_vars, engine, strict=True
246+
processed["input"] = TemplateProcessor.apply_template(
247+
input_text, all_template_vars, engine, strict=True
183248
)
184249

250+
# Copy other parameters
251+
processed.update({
252+
"scale": params["scale"],
253+
"examples": params["examples"],
254+
"template_vars": all_template_vars,
255+
"template_engine": engine
256+
})
257+
258+
return processed
259+
260+
async def _execute_evaluation(
261+
self,
262+
content: Union[str, Dict[str, str], List[Dict[str, str]]],
263+
params: Dict[str, Any],
264+
sampling_params: Optional[Dict[str, Any]],
265+
**kwargs
266+
) -> EvaluationResult:
267+
"""Execute the evaluation with processed parameters."""
185268
# Build messages
186269
messages = PromptBuilder.build_messages(
187270
content=content,
188-
input=input,
189-
criteria=criteria,
190-
rubric=rubric,
191-
scale=scale,
192-
examples=examples,
193-
system_prompt=system_prompt,
194-
context=context,
271+
input=params["input"],
272+
criteria=params["criteria"],
273+
rubric=params["rubric"],
274+
scale=params["scale"],
275+
examples=params["examples"],
276+
system_prompt=params["system_prompt"],
277+
context=params["context"],
195278
**kwargs
196279
)
197280

198-
# Get LLM response. We don't need choices for now.
199-
llm_response:str = await self._call_model(messages, sampling_params, return_choices=False)
200-
281+
# Get LLM response
282+
llm_response = await self._call_model(messages, sampling_params, return_choices=False)
283+
201284
# Parse response
202285
result = self._parse_response(llm_response)
203286

204287
# Add template info to metadata if used
205-
if all_template_vars:
206-
result.metadata["template_vars"] = all_template_vars
207-
result.metadata["template_engine"] = engine.value
288+
if params["template_vars"]:
289+
result.metadata["template_vars"] = params["template_vars"]
290+
result.metadata["template_engine"] = params["template_engine"].value
208291

209292
return result
210293

0 commit comments

Comments
 (0)