1010import os
1111import sys
1212import time
13+ import unittest
1314from collections import defaultdict
1415
1516import requests
1617
1718# Add parent directory to path to allow importing common test utilities
1819sys .path .append (os .path .dirname (os .path .dirname (os .path .abspath (__file__ ))))
19- from tests . test_base import SemanticRouterTestBase
20+ from test_base import SemanticRouterTestBase
2021
2122# Constants
2223ENVOY_URL = "http://localhost:8801"
2324OPENAI_ENDPOINT = "/v1/chat/completions"
2425ROUTER_METRICS_URL = "http://localhost:9190/metrics"
25- DEFAULT_MODEL = "qwen2.5:32b " # Changed from gemma3:27b to match make test-prompt
26+ DEFAULT_MODEL = "Model-A " # Use configured model that matches router config
2627
2728# Category test cases - each designed to trigger a specific classifier category
29+ # Based on config.e2e.yaml: math→Model-B, computer science→Model-B, business→Model-A, history→Model-A
2830CATEGORY_TEST_CASES = [
2931 {
3032 "name" : "Math Query" ,
3133 "expected_category" : "math" ,
32- "content" : "Solve the differential equation dy/dx + 2y = x^2 with the initial condition y(0) = 1." ,
34+ "expected_model" : "Model-B" , # math has Model-B with score 1.0
35+ "content" : "Solve the quadratic equation x^2 + 5x + 6 = 0 and explain the steps." ,
3336 },
3437 {
35- "name" : "Creative Writing Query" ,
36- "expected_category" : "creative" ,
37- "content" : "Write a short story about a space cat." ,
38+ "name" : "Computer Science/Coding Query" ,
39+ "expected_category" : "computer science" ,
40+ "expected_model" : "Model-B" , # computer science has Model-B with score 0.6
41+ "content" : "Write a Python function to implement a linked list with insert and delete operations." ,
3842 },
39- ] # Reduced to just 2 test cases to avoid timeouts
43+ {
44+ "name" : "Business Query" ,
45+ "expected_category" : "business" ,
46+ "expected_model" : "Model-A" , # business has Model-A with score 0.8
47+ "content" : "What are the key principles of supply chain management in modern business?" ,
48+ },
49+ {
50+ "name" : "History Query" ,
51+ "expected_category" : "history" ,
52+ "expected_model" : "Model-A" , # history has Model-A with score 0.8
53+ "content" : "Describe the main causes and key events of World War I." ,
54+ },
55+ ]
4056
4157
4258class RouterClassificationTest (SemanticRouterTestBase ):
@@ -129,7 +145,7 @@ def test_classification_consistency(self):
129145 f"{ ENVOY_URL } { OPENAI_ENDPOINT } " ,
130146 headers = {"Content-Type" : "application/json" },
131147 json = payload ,
132- timeout = 10 ,
148+ timeout = 60 ,
133149 )
134150
135151 passed = response .status_code < 400
@@ -165,7 +181,7 @@ def test_category_classification(self):
165181 self .print_subtest_header (test_case ["name" ])
166182
167183 payload = {
168- "model" : DEFAULT_MODEL ,
184+ "model" : "auto" , # Use "auto" to trigger category-based classification routing
169185 "messages" : [
170186 {
171187 "role" : "assistant" ,
@@ -178,7 +194,7 @@ def test_category_classification(self):
178194
179195 self .print_request_info (
180196 payload = payload ,
181- expectations = f"Expect: Query to be classified as { test_case ['expected_category' ]} and routed accordingly " ,
197+ expectations = f"Expect: Query classified as ' { test_case ['expected_category' ]} ' → routed to { test_case . get ( 'expected_model' , 'appropriate model' ) } " ,
182198 )
183199
184200 response = requests .post (
@@ -188,25 +204,30 @@ def test_category_classification(self):
188204 timeout = 60 ,
189205 )
190206
191- passed = response .status_code < 400
192207 response_json = response .json ()
193- model = response_json .get ("model" , "unknown" )
194- results [test_case ["name" ]] = model
208+ actual_model = response_json .get ("model" , "unknown" )
209+ expected_model = test_case .get ("expected_model" , "unknown" )
210+ results [test_case ["name" ]] = actual_model
211+
212+ model_match = actual_model == expected_model
213+ passed = response .status_code < 400 and model_match
195214
196215 self .print_response_info (
197216 response ,
198217 {
199218 "Expected Category" : test_case ["expected_category" ],
200- "Selected Model" : model ,
219+ "Expected Model" : expected_model ,
220+ "Actual Model" : actual_model ,
221+ "Routing Correct" : "✅" if model_match else "❌" ,
201222 },
202223 )
203224
204225 self .print_test_result (
205226 passed = passed ,
206227 message = (
207- f"Query successfully routed to model: { model } "
208- if passed
209- else f"Request failed with status { response . status_code } "
228+ f"Query correctly routed to { actual_model } "
229+ if model_match
230+ else f"Routing failed: expected { expected_model } , got { actual_model } "
210231 ),
211232 )
212233
@@ -216,22 +237,29 @@ def test_category_classification(self):
216237 f"{ test_case ['name' ]} request failed with status { response .status_code } " ,
217238 )
218239
240+ self .assertEqual (
241+ actual_model ,
242+ expected_model ,
243+ f"{ test_case ['name' ]} : Expected routing to { expected_model } , but got { actual_model } " ,
244+ )
245+
219246 def test_classifier_metrics (self ):
220- """Test that classification metrics are being recorded."""
247+ """Test that router metrics are being recorded and exposed ."""
221248 self .print_test_header (
222- "Classifier Metrics Test" ,
223- "Verifies that classification metrics are being properly recorded and exposed" ,
249+ "Router Metrics Test" ,
250+ "Verifies that router metrics (classification, cache operations) are being properly recorded and exposed" ,
224251 )
225252
226253 # First, let's get the current metrics as a baseline
227254 response = requests .get (ROUTER_METRICS_URL )
228255 baseline_metrics = response .text
229256
230- # Check if classification metrics exist without making additional requests
257+ # Check if classification and routing metrics exist
258+ # These are the actual metrics exposed by the router
231259 classification_metrics = [
232- "llm_router_classification_duration_seconds" ,
233- "llm_router_requests_total" ,
234- "llm_router_model_selection_count" ,
260+ "llm_entropy_classification_latency_seconds" , # Entropy-based classification timing
261+ "llm_cache_hits_total" , # Cache operations (related to classification)
262+ "llm_cache_misses_total" , # Cache misses
235263 ]
236264
237265 metrics_found = 0
@@ -259,13 +287,17 @@ def test_classifier_metrics(self):
259287 self .print_test_result (
260288 passed = passed ,
261289 message = (
262- f"Found { metrics_found } classification metrics"
290+ f"Found { metrics_found } / { len ( classification_metrics ) } router metrics"
263291 if passed
264- else "No classification metrics found"
292+ else "No router metrics found"
265293 ),
266294 )
267295
268- self .assertGreaterEqual (metrics_found , 0 , "No classification metrics found" )
296+ self .assertGreater (
297+ metrics_found ,
298+ 0 ,
299+ f"No router metrics found. Expected at least one of: { ', ' .join (classification_metrics )} " ,
300+ )
269301
270302
271303if __name__ == "__main__" :
0 commit comments