Skip to content

Commit 03e974f

Browse files
Merge pull request #23 from saichandrapandraju/generation-detection
feat(RHOAIENG-28840): Support '/api/v1/text/generation' detections
2 parents 1f6f916 + a4c376b commit 03e974f

File tree

8 files changed

+497
-115
lines changed

8 files changed

+497
-115
lines changed

detectors/Dockerfile.judge

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@ RUN echo "$CACHEBUST"
2121
COPY ./common /app/detectors/common
2222
COPY ./llm_judge/app.py /app/detectors/llm_judge/app.py
2323
COPY ./llm_judge/detector.py /app/detectors/llm_judge/detector.py
24-
COPY ./llm_judge/scheme.py /app/detectors/llm_judge/scheme.py
2524
RUN touch /app/detectors/llm_judge/__init__.py
2625

2726
EXPOSE 8000

detectors/common/scheme.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from enum import Enum
2-
from typing import Dict, List, Optional
2+
from typing import Any, Dict, List, Optional
33

44
from pydantic import BaseModel, Field, RootModel
55

@@ -134,6 +134,7 @@ class ContentAnalysisResponse(BaseModel):
134134
description="Optional field providing evidences for the provided detection",
135135
default=None,
136136
)
137+
metadata: Optional[Dict[str, Any]] = Field(default_factory=dict, description="Additional metadata from evaluation")
137138

138139

139140
class ContentsAnalysisResponse(RootModel):
@@ -145,3 +146,27 @@ class ContentsAnalysisResponse(RootModel):
145146
class Error(BaseModel):
146147
code: int
147148
message: str
149+
150+
class MetricsListResponse(BaseModel):
151+
"""Response for listing available metrics."""
152+
metrics: List[str] = Field(description="List of available metric names")
153+
total: int = Field(description="Total number of available metrics")
154+
155+
class GenerationAnalysisHttpRequest(BaseModel):
156+
prompt: str = Field(description="Prompt is the user input to the LLM", example="What do you think about the future of AI?")
157+
generated_text: str = Field(description="Generated response from the LLM", example="The future of AI is bright but we need to be careful about the risks.")
158+
detector_params: Optional[Dict[str, Any]] = Field(
159+
default_factory=dict,
160+
description="Detector parameters for evaluation (e.g., metric, criteria, etc.)",
161+
example={"metric": "safety"}
162+
)
163+
164+
class GenerationAnalysisResponse(BaseModel):
165+
detection: str = Field(example="safe")
166+
detection_type: str = Field(example="llm_judge")
167+
score: float = Field(example=0.8)
168+
evidences: Optional[List[EvidenceObj]] = Field(
169+
description="Optional field providing evidences for the provided detection",
170+
default=[],
171+
)
172+
metadata: Optional[Dict[str, Any]] = Field(default_factory=dict, description="Additional metadata from evaluation")

detectors/llm_judge/app.py

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,13 @@
66

