Skip to content

Commit 0c81a7a

Browse files
committed
fix: resolve syntax errors after rebase
Signed-off-by: Huamin Chen <[email protected]>
1 parent 88f85ed commit 0c81a7a

File tree

2 files changed

+273
-0
lines changed

2 files changed

+273
-0
lines changed

src/semantic-router/cmd/main.go

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,38 @@ func main() {
111111

112112
observability.Infof("Starting vLLM Semantic Router ExtProc with config: %s", *configPath)
113113

114+
<<<<<<< HEAD
115+
=======
116+
// Initialize embedding models if configured (Long-context support)
117+
cfg, err = config.LoadConfig(*configPath)
118+
if err != nil {
119+
observability.Warnf("Failed to load config for embedding models: %v", err)
120+
} else if cfg.EmbeddingModels.Qwen3ModelPath != "" || cfg.EmbeddingModels.GemmaModelPath != "" {
121+
observability.Infof("Initializing embedding models...")
122+
observability.Infof(" Qwen3 model: %s", cfg.EmbeddingModels.Qwen3ModelPath)
123+
observability.Infof(" Gemma model: %s", cfg.EmbeddingModels.GemmaModelPath)
124+
observability.Infof(" Use CPU: %v", cfg.EmbeddingModels.UseCPU)
125+
126+
if err := candle_binding.InitEmbeddingModels(
127+
cfg.EmbeddingModels.Qwen3ModelPath,
128+
cfg.EmbeddingModels.GemmaModelPath,
129+
cfg.EmbeddingModels.UseCPU,
130+
); err != nil {
131+
observability.Errorf("Failed to initialize embedding models: %v", err)
132+
observability.Warnf("Embedding API endpoints will return placeholder embeddings")
133+
} else {
134+
observability.Infof("Embedding models initialized successfully")
135+
}
136+
} else {
137+
observability.Infof("No embedding models configured, skipping initialization")
138+
observability.Infof("To enable embedding models, add to config.yaml:")
139+
observability.Infof(" embedding_models:")
140+
observability.Infof(" qwen3_model_path: 'models/Qwen3-Embedding-0.6B'")
141+
observability.Infof(" gemma_model_path: 'models/embeddinggemma-300m'")
142+
observability.Infof(" use_cpu: true")
143+
}
144+
145+
>>>>>>> f9802f0 (fix: resolve syntax errors after rebase)
114146
// Start API server if enabled
115147
if *enableAPI {
116148
go func() {

src/semantic-router/pkg/api/server.go

Lines changed: 241 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1239,4 +1239,245 @@ func (s *ClassificationAPIServer) handleUpdateSystemPrompts(w http.ResponseWrite
12391239
http.Error(w, "Failed to encode response", http.StatusInternalServerError)
12401240
return
12411241
}
1242+
<<<<<<< HEAD
1243+
=======
1244+
}
1245+
1246+
// handleEmbeddings handles embedding generation requests
1247+
func (s *ClassificationAPIServer) handleEmbeddings(w http.ResponseWriter, r *http.Request) {
1248+
// Parse request
1249+
var req EmbeddingRequest
1250+
if err := s.parseJSONRequest(r, &req); err != nil {
1251+
s.writeErrorResponse(w, http.StatusBadRequest, "INVALID_INPUT", err.Error())
1252+
return
1253+
}
1254+
1255+
// Validate input
1256+
if len(req.Texts) == 0 {
1257+
s.writeErrorResponse(w, http.StatusBadRequest, "INVALID_INPUT", "texts array cannot be empty")
1258+
return
1259+
}
1260+
1261+
// Set defaults
1262+
if req.Model == "" {
1263+
req.Model = "auto"
1264+
}
1265+
if req.Dimension == 0 {
1266+
req.Dimension = 768 // Default to full dimension
1267+
}
1268+
if req.QualityPriority == 0 && req.LatencyPriority == 0 {
1269+
req.QualityPriority = 0.5
1270+
req.LatencyPriority = 0.5
1271+
}
1272+
1273+
// Validate dimension
1274+
validDimensions := map[int]bool{128: true, 256: true, 512: true, 768: true, 1024: true}
1275+
if !validDimensions[req.Dimension] {
1276+
s.writeErrorResponse(w, http.StatusBadRequest, "INVALID_DIMENSION",
1277+
fmt.Sprintf("dimension must be one of: 128, 256, 512, 768, 1024 (got %d)", req.Dimension))
1278+
return
1279+
}
1280+
1281+
// Generate embeddings for each text
1282+
results := make([]EmbeddingResult, 0, len(req.Texts))
1283+
var totalProcessingTime int64
1284+
1285+
for _, text := range req.Texts {
1286+
var output *candle_binding.EmbeddingOutput
1287+
var err error
1288+
1289+
// Choose between manual model selection or automatic routing
1290+
if req.Model == "auto" || req.Model == "" {
1291+
// Automatic routing based on quality/latency priorities
1292+
output, err = candle_binding.GetEmbeddingWithMetadata(
1293+
text,
1294+
req.QualityPriority,
1295+
req.LatencyPriority,
1296+
req.Dimension,
1297+
)
1298+
} else {
1299+
// Manual model selection ("qwen3" or "gemma")
1300+
output, err = candle_binding.GetEmbeddingWithModelType(
1301+
text,
1302+
req.Model,
1303+
req.Dimension,
1304+
)
1305+
}
1306+
1307+
if err != nil {
1308+
s.writeErrorResponse(w, http.StatusInternalServerError, "EMBEDDING_GENERATION_FAILED",
1309+
fmt.Sprintf("failed to generate embedding: %v", err))
1310+
return
1311+
}
1312+
1313+
// Use metadata directly from Rust layer
1314+
processingTime := int64(output.ProcessingTimeMs)
1315+
1316+
results = append(results, EmbeddingResult{
1317+
Text: text,
1318+
Embedding: output.Embedding,
1319+
Dimension: len(output.Embedding),
1320+
ModelUsed: output.ModelType,
1321+
ProcessingTimeMs: processingTime,
1322+
})
1323+
1324+
totalProcessingTime += processingTime
1325+
}
1326+
1327+
// Calculate statistics
1328+
avgProcessingTime := float64(totalProcessingTime) / float64(len(req.Texts))
1329+
1330+
response := EmbeddingResponse{
1331+
Embeddings: results,
1332+
TotalCount: len(results),
1333+
TotalProcessingTimeMs: totalProcessingTime,
1334+
AvgProcessingTimeMs: avgProcessingTime,
1335+
}
1336+
1337+
observability.Infof("Generated %d embeddings in %dms (avg: %.2fms)",
1338+
len(results), totalProcessingTime, avgProcessingTime)
1339+
1340+
s.writeJSONResponse(w, http.StatusOK, response)
1341+
}
1342+
1343+
// handleSimilarity handles text similarity calculation requests
1344+
func (s *ClassificationAPIServer) handleSimilarity(w http.ResponseWriter, r *http.Request) {
1345+
// Parse request
1346+
var req SimilarityRequest
1347+
if err := s.parseJSONRequest(r, &req); err != nil {
1348+
s.writeErrorResponse(w, http.StatusBadRequest, "INVALID_INPUT", err.Error())
1349+
return
1350+
}
1351+
1352+
// Validate input
1353+
if req.Text1 == "" || req.Text2 == "" {
1354+
s.writeErrorResponse(w, http.StatusBadRequest, "INVALID_INPUT", "both text1 and text2 must be provided")
1355+
return
1356+
}
1357+
1358+
// Set defaults
1359+
if req.Model == "" {
1360+
req.Model = "auto"
1361+
}
1362+
if req.Dimension == 0 {
1363+
req.Dimension = 768 // Default to full dimension
1364+
}
1365+
if req.Model == "auto" && req.QualityPriority == 0 && req.LatencyPriority == 0 {
1366+
req.QualityPriority = 0.5
1367+
req.LatencyPriority = 0.5
1368+
}
1369+
1370+
// Validate dimension
1371+
validDimensions := map[int]bool{128: true, 256: true, 512: true, 768: true, 1024: true}
1372+
if !validDimensions[req.Dimension] {
1373+
s.writeErrorResponse(w, http.StatusBadRequest, "INVALID_DIMENSION",
1374+
fmt.Sprintf("dimension must be one of: 128, 256, 512, 768, 1024 (got %d)", req.Dimension))
1375+
return
1376+
}
1377+
1378+
// Calculate similarity
1379+
result, err := candle_binding.CalculateEmbeddingSimilarity(
1380+
req.Text1,
1381+
req.Text2,
1382+
req.Model,
1383+
req.Dimension,
1384+
)
1385+
1386+
if err != nil {
1387+
s.writeErrorResponse(w, http.StatusInternalServerError, "SIMILARITY_CALCULATION_FAILED",
1388+
fmt.Sprintf("failed to calculate similarity: %v", err))
1389+
return
1390+
}
1391+
1392+
response := SimilarityResponse{
1393+
Similarity: result.Similarity,
1394+
ModelUsed: result.ModelType,
1395+
ProcessingTimeMs: result.ProcessingTimeMs,
1396+
}
1397+
1398+
observability.Infof("Calculated similarity: %.4f (model: %s, took: %.2fms)",
1399+
result.Similarity, result.ModelType, result.ProcessingTimeMs)
1400+
1401+
s.writeJSONResponse(w, http.StatusOK, response)
1402+
}
1403+
1404+
// handleBatchSimilarity handles batch similarity matching requests
1405+
func (s *ClassificationAPIServer) handleBatchSimilarity(w http.ResponseWriter, r *http.Request) {
1406+
// Parse request
1407+
var req BatchSimilarityRequest
1408+
if err := s.parseJSONRequest(r, &req); err != nil {
1409+
s.writeErrorResponse(w, http.StatusBadRequest, "INVALID_INPUT", err.Error())
1410+
return
1411+
}
1412+
1413+
// Validate input
1414+
if req.Query == "" {
1415+
s.writeErrorResponse(w, http.StatusBadRequest, "INVALID_INPUT", "query must be provided")
1416+
return
1417+
}
1418+
if len(req.Candidates) == 0 {
1419+
s.writeErrorResponse(w, http.StatusBadRequest, "INVALID_INPUT", "candidates array cannot be empty")
1420+
return
1421+
}
1422+
1423+
// Set defaults
1424+
if req.Model == "" {
1425+
req.Model = "auto"
1426+
}
1427+
if req.Dimension == 0 {
1428+
req.Dimension = 768 // Default to full dimension
1429+
}
1430+
if req.TopK == 0 {
1431+
req.TopK = len(req.Candidates) // Default to all candidates
1432+
}
1433+
if req.Model == "auto" && req.QualityPriority == 0 && req.LatencyPriority == 0 {
1434+
req.QualityPriority = 0.5
1435+
req.LatencyPriority = 0.5
1436+
}
1437+
1438+
// Validate dimension
1439+
validDimensions := map[int]bool{128: true, 256: true, 512: true, 768: true, 1024: true}
1440+
if !validDimensions[req.Dimension] {
1441+
s.writeErrorResponse(w, http.StatusBadRequest, "INVALID_DIMENSION",
1442+
fmt.Sprintf("dimension must be one of: 128, 256, 512, 768, 1024 (got %d)", req.Dimension))
1443+
return
1444+
}
1445+
1446+
// Calculate batch similarity
1447+
result, err := candle_binding.CalculateSimilarityBatch(
1448+
req.Query,
1449+
req.Candidates,
1450+
req.TopK,
1451+
req.Model,
1452+
req.Dimension,
1453+
)
1454+
1455+
if err != nil {
1456+
s.writeErrorResponse(w, http.StatusInternalServerError, "BATCH_SIMILARITY_FAILED",
1457+
fmt.Sprintf("failed to calculate batch similarity: %v", err))
1458+
return
1459+
}
1460+
1461+
// Build response with matched text included
1462+
matches := make([]BatchSimilarityMatch, len(result.Matches))
1463+
for i, match := range result.Matches {
1464+
matches[i] = BatchSimilarityMatch{
1465+
Index: match.Index,
1466+
Similarity: match.Similarity,
1467+
Text: req.Candidates[match.Index],
1468+
}
1469+
}
1470+
1471+
response := BatchSimilarityResponse{
1472+
Matches: matches,
1473+
TotalCandidates: len(req.Candidates),
1474+
ModelUsed: result.ModelType,
1475+
ProcessingTimeMs: result.ProcessingTimeMs,
1476+
}
1477+
1478+
observability.Infof("Calculated batch similarity: query='%s', %d candidates, top-%d matches (model: %s, took: %.2fms)",
1479+
req.Query, len(req.Candidates), len(matches), result.ModelType, result.ProcessingTimeMs)
1480+
1481+
s.writeJSONResponse(w, http.StatusOK, response)
1482+
>>>>>>> f9802f0 (fix: resolve syntax errors after rebase)
12421483
}

0 commit comments

Comments
 (0)