diff --git a/config/config-mcp-classifier-example.yaml b/config/config-mcp-classifier-example.yaml index 1aaca432..22468df6 100644 --- a/config/config-mcp-classifier-example.yaml +++ b/config/config-mcp-classifier-example.yaml @@ -45,7 +45,19 @@ classifier: # # How it works: # 1. Router connects to MCP server at startup -# 2. Calls 'list_categories' tool: MCP returns {"categories": ["business", "law", ...]} +# 2. Calls 'list_categories' tool and MCP returns: +# { +# "categories": ["math", "science", "technology", "history", "general"], +# "category_system_prompts": { +# "math": "You are a mathematics expert. When answering math questions...", +# "science": "You are a science expert. When answering science questions...", +# "technology": "You are a technology expert..." +# }, +# "category_descriptions": { +# "math": "Mathematical and computational queries", +# "science": "Scientific concepts and queries" +# } +# } # 3. For each request, calls 'classify_text' tool which returns: # { # "class": 3, @@ -55,14 +67,28 @@ classifier: # } # 4. Router uses the model and reasoning settings from MCP response # +# PER-CATEGORY SYSTEM PROMPT INJECTION: +# - The MCP server provides SEPARATE system prompts for EACH category +# - Each category gets its own specialized instructions and context +# - The router stores these prompts and injects the appropriate one per query +# - Use classifier.GetCategorySystemPrompt(categoryName) to retrieve for a specific category +# - Examples: +# * Math category: "You are a mathematics expert. Show step-by-step solutions..." +# * Science category: "You are a science expert. Provide evidence-based answers..." +# * Technology category: "You are a tech expert. Include practical code examples..." +# - This allows domain-specific expertise per category +# # BENEFITS: # - MCP server makes intelligent routing decisions per query # - No hardcoded routing rules needed in config # - MCP can adapt routing based on query complexity, content, etc. -# - Centralized routing logic in MCP server +# - Centralized routing logic and per-category system prompts in MCP server +# - Category descriptions available for logging and debugging +# - Domain-specific LLM behavior for each category # # FALLBACK: # - If MCP doesn't return model/use_reasoning, uses default_model below +# - If MCP doesn't return category_system_prompts, router can use default prompts # - Can also add category-specific overrides here if needed # categories: [] diff --git a/examples/mcp-classifier-server/README.md b/examples/mcp-classifier-server/README.md index f8aaf7f1..cc639073 100644 --- a/examples/mcp-classifier-server/README.md +++ b/examples/mcp-classifier-server/README.md @@ -5,6 +5,7 @@ Example MCP server that provides text classification with intelligent routing fo ## Features - **Dynamic Categories**: Loaded from MCP server at runtime via `list_categories` +- **Per-Category System Prompts**: Each category has its own specialized system prompt for LLM context - **Intelligent Routing**: Returns `model` and `use_reasoning` in classification response - **Regex-Based**: Simple pattern matching (replace with ML models for production) - **Dual Transport**: Supports both HTTP and stdio @@ -81,8 +82,23 @@ github.com/vllm-project/semantic-router/src/semantic-router/pkg/connectivity/mcp 1. **`list_categories`** - Returns `ListCategoriesResponse`: ```json - {"categories": ["math", "science", "technology", ...]} + { + "categories": ["math", "science", "technology", "history", "general"], + "category_system_prompts": { + "math": "You are a mathematics expert. When answering math questions...", + "science": "You are a science expert. When answering science questions...", + "technology": "You are a technology expert. When answering tech questions..." + }, + "category_descriptions": { + "math": "Mathematical and computational queries", + "science": "Scientific concepts and queries" + } + } ``` + + The `category_system_prompts` and `category_descriptions` fields are optional but recommended. + Per-category system prompts allow the MCP server to provide specialized instructions for each + category that the router can inject when processing queries in that specific category. 2. **`classify_text`** - Returns `ClassifyResponse`: @@ -109,18 +125,24 @@ See the `api` package for full type definitions and documentation. ## Customization -Edit `CATEGORIES` to add categories: +**Edit `CATEGORIES` to add categories with per-category system prompts:** ```python CATEGORIES = { "your_category": { "patterns": [r"\b(keyword1|keyword2)\b"], - "description": "Your description" + "description": "Your description", + "system_prompt": """You are an expert in your_category. When answering: +- Provide specific guidance +- Use domain-specific terminology +- Follow best practices for this domain""" } } ``` -Edit `decide_routing()` for custom routing logic: +Each category can have its own specialized system prompt tailored to that domain. + +**Edit `decide_routing()` for custom routing logic:** ```python def decide_routing(text, category, confidence): @@ -129,6 +151,19 @@ def decide_routing(text, category, confidence): return "openai/gpt-oss-20b", True ``` +**Using Per-Category System Prompts in the Router:** + +The router stores per-category system prompts when loading categories. To use them: + +```go +// After classifying a query, get the category-specific system prompt +category := "math" // from classification result +if systemPrompt, ok := classifier.GetCategorySystemPrompt(category); ok { + // Inject the category-specific system prompt when making LLM requests + // Each category gets its own specialized instructions +} +``` + ## License MIT diff --git a/examples/mcp-classifier-server/server.py b/examples/mcp-classifier-server/server.py index b4bbc4fa..a649ff18 100755 --- a/examples/mcp-classifier-server/server.py +++ b/examples/mcp-classifier-server/server.py @@ -8,11 +8,22 @@ 3. Intelligent routing decisions (model selection and reasoning control) The server implements two MCP tools: -- 'list_categories': Returns available categories for dynamic loading +- 'list_categories': Returns available categories with per-category system prompts and descriptions - 'classify_text': Classifies text and returns routing recommendations Protocol: -- list_categories returns: {"categories": ["math", "science", "technology", ...]} +- list_categories returns: { + "categories": ["math", "science", "technology", ...], + "category_system_prompts": { # optional, per-category system prompts + "math": "You are a mathematics expert. When answering math questions...", + "science": "You are a science expert. When answering science questions...", + "technology": "You are a technology expert. When answering tech questions..." + }, + "category_descriptions": { # optional + "math": "Mathematical and computational queries", + "science": "Scientific concepts and queries" + } + } - classify_text returns: { "class": 0, "confidence": 0.85, @@ -46,7 +57,8 @@ ) logger = logging.getLogger(__name__) -# Define classification categories and their regex patterns +# Define classification categories with their regex patterns, descriptions, and system prompts +# Each category has its own system prompt for specialized context CATEGORIES = { "math": { "patterns": [ @@ -56,6 +68,12 @@ r"\b(sin|cos|tan|log|sqrt|sum|average|mean)\b", ], "description": "Mathematical and computational queries", + "system_prompt": """You are a mathematics expert. When answering math questions: +- Show step-by-step solutions with clear explanations +- Use proper mathematical notation and terminology +- Verify calculations and provide intermediate steps +- Explain the underlying concepts and principles +- Offer alternative approaches when applicable""", }, "science": { "patterns": [ @@ -65,6 +83,12 @@ r"\b(planet|star|galaxy|universe|ecosystem|organism)\b", ], "description": "Scientific concepts and queries", + "system_prompt": """You are a science expert. When answering science questions: +- Provide evidence-based answers grounded in scientific research +- Explain relevant scientific concepts and principles +- Use appropriate scientific terminology +- Cite the scientific method and experimental evidence when relevant +- Distinguish between established facts and current theories""", }, "technology": { "patterns": [ @@ -74,6 +98,12 @@ r"\b(python|java|javascript|C\+\+|golang|rust)\b", ], "description": "Technology and computing topics", + "system_prompt": """You are a technology expert. When answering tech questions: +- Include practical examples and code snippets when relevant +- Follow best practices and industry standards +- Explain both high-level concepts and implementation details +- Consider security, performance, and maintainability +- Recommend appropriate tools and technologies for the use case""", }, "history": { "patterns": [ @@ -83,10 +113,22 @@ r"\b(BCE|CE|AD|BC|\d{4})\b.*\b(year|century|ago)\b", ], "description": "Historical events and topics", + "system_prompt": """You are a history expert. When answering historical questions: +- Provide accurate dates, names, and historical context +- Cite time periods and geographical locations +- Explain the causes, events, and consequences +- Consider multiple perspectives and historical interpretations +- Connect historical events to their broader significance""", }, "general": { "patterns": [r".*"], # Catch-all pattern "description": "General questions and topics", + "system_prompt": """You are a knowledgeable assistant. When answering general questions: +- Provide balanced, well-rounded responses +- Draw from multiple domains of knowledge when relevant +- Be clear, concise, and accurate +- Adapt your explanation to the complexity of the question +- Acknowledge limitations and uncertainties when appropriate""", }, } @@ -300,8 +342,9 @@ async def list_tools() -> list[Tool]: Tool( name="list_categories", description=( - "List all available classification categories. " - "Returns a simple array of category names that the router will use for dynamic category loading." + "List all available classification categories with per-category system prompts and descriptions. " + "Returns: categories (array), category_system_prompts (object), category_descriptions (object). " + "Each category can have its own system prompt that the router injects for category-specific LLM context." ), inputSchema={"type": "object", "properties": {}}, ), @@ -328,9 +371,27 @@ async def call_tool(name: str, arguments: Any) -> list[TextContent]: return [TextContent(type="text", text=json.dumps({"error": str(e)}))] elif name == "list_categories": - # Return simple list of category names as expected by semantic router - categories_response = {"categories": CATEGORY_NAMES} - logger.info(f"Returning {len(CATEGORY_NAMES)} categories: {CATEGORY_NAMES}") + # Return category information including per-category system prompts and descriptions + # This allows the router to get category-specific instructions from the MCP server + category_descriptions = { + name: CATEGORIES[name]["description"] for name in CATEGORY_NAMES + } + + category_system_prompts = { + name: CATEGORIES[name]["system_prompt"] + for name in CATEGORY_NAMES + if "system_prompt" in CATEGORIES[name] + } + + categories_response = { + "categories": CATEGORY_NAMES, + "category_system_prompts": category_system_prompts, + "category_descriptions": category_descriptions, + } + + logger.info( + f"Returning {len(CATEGORY_NAMES)} categories with {len(category_system_prompts)} system prompts: {CATEGORY_NAMES}" + ) return [TextContent(type="text", text=json.dumps(categories_response))] else: diff --git a/src/semantic-router/pkg/connectivity/mcp/api/types.go b/src/semantic-router/pkg/connectivity/mcp/api/types.go index 94b468ab..27506f93 100644 --- a/src/semantic-router/pkg/connectivity/mcp/api/types.go +++ b/src/semantic-router/pkg/connectivity/mcp/api/types.go @@ -88,7 +88,17 @@ type ClassifyWithProbabilitiesResponse struct { // Example JSON: // // { -// "categories": ["business", "law", "medical", "technical", "general"] +// "categories": ["business", "law", "medical", "technical", "general"], +// "category_system_prompts": { +// "business": "You are a business and finance expert. Provide detailed financial analysis...", +// "law": "You are a legal expert. Provide accurate legal information and cite relevant laws...", +// "medical": "You are a medical professional. Provide evidence-based health information..." +// }, +// "category_descriptions": { +// "business": "Business and finance related queries", +// "law": "Legal questions and regulations", +// "medical": "Healthcare and medical information" +// } // } type ListCategoriesResponse struct { // Categories is the ordered list of category names. @@ -98,4 +108,14 @@ type ListCategoriesResponse struct { // - class 1 = "law" // - class 2 = "medical" Categories []string `json:"categories"` + + // CategorySystemPrompts provides optional per-category system prompts that the router + // can inject when processing queries in specific categories. This allows the MCP server + // to provide category-specific instructions that guide the LLM's behavior. + // The map key is the category name, and the value is the system prompt for that category. + CategorySystemPrompts map[string]string `json:"category_system_prompts,omitempty"` + + // CategoryDescriptions provides optional human-readable descriptions for each category. + // This can be used for logging, debugging, or providing context to downstream systems. + CategoryDescriptions map[string]string `json:"category_descriptions,omitempty"` } diff --git a/src/semantic-router/pkg/utils/classification/classifier.go b/src/semantic-router/pkg/utils/classification/classifier.go index 82660e15..ac5e5c0e 100644 --- a/src/semantic-router/pkg/utils/classification/classifier.go +++ b/src/semantic-router/pkg/utils/classification/classifier.go @@ -858,6 +858,27 @@ func (c *Classifier) GetCategoryByName(categoryName string) *config.Category { return c.findCategory(categoryName) } +// GetCategorySystemPrompt returns the system prompt for a specific category if available. +// This is useful when the MCP server provides category-specific system prompts that should +// be injected when processing queries in that category. +// Returns empty string and false if no system prompt is available for the category. +func (c *Classifier) GetCategorySystemPrompt(category string) (string, bool) { + if c.CategoryMapping == nil { + return "", false + } + return c.CategoryMapping.GetCategorySystemPrompt(category) +} + +// GetCategoryDescription returns the description for a given category if available. +// This is useful for logging, debugging, or providing context to downstream systems. +// Returns empty string and false if the category has no description. +func (c *Classifier) GetCategoryDescription(category string) (string, bool) { + if c.CategoryMapping == nil { + return "", false + } + return c.CategoryMapping.GetCategoryDescription(category) +} + // buildCategoryNameMappings builds translation maps between MMLU-Pro and generic categories func (c *Classifier) buildCategoryNameMappings() { c.MMLUToGeneric = make(map[string]string) diff --git a/src/semantic-router/pkg/utils/classification/mapping.go b/src/semantic-router/pkg/utils/classification/mapping.go index 0c45e0ec..aab7ce05 100644 --- a/src/semantic-router/pkg/utils/classification/mapping.go +++ b/src/semantic-router/pkg/utils/classification/mapping.go @@ -8,8 +8,10 @@ import ( // CategoryMapping holds the mapping between indices and domain categories type CategoryMapping struct { - CategoryToIdx map[string]int `json:"category_to_idx"` - IdxToCategory map[string]string `json:"idx_to_category"` + CategoryToIdx map[string]int `json:"category_to_idx"` + IdxToCategory map[string]string `json:"idx_to_category"` + CategorySystemPrompts map[string]string `json:"category_system_prompts,omitempty"` // Optional per-category system prompts from MCP server + CategoryDescriptions map[string]string `json:"category_descriptions,omitempty"` // Optional category descriptions } // PIIMapping holds the mapping between indices and PII types @@ -98,6 +100,24 @@ func (cm *CategoryMapping) GetCategoryCount() int { return len(cm.CategoryToIdx) } +// GetCategorySystemPrompt returns the system prompt for a specific category if available +func (cm *CategoryMapping) GetCategorySystemPrompt(category string) (string, bool) { + if cm.CategorySystemPrompts == nil { + return "", false + } + prompt, ok := cm.CategorySystemPrompts[category] + return prompt, ok +} + +// GetCategoryDescription returns the description for a given category +func (cm *CategoryMapping) GetCategoryDescription(category string) (string, bool) { + if cm.CategoryDescriptions == nil { + return "", false + } + desc, ok := cm.CategoryDescriptions[category] + return desc, ok +} + // GetPIITypeCount returns the number of PII types in the mapping func (pm *PIIMapping) GetPIITypeCount() int { return len(pm.LabelToIdx) diff --git a/src/semantic-router/pkg/utils/classification/mcp_classifier.go b/src/semantic-router/pkg/utils/classification/mcp_classifier.go index 18ebda33..9c591dde 100644 --- a/src/semantic-router/pkg/utils/classification/mcp_classifier.go +++ b/src/semantic-router/pkg/utils/classification/mcp_classifier.go @@ -319,8 +319,10 @@ func (m *MCPCategoryClassifier) ListCategories(ctx context.Context) (*CategoryMa // Build CategoryMapping from the list mapping := &CategoryMapping{ - CategoryToIdx: make(map[string]int), - IdxToCategory: make(map[string]string), + CategoryToIdx: make(map[string]int), + IdxToCategory: make(map[string]string), + CategorySystemPrompts: response.CategorySystemPrompts, + CategoryDescriptions: response.CategoryDescriptions, } for idx, category := range response.Categories { @@ -328,7 +330,13 @@ func (m *MCPCategoryClassifier) ListCategories(ctx context.Context) (*CategoryMa mapping.IdxToCategory[fmt.Sprintf("%d", idx)] = category } - observability.Infof("Loaded %d categories from MCP server: %v", len(response.Categories), response.Categories) + if len(response.CategorySystemPrompts) > 0 { + observability.Infof("Loaded %d categories with %d system prompts from MCP server: %v", + len(response.Categories), len(response.CategorySystemPrompts), response.Categories) + } else { + observability.Infof("Loaded %d categories from MCP server: %v", len(response.Categories), response.Categories) + } + return mapping, nil } diff --git a/src/semantic-router/pkg/utils/classification/mcp_classifier_test.go b/src/semantic-router/pkg/utils/classification/mcp_classifier_test.go index 9a693784..3bad1b9b 100644 --- a/src/semantic-router/pkg/utils/classification/mcp_classifier_test.go +++ b/src/semantic-router/pkg/utils/classification/mcp_classifier_test.go @@ -487,6 +487,116 @@ var _ = Describe("MCP Category Classifier", func() { }) }) + Context("when MCP tool returns categories with per-category system prompts", func() { + It("should store system prompts in mapping", func() { + mockClient.callToolResult = &mcp.CallToolResult{ + IsError: false, + Content: []mcp.Content{ + mcp.TextContent{ + Type: "text", + Text: `{ + "categories": ["math", "science", "technology"], + "category_system_prompts": { + "math": "You are a mathematics expert. Show step-by-step solutions.", + "science": "You are a science expert. Provide evidence-based answers.", + "technology": "You are a technology expert. Include practical examples." + }, + "category_descriptions": { + "math": "Mathematical and computational queries", + "science": "Scientific concepts and queries", + "technology": "Technology and computing topics" + } + }`, + }, + }, + } + mapping, err := mcpClassifier.ListCategories(context.Background()) + Expect(err).ToNot(HaveOccurred()) + Expect(mapping).ToNot(BeNil()) + Expect(mapping.CategoryToIdx).To(HaveLen(3)) + + // Verify system prompts are stored + Expect(mapping.CategorySystemPrompts).ToNot(BeNil()) + Expect(mapping.CategorySystemPrompts).To(HaveLen(3)) + + mathPrompt, ok := mapping.GetCategorySystemPrompt("math") + Expect(ok).To(BeTrue()) + Expect(mathPrompt).To(ContainSubstring("mathematics expert")) + + sciencePrompt, ok := mapping.GetCategorySystemPrompt("science") + Expect(ok).To(BeTrue()) + Expect(sciencePrompt).To(ContainSubstring("science expert")) + + techPrompt, ok := mapping.GetCategorySystemPrompt("technology") + Expect(ok).To(BeTrue()) + Expect(techPrompt).To(ContainSubstring("technology expert")) + + // Verify descriptions are stored + Expect(mapping.CategoryDescriptions).ToNot(BeNil()) + Expect(mapping.CategoryDescriptions).To(HaveLen(3)) + + mathDesc, ok := mapping.GetCategoryDescription("math") + Expect(ok).To(BeTrue()) + Expect(mathDesc).To(Equal("Mathematical and computational queries")) + }) + }) + + Context("when MCP tool returns categories without system prompts", func() { + It("should handle missing system prompts gracefully", func() { + mockClient.callToolResult = &mcp.CallToolResult{ + IsError: false, + Content: []mcp.Content{ + mcp.TextContent{ + Type: "text", + Text: `{"categories": ["math", "science"]}`, + }, + }, + } + mapping, err := mcpClassifier.ListCategories(context.Background()) + Expect(err).ToNot(HaveOccurred()) + Expect(mapping).ToNot(BeNil()) + Expect(mapping.CategoryToIdx).To(HaveLen(2)) + + // System prompts should be nil or empty + mathPrompt, ok := mapping.GetCategorySystemPrompt("math") + Expect(ok).To(BeFalse()) + Expect(mathPrompt).To(Equal("")) + }) + }) + + Context("when MCP tool returns partial system prompts", func() { + It("should store only provided system prompts", func() { + mockClient.callToolResult = &mcp.CallToolResult{ + IsError: false, + Content: []mcp.Content{ + mcp.TextContent{ + Type: "text", + Text: `{ + "categories": ["math", "science", "history"], + "category_system_prompts": { + "math": "You are a mathematics expert.", + "science": "You are a science expert." + } + }`, + }, + }, + } + mapping, err := mcpClassifier.ListCategories(context.Background()) + Expect(err).ToNot(HaveOccurred()) + Expect(mapping).ToNot(BeNil()) + Expect(mapping.CategoryToIdx).To(HaveLen(3)) + Expect(mapping.CategorySystemPrompts).To(HaveLen(2)) + + mathPrompt, ok := mapping.GetCategorySystemPrompt("math") + Expect(ok).To(BeTrue()) + Expect(mathPrompt).To(ContainSubstring("mathematics expert")) + + historyPrompt, ok := mapping.GetCategorySystemPrompt("history") + Expect(ok).To(BeFalse()) + Expect(historyPrompt).To(Equal("")) + }) + }) + Context("when MCP tool returns error", func() { It("should return error", func() { mockClient.callToolResult = &mcp.CallToolResult{ @@ -535,6 +645,79 @@ var _ = Describe("MCP Category Classifier", func() { }) }) }) + + Describe("CategoryMapping System Prompt Methods", func() { + var mapping *CategoryMapping + + BeforeEach(func() { + mapping = &CategoryMapping{ + CategoryToIdx: map[string]int{"math": 0, "science": 1, "tech": 2}, + IdxToCategory: map[string]string{"0": "math", "1": "science", "2": "tech"}, + CategorySystemPrompts: map[string]string{ + "math": "You are a mathematics expert. Show step-by-step solutions.", + "science": "You are a science expert. Provide evidence-based answers.", + }, + CategoryDescriptions: map[string]string{ + "math": "Mathematical queries", + "science": "Scientific queries", + "tech": "Technology queries", + }, + } + }) + + Describe("GetCategorySystemPrompt", func() { + Context("when category has system prompt", func() { + It("should return the prompt", func() { + prompt, ok := mapping.GetCategorySystemPrompt("math") + Expect(ok).To(BeTrue()) + Expect(prompt).To(Equal("You are a mathematics expert. Show step-by-step solutions.")) + }) + }) + + Context("when category exists but has no system prompt", func() { + It("should return empty string and false", func() { + prompt, ok := mapping.GetCategorySystemPrompt("tech") + Expect(ok).To(BeFalse()) + Expect(prompt).To(Equal("")) + }) + }) + + Context("when category does not exist", func() { + It("should return empty string and false", func() { + prompt, ok := mapping.GetCategorySystemPrompt("nonexistent") + Expect(ok).To(BeFalse()) + Expect(prompt).To(Equal("")) + }) + }) + + Context("when CategorySystemPrompts is nil", func() { + It("should return empty string and false", func() { + mapping.CategorySystemPrompts = nil + prompt, ok := mapping.GetCategorySystemPrompt("math") + Expect(ok).To(BeFalse()) + Expect(prompt).To(Equal("")) + }) + }) + }) + + Describe("GetCategoryDescription", func() { + Context("when category has description", func() { + It("should return the description", func() { + desc, ok := mapping.GetCategoryDescription("math") + Expect(ok).To(BeTrue()) + Expect(desc).To(Equal("Mathematical queries")) + }) + }) + + Context("when category does not have description", func() { + It("should return empty string and false", func() { + desc, ok := mapping.GetCategoryDescription("nonexistent") + Expect(ok).To(BeFalse()) + Expect(desc).To(Equal("")) + }) + }) + }) + }) }) var _ = Describe("Classifier MCP Methods", func() { @@ -565,6 +748,16 @@ var _ = Describe("Classifier MCP Methods", func() { CategoryMapping: &CategoryMapping{ CategoryToIdx: map[string]int{"tech": 0, "sports": 1, "politics": 2}, IdxToCategory: map[string]string{"0": "tech", "1": "sports", "2": "politics"}, + CategorySystemPrompts: map[string]string{ + "tech": "You are a technology expert. Include practical examples.", + "sports": "You are a sports expert. Provide game analysis.", + "politics": "You are a politics expert. Provide balanced perspectives.", + }, + CategoryDescriptions: map[string]string{ + "tech": "Technology and computing topics", + "sports": "Sports and athletics", + "politics": "Political topics and governance", + }, }, } }) @@ -777,3 +970,108 @@ var _ = Describe("MCP Helper Functions", func() { }) }) }) + +var _ = Describe("Classifier Per-Category System Prompts", func() { + var classifier *Classifier + + BeforeEach(func() { + cfg := &config.RouterConfig{} + cfg.Classifier.MCPCategoryModel.Enabled = true + + classifier = &Classifier{ + Config: cfg, + CategoryMapping: &CategoryMapping{ + CategoryToIdx: map[string]int{"math": 0, "science": 1, "tech": 2}, + IdxToCategory: map[string]string{"0": "math", "1": "science", "2": "tech"}, + CategorySystemPrompts: map[string]string{ + "math": "You are a mathematics expert. Show step-by-step solutions with clear explanations.", + "science": "You are a science expert. Provide evidence-based answers grounded in research.", + "tech": "You are a technology expert. Include practical examples and code snippets.", + }, + CategoryDescriptions: map[string]string{ + "math": "Mathematical and computational queries", + "science": "Scientific concepts and queries", + "tech": "Technology and computing topics", + }, + }, + } + }) + + Describe("GetCategorySystemPrompt", func() { + Context("when category exists with system prompt", func() { + It("should return the category-specific system prompt", func() { + prompt, ok := classifier.GetCategorySystemPrompt("math") + Expect(ok).To(BeTrue()) + Expect(prompt).To(ContainSubstring("mathematics expert")) + Expect(prompt).To(ContainSubstring("step-by-step solutions")) + }) + }) + + Context("when requesting different categories", func() { + It("should return different system prompts for each category", func() { + mathPrompt, ok := classifier.GetCategorySystemPrompt("math") + Expect(ok).To(BeTrue()) + + sciencePrompt, ok := classifier.GetCategorySystemPrompt("science") + Expect(ok).To(BeTrue()) + + techPrompt, ok := classifier.GetCategorySystemPrompt("tech") + Expect(ok).To(BeTrue()) + + // Verify they are different + Expect(mathPrompt).ToNot(Equal(sciencePrompt)) + Expect(mathPrompt).ToNot(Equal(techPrompt)) + Expect(sciencePrompt).ToNot(Equal(techPrompt)) + + // Verify each has category-specific content + Expect(mathPrompt).To(ContainSubstring("mathematics")) + Expect(sciencePrompt).To(ContainSubstring("science")) + Expect(techPrompt).To(ContainSubstring("technology")) + }) + }) + + Context("when category does not exist", func() { + It("should return empty string and false", func() { + prompt, ok := classifier.GetCategorySystemPrompt("nonexistent") + Expect(ok).To(BeFalse()) + Expect(prompt).To(Equal("")) + }) + }) + + Context("when CategoryMapping is nil", func() { + It("should return empty string and false", func() { + classifier.CategoryMapping = nil + prompt, ok := classifier.GetCategorySystemPrompt("math") + Expect(ok).To(BeFalse()) + Expect(prompt).To(Equal("")) + }) + }) + }) + + Describe("GetCategoryDescription", func() { + Context("when category has description", func() { + It("should return the description", func() { + desc, ok := classifier.GetCategoryDescription("math") + Expect(ok).To(BeTrue()) + Expect(desc).To(Equal("Mathematical and computational queries")) + }) + }) + + Context("when category does not exist", func() { + It("should return empty string and false", func() { + desc, ok := classifier.GetCategoryDescription("nonexistent") + Expect(ok).To(BeFalse()) + Expect(desc).To(Equal("")) + }) + }) + + Context("when CategoryMapping is nil", func() { + It("should return empty string and false", func() { + classifier.CategoryMapping = nil + desc, ok := classifier.GetCategoryDescription("math") + Expect(ok).To(BeFalse()) + Expect(desc).To(Equal("")) + }) + }) + }) +})