diff --git a/package-lock.json b/package-lock.json index 28edbad..57fe12a 100644 --- a/package-lock.json +++ b/package-lock.json @@ -11,7 +11,9 @@ "dependencies": { "@modelcontextprotocol/sdk": "^1.0.4", "dotenv": "^16.4.7", - "ollama": "^0.5.11" + "ollama": "^0.5.11", + "zod": "^3.24.1", + "zod-to-json-schema": "^3.24.1" }, "devDependencies": { "@types/node": "^20.0.0", @@ -839,6 +841,14 @@ "funding": { "url": "https://github.com/sponsors/colinhacks" } + }, + "node_modules/zod-to-json-schema": { + "version": "3.24.1", + "resolved": "https://registry.npmjs.org/zod-to-json-schema/-/zod-to-json-schema-3.24.1.tgz", + "integrity": "sha512-3h08nf3Vw3Wl3PK+q3ow/lIil81IT2Oa7YpQyUUDsEWbXveMesdfK1xBd2RhCkynwZndAxixji/7SYJJowr62w==", + "peerDependencies": { + "zod": "^3.24.1" + } } } } diff --git a/package.json b/package.json index df4e231..f7768eb 100644 --- a/package.json +++ b/package.json @@ -28,6 +28,8 @@ "dependencies": { "@modelcontextprotocol/sdk": "^1.0.4", "dotenv": "^16.4.7", - "ollama": "^0.5.11" + "ollama": "^0.5.11", + "zod": "^3.24.1", + "zod-to-json-schema": "^3.24.1" } } diff --git a/src/index.ts b/src/index.ts index 240da4c..82e23f8 100644 --- a/src/index.ts +++ b/src/index.ts @@ -1,13 +1,11 @@ -import ollama from "ollama"; +import ollama, { Message } from "ollama"; import { Client } from "@modelcontextprotocol/sdk/client/index.js"; import { StdioClientTransport } from "@modelcontextprotocol/sdk/client/stdio.js"; -import { - ReadResourceResultSchema, - ListResourcesResultSchema, - CallToolResultSchema, -} from "@modelcontextprotocol/sdk/types.js"; +import { CallToolResultSchema } from "@modelcontextprotocol/sdk/types.js"; import dotenv from "dotenv"; import { fileURLToPath } from "url"; +import { z } from "zod"; +import { zodToJsonSchema } from "zod-to-json-schema"; // Load environment variables from .env file dotenv.config(); @@ -21,50 +19,165 @@ if (!databaseUrl) { process.exit(1); } -interface DatabaseSchema { - column_name: string; - data_type: string; +const SYSTEM_PROMPT = ` +You have access to a PostgreSQL database. +Use your knowledge of SQL to present an SQL query that will answer the user's question. +The user will execute the query and share the results with you. +Use the query results to verify that the query is correct. +If there are no query results, your answer is not verfied. +If the query results actually answer the question, mark the answer verified, +and provide a good human-readable answer. + +You can use the results to refine your query, if your previous answer was insufficient. +Always include the SQL query in your response. + +If the user tells you that there was an MCP error, analyze the error and respond with a different query. +`; + +const USER_PROMPT = ` +I have a database with the following tables: + +[ + {"table_name": "action_item_status_history"}, + {"table_name": "application_key"}, + {"table_name": "board"}, + {"table_name": "alembic_version"}, + {"table_name": "archival_measurement"}, + {"table_name": "application_audit"}, + {"table_name": "combined_load"}, + {"table_name": "customer_operating_preferences_staging"}, + {"table_name": "application"}, + {"table_name": "bank_account"}, + {"table_name": "baseline_value"}, + {"table_name": "control_profile"}, + {"table_name": "archival_facility_measurement"}, + {"table_name": "account_manager"}, + {"table_name": "action_item"}, + {"table_name": "board_access_control"}, + {"table_name": "email_facility"}, + {"table_name": "email"}, + {"table_name": "device"}, + {"table_name": "dwolla_customer"}, + {"table_name": "facility_contact_association"}, + {"table_name": "facility_operating_preferences"}, + {"table_name": "facility_operating_preferences_account_default"}, + {"table_name": "facility_operating_preferences_account_default_staging"}, + {"table_name": "facility"}, + {"table_name": "facility_enablement"}, + {"table_name": "facility_operating_preferences_staging"}, + {"table_name": "program"}, + {"table_name": "hubspot_email_user"}, + {"table_name": "meter"}, + {"table_name": "firmware_update"}, + {"table_name": "foobar"}, + {"table_name": "identity_role"}, + {"table_name": "facility_transfers"}, + {"table_name": "organization"}, + {"table_name": "feature_access"}, + {"table_name": "generator"}, + {"table_name": "hubspot_email"}, + {"table_name": "program_facility_association"}, + {"table_name": "historical_program_facility_association"}, + {"table_name": "interval_staging"}, + {"table_name": "line_item_event_association"}, + {"table_name": "line_item"}, + {"table_name": "market_timezone_override"}, + {"table_name": "meter_configuration"}, + {"table_name": "miso_lmr_price_offer_selection"}, + {"table_name": "portfolio"}, + {"table_name": "opportunity"}, + {"table_name": "portfolio_facilities"}, + {"table_name": "portfolio_type"}, + {"table_name": "program_geography_association_temp"}, + {"table_name": "program_tmp_migration"}, + {"table_name": "request_job"}, + {"table_name": "meter_provider_configuration"}, + {"table_name": "meter_provider"}, + {"table_name": "payment_program_association"}, + {"table_name": "portfolio_applications"}, + {"table_name": "portfolio_metadata"}, + {"table_name": "program_zipcode_association"}, + {"table_name": "registration_dispatch_performance"}, + {"table_name": "registration_potential_value"}, + {"table_name": "permission"}, + {"table_name": "role_permissions"}, + {"table_name": "portfolio_users"}, + {"table_name": "settlement_payment"}, + {"table_name": "ses_email_user"}, + {"table_name": "settlement_payment_transition_reason"}, + {"table_name": "ses_email"}, + {"table_name": "user_alert_configuration"}, + {"table_name": "user_alert_notification"}, + {"table_name": "temp_portfolio"}, + {"table_name": "scheduled_event"}, + {"table_name": "settlement_baseline_value"}, + {"table_name": "user_audit_impl"}, + {"table_name": "user"}, + {"table_name": "settlement_facility_load"}, + {"table_name": "user_activation_audit"}, + {"table_name": "user_query"}, + {"table_name": "voltus_opportunity_product"}, + {"table_name": "vcrm_group_registration"}, + {"table_name": "vendor_payment"}, + {"table_name": "voltlet_configuration"}, + {"table_name": "event_facility_association"}, + {"table_name": "event_acknowledgment"}, + {"table_name": "action_item_attempt"}, + {"table_name": "utility_account"}, + {"table_name": "role"}, + {"table_name": "line_item_transition_log"}, + {"table_name": "openadr_settings"}, + {"table_name": "settlement_payment_transition_log"} +] + +I have a question: + +`; + +interface SqlModelResponse { + sqlQuery: string; + isVerified: boolean; + answerSummary: string; + getFormattedAnswer(): string; } -interface ColumnMetadata { - description: string; - examples: string[]; - foreignKey?: { - table: string; - column: string; - }; +class JsonSqlModelResponse implements SqlModelResponse { + sqlQuery: string; + isVerified: boolean; + answerSummary: string; + + constructor(sqlQuery: string, isVerified: boolean, answerSummary: string) { + this.sqlQuery = sqlQuery; + this.isVerified = isVerified; + this.answerSummary = answerSummary; + } + + getFormattedAnswer(): string { + let formattedAnswer = `${this.answerSummary}. + + This was obtained using the following query: + + \`\`\`sql + ${this.sqlQuery} + \`\`\` + `; + + if (this.isVerified) { + return formattedAnswer; + } else { + return `${formattedAnswer} (Answer is not verified.)`; + } + } } class OllamaMCPHost { private client: Client; private transport: StdioClientTransport; private modelName: string; - private schemaCache: Map = new Map(); - private columnMetadata: Map> = new Map(); private chatHistory: { role: string; content: string }[] = []; private readonly MAX_HISTORY_LENGTH = 20; private readonly MAX_RETRIES = 5; - private static readonly QUERY_GUIDELINES = ` -When analyzing questions: -1. First write a SQL query to get the necessary information. Identify which tables contain the relevant information by looking at: - - Table names and their purposes - - Column names and descriptions - - Foreign key relationships -2. Use the 'query' tool to execute the SQL query -3. If unsure about table contents, write a sample query first: - SELECT column_name, COUNT(*) FROM table_name GROUP BY column_name LIMIT 5; -4. For complex questions, break down into multiple queries: - - First query to validate data availability - - Second query to get detailed information -5. Always include appropriate JOIN conditions when combining tables -6. Use WHERE clauses to filter irrelevant data -7. Consider using ORDER BY for sorted results - -Important: Only use SELECT statements - no modifications allowed! - -When you are finished, analyze the results and provide a natural language response.`; - constructor(modelName?: string) { this.modelName = modelName || process.env.OLLAMA_MODEL || "qwen2.5-coder:7b-instruct"; @@ -78,124 +191,8 @@ When you are finished, analyze the results and provide a natural language respon ); } - private async detectTableRelationships(): Promise { - // Query the database to find foreign key relationships - const sql = ` - SELECT - tc.table_name as table_name, - kcu.column_name as column_name, - ccu.table_name AS foreign_table_name, - ccu.column_name AS foreign_column_name - FROM information_schema.table_constraints tc - JOIN information_schema.key_column_usage kcu - ON tc.constraint_name = kcu.constraint_name - JOIN information_schema.constraint_column_usage ccu - ON ccu.constraint_name = tc.constraint_name - WHERE constraint_type = 'FOREIGN KEY' - `; - - try { - const result = await this.executeQuery(sql); - const relationships = JSON.parse(result); - - // Create initial metadata for foreign keys - relationships.forEach((rel: any) => { - const tableMetadata = - this.columnMetadata.get(rel.table_name) || new Map(); - - tableMetadata.set(rel.column_name, { - description: `Foreign key referencing ${rel.foreign_table_name}.${rel.foreign_column_name}`, - examples: [], - foreignKey: { - table: rel.foreign_table_name, - column: rel.foreign_column_name, - }, - }); - - this.columnMetadata.set(rel.table_name, tableMetadata); - }); - } catch (error) { - console.error("Error detecting table relationships:", error); - } - } - - private buildSystemPrompt(includeErrorContext: string = ""): string { - let prompt = - "You are a data analyst assistant. You have access to a PostgreSQL database with these tables:\n\n"; - - // Add detailed schema information - for (const [tableName, schema] of this.schemaCache.entries()) { - prompt += `Table: ${tableName}\n`; - prompt += "Columns:\n"; - - for (const column of schema) { - const metadata = this.columnMetadata - .get(tableName) - ?.get(column.column_name); - prompt += `- ${column.column_name} (${column.data_type})`; - - if (metadata) { - prompt += `: ${metadata.description}`; - if (metadata.foreignKey) { - prompt += ` [References ${metadata.foreignKey.table}.${metadata.foreignKey.column}]`; - } - } - prompt += "\n"; - } - prompt += "\n"; - } - - // Add query guidelines - prompt += "\nQuery Guidelines:\n"; - prompt += OllamaMCPHost.QUERY_GUIDELINES; - - if (includeErrorContext) { - prompt += `\nPrevious Error Context: ${includeErrorContext}\n`; - prompt += - "Please revise your approach and try a different query strategy.\n"; - } - - return prompt; - } - async connect() { await this.client.connect(this.transport); - - // First detect relationships - await this.detectTableRelationships(); - - // Then load schemas - const resources = await this.client.request( - { method: "resources/list" }, - ListResourcesResultSchema - ); - - for (const resource of resources.resources) { - if (resource.uri.endsWith("/schema")) { - const schema = await this.client.request( - { - method: "resources/read", - params: { uri: resource.uri }, - }, - ReadResourceResultSchema - ); - - if (schema.contents[0]?.text) { - try { - const tableName = resource.uri.split("/").slice(-2)[0]; - this.schemaCache.set( - tableName, - JSON.parse(schema.contents[0].text as string) - ); - } catch (error) { - console.error( - `Failed to parse schema for resource ${resource.uri}:`, - error instanceof Error ? error.message : String(error) - ); - } - } - } - } } private async executeQuery(sql: string): Promise { @@ -223,71 +220,92 @@ When you are finished, analyze the results and provide a natural language respon } } + private async queryModelJson(messages: Message[]): Promise { + const loc = "queryModelJson(): "; + + const SqlJsonContent = z.object({ + sqlQuery: z.string(), + isVerified: z.boolean(), + answerSummary: z.string(), + }) + + // Get response from Ollama + const response = await ollama.chat({ + model: this.modelName, + messages: messages, + format: zodToJsonSchema(SqlJsonContent) + }); + + const content = response.message.content; + this.addToHistory("assistant", content); + // console.log(loc + `response: ${content}`); + + try { + const parsedData = SqlJsonContent.parse(JSON.parse(content)); + return new JsonSqlModelResponse(parsedData.sqlQuery, parsedData.isVerified, parsedData.answerSummary); + } catch (error) { + throw new Error(`Could not extract SQL from this response: ${response.message.content} because of error: ${error}`); + } + } + async processQuestion(question: string): Promise { + const loc = "processQuestion(): "; + try { let attemptCount = 0; - let lastError: string | undefined; while (attemptCount <= this.MAX_RETRIES) { const messages = [ - { role: "system", content: this.buildSystemPrompt(lastError) }, + { role: "system", content: SYSTEM_PROMPT }, + { + role: "user", + content: `${USER_PROMPT}${question}`, + }, ...this.chatHistory, - { role: "user", content: question }, ]; - if (attemptCount === 0) { - this.addToHistory("user", question); - } - console.log( attemptCount > 0 ? `\nRetry attempt ${attemptCount}...` : "" ); - // Get response from Ollama - const response = await ollama.chat({ - model: this.modelName, - messages: messages, - }); + let sqlModelResponse = null; - // Extract SQL query - const sqlMatch = response.message.content.match( - /```sql\n([\s\S]*?)\n```/ - ); - if (!sqlMatch) { - return response.message.content; + try { + sqlModelResponse = await this.queryModelJson(messages); + } catch (error) { + const errorMessage = + error instanceof Error ? error.message : String(error); + this.addToHistory("user", errorMessage); + return errorMessage; + } + + if (!sqlModelResponse) { + console.log(loc + `Skipping invalid sqlModelResponse: ${sqlModelResponse}`); + continue; } - const sql = sqlMatch[1].trim(); - console.log("Executing SQL:", sql); + if (sqlModelResponse.isVerified) { + return sqlModelResponse.getFormattedAnswer(); + } try { // Execute the query - const queryResult = await this.executeQuery(sql); - this.addToHistory("assistant", response.message.content); - - // Ask for result interpretation - const interpretationMessages = [ - ...messages, - { role: "assistant", content: response.message.content }, - { - role: "user", - content: `Here are the results of the SQL query: ${queryResult}\n\nPlease analyze these results and provide a clear summary.`, - }, - ]; - - const finalResponse = await ollama.chat({ - model: this.modelName, - messages: interpretationMessages, - }); - - this.addToHistory("assistant", finalResponse.message.content); - return finalResponse.message.content; + const queryResult = await this.executeQuery(sqlModelResponse.sqlQuery); + + console.log(loc + "Result from executing query: " + queryResult); + + this.addToHistory( + "user", + `Here are the results of the SQL query: ${queryResult}` + ); } catch (error) { - lastError = error instanceof Error ? error.message : String(error); + const errorMessage = + error instanceof Error ? error.message : String(error); + this.addToHistory("user", errorMessage); if (attemptCount === this.MAX_RETRIES) { return `I apologize, but I was unable to successfully query the database after ${ this.MAX_RETRIES + 1 - } attempts. The last error was: ${lastError}`; + } attempts. The last error was: ${errorMessage}`; } } @@ -320,7 +338,7 @@ async function main() { console.log( "\nConnected to database. You can now ask questions about your data." ); - console.log('Type "exit" to quit.\n'); + console.log('Type "/exit" to quit.\n'); const askQuestion = (prompt: string) => new Promise((resolve) => { @@ -332,7 +350,7 @@ async function main() { "\nWhat would you like to know about your data? " ); - if (userInput.toLowerCase() === "exit") { + if (userInput.toLowerCase().includes("/exit")) { console.log("\nGoodbye!\n"); readline.close(); await host.cleanup();