77
from detectors.common.app import DetectorBaseAPI as FastAPI
88
from detectors.llm_judge.detector import LLMJudgeDetector
9-
from detectors.llm_judge.scheme import (
9+
from detectors.common.scheme import (
1010
ContentAnalysisHttpRequest,
1111
ContentsAnalysisResponse,
1212
MetricsListResponse,
1313
Error,
14+
GenerationAnalysisHttpRequest,
15+
GenerationAnalysisResponse,
1416
)
1517

1618

@@ -35,23 +37,43 @@ async def lifespan(app: FastAPI):
3537
"/api/v1/text/contents",
3638
response_model=ContentsAnalysisResponse,
3739
description="""LLM-as-Judge detector that evaluates content using various metrics like safety, toxicity, accuracy, helpfulness, etc. \
38-
The metric parameter allows you to specify which evaluation criteria to use. \
40+
The metric detector_params parameter allows you to specify which evaluation criteria to use. \
3941
Supports all built-in vllm_judge metrics including safety, accuracy, helpfulness, clarity, and many more.""",
4042
responses={
4143
404: {"model": Error, "description": "Resource Not Found"},
4244
422: {"model": Error, "description": "Validation Error"},
4345
},
4446
)
45-
async def detector_unary_handler(
47+
async def detector_content_analysis_handler(
4648
request: ContentAnalysisHttpRequest,
4749
detector_id: Annotated[str, Header(example="llm_judge_safety")],
4850
):
4951
"""Analyze content using LLM-as-Judge evaluation."""
5052
detector: LLMJudgeDetector = app.get_detector()
5153
if not detector:
5254
raise HTTPException(status_code=503, detail="Detector not found")
53-
return ContentsAnalysisResponse(root=await detector.run(request))
55+
return ContentsAnalysisResponse(root=await detector.analyze_content(request))
5456

57+
@app.post(
58+
"/api/v1/text/generation",
59+
response_model=GenerationAnalysisResponse,
60+
description="""Analyze a single generation using the specified metric. \
61+
The metric detector_params parameter allows you to specify which evaluation criteria to use. \
62+
Supports all built-in vllm_judge metrics including safety, accuracy, helpfulness, clarity, and many more.""",
63+
responses={
64+
404: {"model": Error, "description": "Resource Not Found"},
65+
422: {"model": Error, "description": "Validation Error"},
66+
},
67+
)
68+
async def detector_generation_analysis_handler(
69+
request: GenerationAnalysisHttpRequest,
70+
detector_id: Annotated[str, Header(example="llm_judge_safety")],
71+
):
72+
"""Analyze a single generation using LLM-as-Judge evaluation."""
73+
detector: LLMJudgeDetector = app.get_detector()
74+
if not detector:
75+
raise HTTPException(status_code=503, detail="Detector not found")
76+
return await detector.analyze_generation(request)
5577

5678
@app.get(
5779
"/api/v1/metrics",

detectors/llm_judge/detector.py

Lines changed: 81 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,12 @@
44
from vllm_judge import Judge, EvaluationResult, BUILTIN_METRICS
55
from vllm_judge.exceptions import MetricNotFoundError
66
from detectors.common.app import logger
7-
from detectors.llm_judge.scheme import (
7+
from detectors.common.scheme import (
88
ContentAnalysisHttpRequest,
99
ContentAnalysisResponse,
1010
ContentsAnalysisResponse,
11+
GenerationAnalysisHttpRequest,
12+
GenerationAnalysisResponse,
1113
)
1214

1315

@@ -40,33 +42,49 @@ def _initialize_judge(self) -> None:
4042
logger.error(f"Failed to detect model: {e}")
4143
raise
4244

43-
async def evaluate_single_content(self, content: str, params: Dict[str, Any]) -> ContentAnalysisResponse:
45+
def _validate_params(self, params: Dict[str, Any]) -> Dict[str, Any]:
4446
"""
45-
Evaluate a single piece of content using the specified metric.
46-
47-
Args:
48-
content: Text content to evaluate
49-
params: vLLM Judge parameters for the evaluation
50-
51-
Returns:
52-
ContentAnalysisResponse with evaluation results
47+
Make sure the params have valid metric/criteria and scale.
5348
"""
5449
if "metric" not in params:
5550
if "criteria" not in params:
5651
params["metric"] = "safety" # Default to safety
5752
elif "scale" not in params:
5853
params["scale"] = (0, 1) # Default to 0-1 scale
59-
60-
if "metric" in params:
54+
else:
6155
if params["metric"] not in self.available_metrics:
6256
raise MetricNotFoundError(
6357
f"Metric '{params['metric']}' not found. Available metrics: {', '.join(sorted(self.available_metrics))}"
6458
)
6559
judge_metric = BUILTIN_METRICS[params["metric"]]
6660
if judge_metric.scale is None:
6761
params["scale"] = (0, 1) # Default to 0-1 scale
62+
63+
return params
64+
65+
def _get_score(self, result: EvaluationResult) -> float:
66+
"""
67+
Get the score from the evaluation result.
68+
"""
69+
if isinstance(result.decision, (int, float)) or result.score is not None:
70+
return float(result.score if result.score is not None else result.decision)
71+
logger.warning(f"Score is not a number: '{result.score}'. Defaulting to 0.0")
72+
return 0.0 # FIXME: default to 0 because of non-optional field in schema
73+
74+
async def evaluate_single_content(self, content: str, params: Dict[str, Any]) -> ContentAnalysisResponse:
75+
"""
76+
Evaluate a single piece of content using the specified metric.
77+
78+
Args:
79+
content: Text content to evaluate
80+
params: vLLM Judge parameters for the evaluation
81+
82+
Returns:
83+
ContentAnalysisResponse with evaluation results
84+
"""
85+
params: Dict[str, Any] = self._validate_params(params)
6886

69-
evaluation_params = {
87+
evaluation_params: Dict[str, Any] = {
7088
"content": content,
7189
**params
7290
}
@@ -76,11 +94,8 @@ async def evaluate_single_content(self, content: str, params: Dict[str, Any]) ->
7694
**evaluation_params
7795
)
7896

79-
# Convert to response format
80-
score = None
81-
if isinstance(result.decision, (int, float)) or result.score is not None:
82-
# Numeric result
83-
score = float(result.score if result.score is not None else result.decision)
97+
# Convert to response format.
98+
score: float = self._get_score(result)
8499

85100
return ContentAnalysisResponse(
86101
start=0,
@@ -93,12 +108,12 @@ async def evaluate_single_content(self, content: str, params: Dict[str, Any]) ->
93108
metadata={"reasoning": result.reasoning}
94109
)
95110

96-
async def run(self, request: ContentAnalysisHttpRequest) -> ContentsAnalysisResponse:
111+
async def analyze_content(self, request: ContentAnalysisHttpRequest) -> ContentsAnalysisResponse:
97112
"""
98113
Run content analysis for each input text.
99114
100115
Args:
101-
request: Input request containing texts and metric to analyze
116+
request: Input request containing texts and optional metric to analyze
102117
103118
Returns:
104119
ContentsAnalysisResponse: The aggregated response for all input texts
@@ -111,7 +126,53 @@ async def run(self, request: ContentAnalysisHttpRequest) -> ContentsAnalysisResp
111126
contents_analyses.append([analysis]) # Wrap in list to match schema
112127

113128
return contents_analyses
129+
130+
async def evaluate_single_generation(self, prompt: str, generated_text: str, params: Dict[str, Any]) -> GenerationAnalysisResponse:
131+
"""
132+
Evaluate a single generation based on the prompt and generated text.
133+
134+
Args:
135+
prompt: Prompt to the LLM
136+
generated_text: Generated text from the LLM
137+
params: vLLM Judge parameters for the evaluation
138+
139+
Returns:
140+
GenerationAnalysisResponse: The response for the generation analysis
141+
"""
142+
params: Dict[str, Any] = self._validate_params(params)
143+
evaluation_params: Dict[str, Any] = {
144+
"input": prompt,
145+
"content": generated_text,
146+
**params
147+
}
148+
149+
result: EvaluationResult = await self.judge.evaluate(
150+
**evaluation_params
151+
)
152+
153+
score: float = self._get_score(result)
154+
155+
return GenerationAnalysisResponse(
156+
detection=str(result.decision),
157+
detection_type="llm_judge",
158+
score=score,
159+
evidences=[],
160+
metadata={"reasoning": result.reasoning}
161+
)
162+
163+
async def analyze_generation(self, request: GenerationAnalysisHttpRequest) -> GenerationAnalysisResponse:
164+
"""
165+
Analyze a single generation based on the prompt and generated text.
166+
167+
Args:
168+
request: Input request containing prompt, generated text and optional metric to analyze
114169
170+
Returns:
171+
GenerationAnalysisResponse: The response for the generation analysis
172+
"""
173+
return await self.evaluate_single_generation(prompt=request.prompt,
174+
generated_text=request.generated_text,
175+
params=request.detector_params)
115176

116177
async def close(self):
117178
"""Close the judge client."""

detectors/llm_judge/requirements.txt

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1 @@
1-
vllm-judge[jinja2]==0.1.6
2-
pyyaml==6.0.2
1+
vllm-judge[jinja2]==0.1.8

detectors/llm_judge/scheme.py

Lines changed: 0 additions & 74 deletions
This file was deleted.

0 commit comments

Comments
 (0)