Skip to content

Commit 52049ef

Browse files
address key sourcery comments
1 parent 761e88f commit 52049ef

File tree

3 files changed

+332
-157
lines changed

3 files changed

+332
-157
lines changed

src/vllm_judge/judge.py

Lines changed: 156 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -107,113 +107,191 @@ async def evaluate(
107107
MetricNotFoundError: If metric name not found
108108
ParseError: If unable to parse model response
109109
"""
110+
# Resolve metric if string
111+
resolved_metric = self._resolve_metric(metric)
112+
113+
# Handle model-specific metrics early
114+
if isinstance(resolved_metric, ModelSpecificMetric):
115+
return await self._evaluate_model_specific_metric(
116+
resolved_metric, content, sampling_params
117+
)
118+
119+
# Process normal evaluation
120+
evaluation_params = self._prepare_evaluation_params(
121+
resolved_metric, criteria, rubric, scale, examples,
122+
system_prompt, template_engine
123+
)
124+
125+
# Process templates
126+
processed_params = self._process_templates(
127+
evaluation_params, template_vars, input, context
128+
)
129+
130+
# Build and execute evaluation
131+
return await self._execute_evaluation(
132+
content, processed_params, sampling_params, **kwargs
133+
)
134+
135+
def _resolve_metric(self, metric: Union[Metric, str, None]) -> Optional[Metric]:
136+
"""Resolve metric string to Metric object."""
110137
if metric and isinstance(metric, str):
111-
metric: Metric = self.get_metric(metric)
112-
113-
# Handle model-specific metrics
114-
if isinstance(metric, ModelSpecificMetric):
115-
if isinstance(content, dict):
116-
raise InvalidInputError("Model-specific metrics only support string and list of dicts as content for now")
117-
118-
if isinstance(content, list) and len(content) == 0:
119-
raise InvalidInputError("Conversation content cannot be an empty list.")
120-
121-
is_conversation = (
122-
isinstance(content, list) and
123-
all(isinstance(msg, dict) and "role" in msg and "content" in msg for msg in content)
138+
return self.get_metric(metric)
139+
return metric
140+
141+
async def _evaluate_model_specific_metric(
142+
self,
143+
metric: ModelSpecificMetric,
144+
content: Union[str, List[Dict[str, str]]],
145+
sampling_params: Optional[Dict[str, Any]]
146+
) -> EvaluationResult:
147+
"""Handle evaluation for model-specific metrics."""
148+
# Validate content for model-specific metrics
149+
if isinstance(content, dict):
150+
raise InvalidInputError(
151+
"Model-specific metrics only support string and list of dicts as content for now"
124152
)
125-
if isinstance(content, list) and not is_conversation:
126-
raise InvalidInputError("Invalid content structure for conversation. Please provide a list of dicts with role and content fields.")
127-
128-
129-
# Skip ALL our formatting
130-
if is_conversation:
131-
messages = content
132-
else:
133-
messages = [{"role": "user", "content": content}]
134-
135-
# logger.info(f"Evaluating model-specific metric {metric.name}.")
136-
logger.info(f"We assume you're using {metric.model_pattern} type model. If not, please do not use this metric and use a normal metric instead.")
153+
154+
if isinstance(content, list) and len(content) == 0:
155+
raise InvalidInputError("Conversation content cannot be an empty list.")
156+
157+
# Validate conversation format
158+
is_conversation = (
159+
isinstance(content, list) and
160+
all(isinstance(msg, dict) and "role" in msg and "content" in msg for msg in content)
161+
)
162+
if isinstance(content, list) and not is_conversation:
163+
raise InvalidInputError(
164+
"Invalid content structure for conversation. "
165+
"Please provide a list of dicts with role and content fields."
166+
)
167+
168+
# Prepare messages
169+
if is_conversation:
170+
messages = content
171+
else:
172+
messages = [{"role": "user", "content": content}]
137173

138-
if metric.sampling_params:
139-
if sampling_params is None:
140-
sampling_params = {}
141-
sampling_params.update(metric.sampling_params)
142-
143-
# vLLM applies model's chat template automatically
144-
llm_response = await self._call_model(messages, sampling_params, return_choices=metric.return_choices)
145-
146-
# Use metric's parser
147-
return metric.parser_func(llm_response)
174+
logger.info(
175+
f"We assume you're using {metric.model_pattern} type model. "
176+
f"If not, please do not use this metric and use a normal metric instead."
177+
)
148178

149-
# Handle normal metrics
150-
# Handle metric parameter
151-
metric_template_vars = {}
179+
# Get model response and parse
180+
llm_response = await self._call_model(messages, sampling_params, return_choices=False)
181+
return metric.parser_func(llm_response)
182+
183+
def _prepare_evaluation_params(
184+
self,
185+
metric: Optional[Metric],
186+
criteria: Optional[str],
187+
rubric: Union[str, Dict[Union[int, float], str], None],
188+
scale: Optional[Tuple[int, int]],
189+
examples: Optional[List[Dict[str, Any]]],
190+
system_prompt: Optional[str],
191+
template_engine: Union[str, TemplateEngine]
192+
) -> Dict[str, Any]:
193+
"""Prepare evaluation parameters, merging metric defaults with user overrides."""
194+
params = {
195+
"criteria": criteria,
196+
"rubric": rubric,
197+
"scale": scale,
198+
"examples": examples,
199+
"system_prompt": system_prompt,
200+
"template_engine": template_engine,
201+
"metric_template_vars": {}
202+
}
152203

153204
if metric:
154205
# Use metric defaults but allow overrides
155-
criteria = criteria or metric.criteria
156-
rubric = rubric or metric.rubric
157-
scale = scale or metric.scale
158-
examples = examples or metric.examples
159-
system_prompt = system_prompt or metric.system_prompt
160-
metric_template_vars = metric.template_vars
206+
params["criteria"] = criteria or metric.criteria
207+
params["rubric"] = rubric or metric.rubric
208+
params["scale"] = scale or metric.scale
209+
params["examples"] = examples or metric.examples
210+
params["system_prompt"] = system_prompt or metric.system_prompt
211+
params["metric_template_vars"] = metric.template_vars
161212
if metric.template_engine:
162-
template_engine = metric.template_engine
213+
params["template_engine"] = metric.template_engine
163214

164-
# Validate inputs
165-
if not criteria:
215+
# Validate required parameters
216+
if not params["criteria"]:
166217
raise InvalidInputError("Either 'criteria' or 'metric' must be provided")
167218

219+
return params
220+
221+
def _process_templates(
222+
self,
223+
params: Dict[str, Any],
224+
template_vars: Optional[Dict[str, Any]],
225+
input_text: Optional[str],
226+
context: Optional[str]
227+
) -> Dict[str, Any]:
228+
"""Process all template variables and return processed parameters."""
168229
# Determine template engine
169-
engine = TemplateEngine(template_engine)
230+
engine = TemplateEngine(params["template_engine"])
170231

171232
# Merge template variables (metric defaults + user provided)
172-
all_template_vars = {**metric_template_vars, **(template_vars or {})}
173-
# Add input to template variables if provided
174-
if input:
175-
all_template_vars["input"] = input
233+
all_template_vars = {**params["metric_template_vars"], **(template_vars or {})}
234+
if input_text:
235+
all_template_vars["input"] = input_text
176236

177-
# Process templates
178-
criteria = TemplateProcessor.apply_template(
179-
criteria, all_template_vars, engine, strict=True
180-
)
181-
rubric = TemplateProcessor.apply_template(
182-
rubric, all_template_vars, engine, strict=True
183-
)
184-
system_prompt = TemplateProcessor.apply_template(
185-
system_prompt, all_template_vars, engine, strict=True
186-
)
187-
context = TemplateProcessor.apply_template(
237+
# Process templates for all relevant fields
238+
template_fields = ["criteria", "rubric", "system_prompt"]
239+
processed = {}
240+
241+
for field in template_fields:
242+
processed[field] = TemplateProcessor.apply_template(
243+
params[field], all_template_vars, engine, strict=True
244+
)
245+
246+
# Process additional fields
247+
processed["context"] = TemplateProcessor.apply_template(
188248
context, all_template_vars, engine, strict=True
189249
)
190-
input = TemplateProcessor.apply_template(
191-
input, all_template_vars, engine, strict=True
250+
processed["input"] = TemplateProcessor.apply_template(
251+
input_text, all_template_vars, engine, strict=True
192252
)
193253

254+
# Copy other parameters
255+
processed.update({
256+
"scale": params["scale"],
257+
"examples": params["examples"],
258+
"template_vars": all_template_vars,
259+
"template_engine": engine
260+
})
261+
262+
return processed
263+
264+
async def _execute_evaluation(
265+
self,
266+
content: Union[str, Dict[str, str], List[Dict[str, str]]],
267+
params: Dict[str, Any],
268+
sampling_params: Optional[Dict[str, Any]],
269+
**kwargs
270+
) -> EvaluationResult:
271+
"""Execute the evaluation with processed parameters."""
194272
# Build messages
195273
messages = PromptBuilder.build_messages(
196274
content=content,
197-
input=input,
198-
criteria=criteria,
199-
rubric=rubric,
200-
scale=scale,
201-
examples=examples,
202-
system_prompt=system_prompt,
203-
context=context,
275+
input=params["input"],
276+
criteria=params["criteria"],
277+
rubric=params["rubric"],
278+
scale=params["scale"],
279+
examples=params["examples"],
280+
system_prompt=params["system_prompt"],
281+
context=params["context"],
204282
**kwargs
205283
)
206284

207-
# Get LLM response. We don't need choices for now.
208-
llm_response:str = await self._call_model(messages, sampling_params, return_choices=False)
209-
285+
# Get LLM response
286+
llm_response = await self._call_model(messages, sampling_params, return_choices=False)
287+
210288
# Parse response
211289
result = self._parse_response(llm_response)
212290

213291
# Add template info to metadata if used
214-
if all_template_vars:
215-
result.metadata["template_vars"] = all_template_vars
216-
result.metadata["template_engine"] = engine.value
292+
if params["template_vars"]:
293+
result.metadata["template_vars"] = params["template_vars"]
294+
result.metadata["template_engine"] = params["template_engine"].value
217295

218296
return result
219297

0 commit comments

Comments
 (0)