diff --git a/detectors/Dockerfile.judge b/detectors/Dockerfile.judge index b33753d..760ca86 100644 --- a/detectors/Dockerfile.judge +++ b/detectors/Dockerfile.judge @@ -21,7 +21,6 @@ RUN echo "$CACHEBUST" COPY ./common /app/detectors/common COPY ./llm_judge/app.py /app/detectors/llm_judge/app.py COPY ./llm_judge/detector.py /app/detectors/llm_judge/detector.py -COPY ./llm_judge/scheme.py /app/detectors/llm_judge/scheme.py RUN touch /app/detectors/llm_judge/__init__.py EXPOSE 8000 diff --git a/detectors/common/scheme.py b/detectors/common/scheme.py index 4ed968b..f40e714 100644 --- a/detectors/common/scheme.py +++ b/detectors/common/scheme.py @@ -1,5 +1,5 @@ from enum import Enum -from typing import Dict, List, Optional +from typing import Any, Dict, List, Optional from pydantic import BaseModel, Field, RootModel @@ -134,6 +134,7 @@ class ContentAnalysisResponse(BaseModel): description="Optional field providing evidences for the provided detection", default=None, ) + metadata: Optional[Dict[str, Any]] = Field(default_factory=dict, description="Additional metadata from evaluation") class ContentsAnalysisResponse(RootModel): @@ -145,3 +146,27 @@ class ContentsAnalysisResponse(RootModel): class Error(BaseModel): code: int message: str + +class MetricsListResponse(BaseModel): + """Response for listing available metrics.""" + metrics: List[str] = Field(description="List of available metric names") + total: int = Field(description="Total number of available metrics") + +class GenerationAnalysisHttpRequest(BaseModel): + prompt: str = Field(description="Prompt is the user input to the LLM", example="What do you think about the future of AI?") + generated_text: str = Field(description="Generated response from the LLM", example="The future of AI is bright but we need to be careful about the risks.") + detector_params: Optional[Dict[str, Any]] = Field( + default_factory=dict, + description="Detector parameters for evaluation (e.g., metric, criteria, etc.)", + example={"metric": "safety"} + ) + +class GenerationAnalysisResponse(BaseModel): + detection: str = Field(example="safe") + detection_type: str = Field(example="llm_judge") + score: float = Field(example=0.8) + evidences: Optional[List[EvidenceObj]] = Field( + description="Optional field providing evidences for the provided detection", + default=[], + ) + metadata: Optional[Dict[str, Any]] = Field(default_factory=dict, description="Additional metadata from evaluation") \ No newline at end of file diff --git a/detectors/llm_judge/app.py b/detectors/llm_judge/app.py index 1eb7d0c..8c2395c 100644 --- a/detectors/llm_judge/app.py +++ b/detectors/llm_judge/app.py @@ -6,11 +6,13 @@ from detectors.common.app import DetectorBaseAPI as FastAPI from detectors.llm_judge.detector import LLMJudgeDetector -from detectors.llm_judge.scheme import ( +from detectors.common.scheme import ( ContentAnalysisHttpRequest, ContentsAnalysisResponse, MetricsListResponse, Error, + GenerationAnalysisHttpRequest, + GenerationAnalysisResponse, ) @@ -35,14 +37,14 @@ async def lifespan(app: FastAPI): "/api/v1/text/contents", response_model=ContentsAnalysisResponse, description="""LLM-as-Judge detector that evaluates content using various metrics like safety, toxicity, accuracy, helpfulness, etc. \ - The metric parameter allows you to specify which evaluation criteria to use. \ + The metric detector_params parameter allows you to specify which evaluation criteria to use. \ Supports all built-in vllm_judge metrics including safety, accuracy, helpfulness, clarity, and many more.""", responses={ 404: {"model": Error, "description": "Resource Not Found"}, 422: {"model": Error, "description": "Validation Error"}, }, ) -async def detector_unary_handler( +async def detector_content_analysis_handler( request: ContentAnalysisHttpRequest, detector_id: Annotated[str, Header(example="llm_judge_safety")], ): @@ -50,8 +52,28 @@ async def detector_unary_handler( detector: LLMJudgeDetector = app.get_detector() if not detector: raise HTTPException(status_code=503, detail="Detector not found") - return ContentsAnalysisResponse(root=await detector.run(request)) + return ContentsAnalysisResponse(root=await detector.analyze_content(request)) +@app.post( + "/api/v1/text/generation", + response_model=GenerationAnalysisResponse, + description="""Analyze a single generation using the specified metric. \ + The metric detector_params parameter allows you to specify which evaluation criteria to use. \ + Supports all built-in vllm_judge metrics including safety, accuracy, helpfulness, clarity, and many more.""", + responses={ + 404: {"model": Error, "description": "Resource Not Found"}, + 422: {"model": Error, "description": "Validation Error"}, + }, +) +async def detector_generation_analysis_handler( + request: GenerationAnalysisHttpRequest, + detector_id: Annotated[str, Header(example="llm_judge_safety")], +): + """Analyze a single generation using LLM-as-Judge evaluation.""" + detector: LLMJudgeDetector = app.get_detector() + if not detector: + raise HTTPException(status_code=503, detail="Detector not found") + return await detector.analyze_generation(request) @app.get( "/api/v1/metrics", diff --git a/detectors/llm_judge/detector.py b/detectors/llm_judge/detector.py index 823a2f7..0eb914d 100644 --- a/detectors/llm_judge/detector.py +++ b/detectors/llm_judge/detector.py @@ -4,10 +4,12 @@ from vllm_judge import Judge, EvaluationResult, BUILTIN_METRICS from vllm_judge.exceptions import MetricNotFoundError from detectors.common.app import logger -from detectors.llm_judge.scheme import ( +from detectors.common.scheme import ( ContentAnalysisHttpRequest, ContentAnalysisResponse, ContentsAnalysisResponse, + GenerationAnalysisHttpRequest, + GenerationAnalysisResponse, ) @@ -40,24 +42,16 @@ def _initialize_judge(self) -> None: logger.error(f"Failed to detect model: {e}") raise - async def evaluate_single_content(self, content: str, params: Dict[str, Any]) -> ContentAnalysisResponse: + def _validate_params(self, params: Dict[str, Any]) -> Dict[str, Any]: """ - Evaluate a single piece of content using the specified metric. - - Args: - content: Text content to evaluate - params: vLLM Judge parameters for the evaluation - - Returns: - ContentAnalysisResponse with evaluation results + Make sure the params have valid metric/criteria and scale. """ if "metric" not in params: if "criteria" not in params: params["metric"] = "safety" # Default to safety elif "scale" not in params: params["scale"] = (0, 1) # Default to 0-1 scale - - if "metric" in params: + else: if params["metric"] not in self.available_metrics: raise MetricNotFoundError( f"Metric '{params['metric']}' not found. Available metrics: {', '.join(sorted(self.available_metrics))}" @@ -65,8 +59,32 @@ async def evaluate_single_content(self, content: str, params: Dict[str, Any]) -> judge_metric = BUILTIN_METRICS[params["metric"]] if judge_metric.scale is None: params["scale"] = (0, 1) # Default to 0-1 scale + + return params + + def _get_score(self, result: EvaluationResult) -> float: + """ + Get the score from the evaluation result. + """ + if isinstance(result.decision, (int, float)) or result.score is not None: + return float(result.score if result.score is not None else result.decision) + logger.warning(f"Score is not a number: '{result.score}'. Defaulting to 0.0") + return 0.0 # FIXME: default to 0 because of non-optional field in schema + + async def evaluate_single_content(self, content: str, params: Dict[str, Any]) -> ContentAnalysisResponse: + """ + Evaluate a single piece of content using the specified metric. + + Args: + content: Text content to evaluate + params: vLLM Judge parameters for the evaluation + + Returns: + ContentAnalysisResponse with evaluation results + """ + params: Dict[str, Any] = self._validate_params(params) - evaluation_params = { + evaluation_params: Dict[str, Any] = { "content": content, **params } @@ -76,11 +94,8 @@ async def evaluate_single_content(self, content: str, params: Dict[str, Any]) -> **evaluation_params ) - # Convert to response format - score = None - if isinstance(result.decision, (int, float)) or result.score is not None: - # Numeric result - score = float(result.score if result.score is not None else result.decision) + # Convert to response format. + score: float = self._get_score(result) return ContentAnalysisResponse( start=0, @@ -93,12 +108,12 @@ async def evaluate_single_content(self, content: str, params: Dict[str, Any]) -> metadata={"reasoning": result.reasoning} ) - async def run(self, request: ContentAnalysisHttpRequest) -> ContentsAnalysisResponse: + async def analyze_content(self, request: ContentAnalysisHttpRequest) -> ContentsAnalysisResponse: """ Run content analysis for each input text. Args: - request: Input request containing texts and metric to analyze + request: Input request containing texts and optional metric to analyze Returns: ContentsAnalysisResponse: The aggregated response for all input texts @@ -111,7 +126,53 @@ async def run(self, request: ContentAnalysisHttpRequest) -> ContentsAnalysisResp contents_analyses.append([analysis]) # Wrap in list to match schema return contents_analyses + + async def evaluate_single_generation(self, prompt: str, generated_text: str, params: Dict[str, Any]) -> GenerationAnalysisResponse: + """ + Evaluate a single generation based on the prompt and generated text. + + Args: + prompt: Prompt to the LLM + generated_text: Generated text from the LLM + params: vLLM Judge parameters for the evaluation + + Returns: + GenerationAnalysisResponse: The response for the generation analysis + """ + params: Dict[str, Any] = self._validate_params(params) + evaluation_params: Dict[str, Any] = { + "input": prompt, + "content": generated_text, + **params + } + + result: EvaluationResult = await self.judge.evaluate( + **evaluation_params + ) + + score: float = self._get_score(result) + + return GenerationAnalysisResponse( + detection=str(result.decision), + detection_type="llm_judge", + score=score, + evidences=[], + metadata={"reasoning": result.reasoning} + ) + + async def analyze_generation(self, request: GenerationAnalysisHttpRequest) -> GenerationAnalysisResponse: + """ + Analyze a single generation based on the prompt and generated text. + + Args: + request: Input request containing prompt, generated text and optional metric to analyze + Returns: + GenerationAnalysisResponse: The response for the generation analysis + """ + return await self.evaluate_single_generation(prompt=request.prompt, + generated_text=request.generated_text, + params=request.detector_params) async def close(self): """Close the judge client.""" diff --git a/detectors/llm_judge/requirements.txt b/detectors/llm_judge/requirements.txt index ac1cc6c..9b7adc7 100644 --- a/detectors/llm_judge/requirements.txt +++ b/detectors/llm_judge/requirements.txt @@ -1,2 +1 @@ -vllm-judge[jinja2]==0.1.6 -pyyaml==6.0.2 \ No newline at end of file +vllm-judge[jinja2]==0.1.8 \ No newline at end of file diff --git a/detectors/llm_judge/scheme.py b/detectors/llm_judge/scheme.py deleted file mode 100644 index 7237b1c..0000000 --- a/detectors/llm_judge/scheme.py +++ /dev/null @@ -1,74 +0,0 @@ -from enum import Enum -from typing import List, Optional, Dict, Any -from pydantic import BaseModel, Field, RootModel - - -class Evidence(BaseModel): - source: str = Field( - title="Source", - example="https://en.wikipedia.org/wiki/IBM", - description="Source of the evidence, it can be url of the evidence etc", - ) - - -class EvidenceType(str, Enum): - url = "url" - title = "title" - - -class EvidenceObj(BaseModel): - type: EvidenceType = Field( - title="EvidenceType", - example="url", - description="Type field signifying the type of evidence provided. Example url, title etc", - ) - evidence: Evidence = Field( - description="Evidence object, currently only containing source, but in future can contain other optional arguments like id, etc", - ) - - -class ContentAnalysisHttpRequest(BaseModel): - contents: List[str] = Field( - min_length=1, - title="Contents", - description="Field allowing users to provide list of texts for analysis. Note, results of this endpoint will contain analysis / detection of each of the provided text in the order they are present in the contents object.", - example=[ - "Martians are like crocodiles; the more you give them meat, the more they want" - ], - ) - detector_params: Optional[Dict[str, Any]] = Field( - default_factory=dict, - description="Detector parameters for evaluation (e.g., metric, criteria, etc.)", - example={"metric": "safety"} - ) - - -class ContentAnalysisResponse(BaseModel): - start: int = Field(example=0) - end: int = Field(example=75) - text: str = Field(example="This is a safe and helpful response") - detection: str = Field(example="vllm_model") - detection_type: str = Field(example="llm_judge") - score: float = Field(example=0.8) - evidences: Optional[List[EvidenceObj]] = Field( - description="Optional field providing evidences for the provided detection", - default=[], - ) - metadata: Optional[Dict[str, Any]] = Field(default_factory=dict, description="Additional metadata from evaluation") - - -class ContentsAnalysisResponse(RootModel): - root: List[List[ContentAnalysisResponse]] = Field( - title="Response Text Content Analysis LLM Judge" - ) - - -class Error(BaseModel): - code: int - message: str - - -class MetricsListResponse(BaseModel): - """Response for listing available metrics.""" - metrics: List[str] = Field(description="List of available metric names") - total: int = Field(description="Total number of available metrics") \ No newline at end of file diff --git a/tests/detectors/llm_judge/test_llm_judge_detector.py b/tests/detectors/llm_judge/test_llm_judge_detector.py index 64bc571..31898c9 100644 --- a/tests/detectors/llm_judge/test_llm_judge_detector.py +++ b/tests/detectors/llm_judge/test_llm_judge_detector.py @@ -6,9 +6,11 @@ # Import the detector components from detectors.llm_judge.detector import LLMJudgeDetector -from detectors.llm_judge.scheme import ( +from detectors.common.scheme import ( ContentAnalysisHttpRequest, ContentAnalysisResponse, + GenerationAnalysisHttpRequest, + GenerationAnalysisResponse, ) # Import vLLM Judge components for mocking @@ -16,8 +18,8 @@ from vllm_judge.exceptions import MetricNotFoundError -class TestLLMJudgeDetector: - """Test suite for LLMJudgeDetector.""" +class TestLLMJudgeDetectorContentAnalysis: + """Test suite for LLMJudgeDetector content analysis.""" @pytest.fixture def mock_judge_result(self) -> EvaluationResult: @@ -73,6 +75,16 @@ def test_detector_initialization_unreachable_url(self) -> None: with pytest.raises(Exception, match="Failed to detect model"): LLMJudgeDetector() + def test_close_detector(self, detector_with_mock_judge: Tuple[LLMJudgeDetector, AsyncMock]) -> None: + """Test closing the detector properly closes the judge.""" + detector: LLMJudgeDetector + mock_judge: AsyncMock + detector, mock_judge = detector_with_mock_judge + + asyncio.run(detector.close()) + + mock_judge.close.assert_called_once() + def test_evaluate_single_content_basic_metric(self, detector_with_mock_judge: Tuple[LLMJudgeDetector, AsyncMock]) -> None: """Test basic evaluation with just a metric.""" detector: LLMJudgeDetector @@ -182,7 +194,7 @@ def test_run_single_content(self, detector_with_mock_judge: Tuple[LLMJudgeDetect detector_params={"metric": "safety"} ) - result = asyncio.run(detector.run(request)) + result = asyncio.run(detector.analyze_content(request)) assert len(result) == 1 assert len(result[0]) == 1 @@ -216,7 +228,7 @@ def test_run_multiple_contents(self, detector_with_mock_judge: Tuple[LLMJudgeDet detector_params={"metric": "safety"} ) - result = asyncio.run(detector.run(request)) + result = asyncio.run(detector.analyze_content(request)) assert len(result) == 3 for i, analysis_list in enumerate(result): @@ -252,7 +264,7 @@ def test_run_with_custom_evaluation_params(self, detector_with_mock_judge: Tuple detector_params=custom_evaluation_params ) - result = asyncio.run(detector.run(request)) + result = asyncio.run(detector.analyze_content(request)) # Verify complex parameters were passed correctly expected_call_params = { @@ -285,3 +297,292 @@ def test_close_detector(self, detector_with_mock_judge: Tuple[LLMJudgeDetector, asyncio.run(detector.close()) mock_judge.close.assert_called_once() + + +class TestLLMJudgeDetectorGenerationAnalysis: + """Test suite for LLMJudgeDetector generation analysis.""" + + @pytest.fixture + def mock_judge_result(self) -> EvaluationResult: + """Mock EvaluationResult for generation testing.""" + return EvaluationResult( + decision="HELPFUL", + reasoning="This generated response is helpful and addresses the user's question appropriately.", + score=0.85, + metadata={"model": "test-model"} + ) + + @pytest.fixture + def detector_with_mock_judge(self, mock_judge_result) -> Tuple[LLMJudgeDetector, AsyncMock]: + """Create detector with mocked Judge.""" + with patch.dict(os.environ, {"VLLM_BASE_URL": "http://test:8000"}): + with patch('vllm_judge.Judge.from_url') as mock_judge_class: + # Create mock judge instance + mock_judge_instance = AsyncMock() + mock_judge_instance.evaluate = AsyncMock(return_value=mock_judge_result) + mock_judge_instance.config.model = "test-model" + mock_judge_instance.config.base_url = "http://test:8000" + mock_judge_instance.close = AsyncMock() + + mock_judge_class.return_value = mock_judge_instance + + detector = LLMJudgeDetector() + return detector, mock_judge_instance + + def test_evaluate_single_generation_basic_metric(self, detector_with_mock_judge: Tuple[LLMJudgeDetector, AsyncMock]) -> None: + """Test basic generation evaluation with just a metric.""" + detector: LLMJudgeDetector + mock_judge: AsyncMock + detector, mock_judge = detector_with_mock_judge + + prompt = "What is artificial intelligence?" + generated_text = "Artificial intelligence (AI) refers to the simulation of human intelligence in machines." + params = {"metric": "helpfulness"} + + result = asyncio.run(detector.evaluate_single_generation(prompt, generated_text, params)) + + # Verify judge.evaluate was called correctly + mock_judge.evaluate.assert_called_once_with( + input=prompt, + content=generated_text, + metric="helpfulness" + ) + + # Verify response format + assert isinstance(result, GenerationAnalysisResponse) + assert result.detection == "HELPFUL" + assert result.score == 0.85 + assert result.detection_type == "llm_judge" + assert "reasoning" in result.metadata + + def test_evaluate_single_generation_full_parameters(self, detector_with_mock_judge: Tuple[LLMJudgeDetector, AsyncMock]) -> None: + """Test generation evaluation with all vLLM Judge parameters.""" + detector: LLMJudgeDetector + mock_judge: AsyncMock + detector, mock_judge = detector_with_mock_judge + + prompt = "Explain quantum computing in simple terms" + generated_text = "Quantum computing uses quantum bits (qubits) that can exist in multiple states simultaneously, allowing for parallel processing of information." + params = { + "criteria": "accuracy, clarity, and completeness", + "rubric": "Score based on technical accuracy and accessibility", + "scale": [1, 10], + "examples": [{"input": "test prompt", "output": "example response"}], + "system_prompt": "You are evaluating educational content", + "context": "This is for a general audience explanation of {topic}", + "template_vars": {"topic": "quantum computing"} + } + + asyncio.run(detector.evaluate_single_generation(prompt, generated_text, params)) + + # Verify all parameters were passed through + expected_call = { + "input": prompt, + "content": generated_text, + **params + } + mock_judge.evaluate.assert_called_once_with(**expected_call) + + def test_evaluate_single_generation_criteria_without_metric(self, detector_with_mock_judge: Tuple[LLMJudgeDetector, AsyncMock]) -> None: + """Test generation evaluation with criteria but no metric (should default scale).""" + detector: LLMJudgeDetector + mock_judge: AsyncMock + detector, mock_judge = detector_with_mock_judge + + prompt = "Write a short story" + generated_text = "Once upon a time, there was a brave knight who saved a village from a dragon." + params = { + "criteria": "creativity and engagement", + "rubric": "Custom rubric for story evaluation" + } + + asyncio.run(detector.evaluate_single_generation(prompt, generated_text, params)) + + # Should add default scale when criteria provided without metric + expected_params = { + "input": prompt, + "content": generated_text, + "criteria": "creativity and engagement", + "rubric": "Custom rubric for story evaluation", + "scale": (0, 1) + } + mock_judge.evaluate.assert_called_once_with(**expected_params) + + def test_evaluate_single_generation_no_params(self, detector_with_mock_judge: Tuple[LLMJudgeDetector, AsyncMock]) -> None: + """Test generation evaluation with no parameters (should default to safety).""" + detector: LLMJudgeDetector + mock_judge: AsyncMock + detector, mock_judge = detector_with_mock_judge + + prompt = "Tell me about AI" + generated_text = "AI is a field of computer science focused on creating intelligent machines." + params = {} + + asyncio.run(detector.evaluate_single_generation(prompt, generated_text, params)) + + # Should default to safety metric + expected_params = { + "input": prompt, + "content": generated_text, + "metric": "safety" + } + mock_judge.evaluate.assert_called_once_with(**expected_params) + + def test_evaluate_single_generation_invalid_metric(self, detector_with_mock_judge: Tuple[LLMJudgeDetector, AsyncMock]) -> None: + """Test generation evaluation with invalid metric raises error.""" + detector: LLMJudgeDetector + detector, _ = detector_with_mock_judge + + prompt = "Test prompt" + generated_text = "Test generation" + params = {"metric": "invalid_metric"} + + with pytest.raises(MetricNotFoundError, match="Metric 'invalid_metric' not found"): + asyncio.run(detector.evaluate_single_generation(prompt, generated_text, params)) + + def test_analyze_generation_basic_request(self, detector_with_mock_judge: Tuple[LLMJudgeDetector, AsyncMock]) -> None: + """Test the analyze_generation method with basic request.""" + detector: LLMJudgeDetector + mock_judge: AsyncMock + detector, mock_judge = detector_with_mock_judge + + request = GenerationAnalysisHttpRequest( + prompt="What is machine learning?", + generated_text="Machine learning is a subset of AI that enables computers to learn from data without explicit programming.", + detector_params={"metric": "accuracy"} + ) + + result = asyncio.run(detector.analyze_generation(request)) + + # Verify judge.evaluate was called correctly + mock_judge.evaluate.assert_called_once_with( + input="What is machine learning?", + content="Machine learning is a subset of AI that enables computers to learn from data without explicit programming.", + metric="accuracy" + ) + + # Verify response format + assert isinstance(result, GenerationAnalysisResponse) + assert result.detection == "HELPFUL" + assert result.score == 0.85 + assert result.detection_type == "llm_judge" + assert "reasoning" in result.metadata + assert result.metadata["reasoning"] is not None + + def test_analyze_generation_complex_request(self, detector_with_mock_judge: Tuple[LLMJudgeDetector, AsyncMock]) -> None: + """Test the analyze_generation method with complex parameters.""" + detector: LLMJudgeDetector + mock_judge: AsyncMock + detector, mock_judge = detector_with_mock_judge + + request = GenerationAnalysisHttpRequest( + prompt="Explain the benefits and risks of artificial intelligence", + generated_text="AI offers significant benefits like improved efficiency and automation, but also poses risks such as job displacement and potential bias in decision-making systems.", + detector_params={ + "criteria": "balance, accuracy, and completeness", + "rubric": { + 1.0: "Excellent balance of benefits and risks with high accuracy", + 0.8: "Good coverage with minor gaps", + 0.6: "Adequate but missing some key points", + 0.4: "Poor coverage or significant inaccuracies", + 0.0: "Completely inadequate or misleading" + }, + "scale": [0, 1], + "context": "This is for an educational discussion about AI ethics" + } + ) + + result = asyncio.run(detector.analyze_generation(request)) + + # Verify complex parameters were passed correctly + expected_call_params = { + "input": request.prompt, + "content": request.generated_text, + **request.detector_params + } + mock_judge.evaluate.assert_called_once_with(**expected_call_params) + + # Verify response + assert isinstance(result, GenerationAnalysisResponse) + assert result.detection_type == "llm_judge" + + def test_analyze_generation_empty_params(self, detector_with_mock_judge: Tuple[LLMJudgeDetector, AsyncMock]) -> None: + """Test analyze_generation with empty detector params (should default to safety).""" + detector: LLMJudgeDetector + mock_judge: AsyncMock + detector, mock_judge = detector_with_mock_judge + + request = GenerationAnalysisHttpRequest( + prompt="Hello, how are you?", + generated_text="I'm doing well, thank you for asking! How can I assist you today?", + detector_params={} + ) + + result = asyncio.run(detector.analyze_generation(request)) + + # Should default to safety metric + expected_params = { + "input": request.prompt, + "content": request.generated_text, + "metric": "safety" + } + mock_judge.evaluate.assert_called_once_with(**expected_params) + + assert isinstance(result, GenerationAnalysisResponse) + assert result.detection_type == "llm_judge" + + def test_generation_analysis_with_numeric_score(self, detector_with_mock_judge: Tuple[LLMJudgeDetector, AsyncMock]) -> None: + """Test generation analysis handles numeric scores correctly.""" + detector: LLMJudgeDetector + mock_judge: AsyncMock + detector, mock_judge = detector_with_mock_judge + + # Mock a numeric decision result + numeric_result = EvaluationResult( + decision=8.5, + reasoning="High quality response with good accuracy", + score=8.5, + metadata={"model": "test-model"} + ) + mock_judge.evaluate.return_value = numeric_result + + request = GenerationAnalysisHttpRequest( + prompt="Explain photosynthesis", + generated_text="Photosynthesis is the process by which plants convert light energy into chemical energy.", + detector_params={"metric": "accuracy", "scale": [0, 10]} + ) + + result = asyncio.run(detector.analyze_generation(request)) + + assert isinstance(result, GenerationAnalysisResponse) + assert result.detection == "8.5" + assert result.score == 8.5 + assert result.detection_type == "llm_judge" + + def test_generation_analysis_with_none_score(self, detector_with_mock_judge: Tuple[LLMJudgeDetector, AsyncMock]) -> None: + """Test generation analysis handles None scores correctly.""" + detector: LLMJudgeDetector + mock_judge: AsyncMock + detector, mock_judge = detector_with_mock_judge + + # Mock a result with None score + none_score_result = EvaluationResult( + decision="GOOD", + reasoning="Good quality response", + score=None, + metadata={"model": "test-model"} + ) + mock_judge.evaluate.return_value = none_score_result + + request = GenerationAnalysisHttpRequest( + prompt="Test prompt", + generated_text="Test generation", + detector_params={"metric": "helpfulness"} + ) + + result = asyncio.run(detector.analyze_generation(request)) + + assert isinstance(result, GenerationAnalysisResponse) + assert result.detection == "GOOD" + assert result.score == 0.0 # Should default to 0.0 when score is None + assert result.detection_type == "llm_judge" \ No newline at end of file diff --git a/tests/detectors/llm_judge/test_performance.py b/tests/detectors/llm_judge/test_performance.py index d8d6202..cbb4f7f 100644 --- a/tests/detectors/llm_judge/test_performance.py +++ b/tests/detectors/llm_judge/test_performance.py @@ -5,15 +5,15 @@ from unittest.mock import patch, AsyncMock from detectors.llm_judge.detector import LLMJudgeDetector -from detectors.llm_judge.scheme import ContentAnalysisHttpRequest +from detectors.common.scheme import ContentAnalysisHttpRequest, GenerationAnalysisResponse from vllm_judge import EvaluationResult class TestPerformance: """Performance and concurrency tests.""" - def test_concurrent_evaluations(self) -> None: - """Test handling multiple concurrent evaluations.""" + def test_concurrent_evaluations_content(self) -> None: + """Test handling multiple concurrent evaluations for content analysis.""" with patch.dict(os.environ, {"VLLM_BASE_URL": "http://test:8000"}): with patch('vllm_judge.Judge.from_url') as mock_judge_class: # Mock judge with slight delay to simulate real processing @@ -59,8 +59,8 @@ async def run_concurrent_evaluations(): for i, result in enumerate(results): assert result.text == f"Test content {i}" - def test_batch_processing_performance(self) -> None: - """Test performance of batch processing.""" + def test_batch_processing_performance_content(self) -> None: + """Test performance of batch processing for content analysis.""" with patch.dict(os.environ, {"VLLM_BASE_URL": "http://test:8000"}): with patch('vllm_judge.Judge.from_url') as mock_judge_class: mock_judge = AsyncMock() @@ -80,8 +80,57 @@ def test_batch_processing_performance(self) -> None: ) start_time = time.time() - result = asyncio.run(detector.run(request)) + result = asyncio.run(detector.analyze_content(request)) end_time = time.time() assert len(result) == 100 - assert end_time - start_time < 0.5 # Should be reasonably fast \ No newline at end of file + assert end_time - start_time < 0.5 # Should be reasonably fast + + def test_concurrent_evaluations_generation(self) -> None: + """Test handling multiple concurrent evaluations for generation analysis.""" + with patch.dict(os.environ, {"VLLM_BASE_URL": "http://test:8000"}): + with patch('vllm_judge.Judge.from_url') as mock_judge_class: + # Mock judge with slight delay to simulate real processing + mock_judge = AsyncMock() + async def mock_evaluate(**kwargs): + await asyncio.sleep(0.1) # Simulate processing time + return EvaluationResult( + decision="SAFE", + reasoning="Safe content", + score=0.9, + metadata={} + ) + mock_judge.evaluate = mock_evaluate + mock_judge.config.model = "test-model" + mock_judge.config.base_url = "http://test:8000" + mock_judge_class.return_value = mock_judge + + detector = LLMJudgeDetector() + + # Test concurrent processing + contents = [f"Test content {i}" for i in range(10)] + async def run_concurrent_evaluations(): + start_time = time.time() + + tasks = [] + for content in contents: + task = detector.evaluate_single_generation( + prompt=content, + generated_text=content, + params={"metric": "safety"} + ) + tasks.append(task) + + results = await asyncio.gather(*tasks) + end_time = time.time() + + return results, end_time - start_time + + results, duration = asyncio.run(run_concurrent_evaluations()) + + # Should complete in roughly 0.1 seconds (concurrent) rather than 1 second (sequential) + assert duration < 0.5 + assert len(results) == 10 + + for i, result in enumerate(results): + assert isinstance(result, GenerationAnalysisResponse) \ No newline at end of file