Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 11 additions & 16 deletions packages/server/src/bin.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@

import { createRequire } from "node:module";
import { Command } from "commander";
import { createDocsServer } from "./create.js";
import { startStdioServer } from "./stdio.js";
import { startHttpServer } from "./http.js";
import { createDocsMcpServerFactory } from "./create.js";

const require = createRequire(import.meta.url);
const SERVER_VERSION = readPackageVersion();
Expand Down Expand Up @@ -68,7 +68,14 @@ program
}))
: [];

const app = await createDocsServer({
const serverName =
options.name === "@speakeasy-api/docs-mcp-server" && options.toolPrefix
? `${options.toolPrefix}-docs-server`
: options.name;

const mcpServerFactory = await createDocsMcpServerFactory({
serverName,
serverVersion: options.version,
indexDir: options.indexDir,
toolPrefix: options.toolPrefix,
queryEmbeddingApiKey: options.queryEmbeddingApiKey,
Expand All @@ -80,22 +87,10 @@ program
...(customTools.length > 0 ? { customTools } : {}),
});

const serverName =
options.name === "@speakeasy-api/docs-mcp-server" && options.toolPrefix
? `${options.toolPrefix}-docs-server`
: options.name;

if (options.transport === "http") {
await startHttpServer(app, {
name: serverName,
version: options.version,
port: options.port,
});
await startHttpServer(mcpServerFactory, { port: options.port });
} else {
await startStdioServer(app, {
name: serverName,
version: options.version,
});
await startStdioServer(mcpServerFactory);
}
});

Expand Down
37 changes: 26 additions & 11 deletions packages/server/src/create.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@ import {
type EmbeddingProvider,
type SearchEngine,
} from "@speakeasy-api/docs-mcp-core";
import { McpDocsServer } from "./server.js";
import { createMcpServer } from "./server.js";
import { McpServer } from "@modelcontextprotocol/sdk/server/mcp.js";

const TaxonomyFieldSchema = z
.object({
Expand Down Expand Up @@ -113,6 +114,12 @@ const CustomToolSchema = z.object({

/** Zod schema for `createDocsServer()` options. Consumers can use this to validate config. */
export const CreateDocsServerOptionsSchema = z.object({
/** Name of the MCP server. */
serverName: z.string().min(1, "serverName must be a non-empty string"),

/** Version of the MCP server. */
serverVersion: z.string().min(1, "serverVersion must be a non-empty string"),

/** Directory containing chunks.json and metadata.json produced by `docs-mcp build`. */
indexDir: z.string().min(1, "indexDir must be a non-empty string"),

Expand Down Expand Up @@ -179,9 +186,9 @@ export type CreateDocsServerOptions = z.output<typeof CreateDocsServerOptionsSch
* opens the search engine, and returns a server ready to be passed to `startStdioServer()` or
* `startHttpServer()`.
*/
export async function createDocsServer(
export async function createDocsMcpServerFactory(
input: CreateDocsServerOptionsInput,
): Promise<McpDocsServer> {
): Promise<() => McpServer> {
const options = CreateDocsServerOptionsSchema.parse(input);

const indexDir = path.resolve(options.indexDir);
Expand Down Expand Up @@ -238,14 +245,22 @@ export async function createDocsServer(

const index = await loadSearchEngine(loadInput);

return new McpDocsServer({
index,
metadata,
vectorSearchAvailable:
queryEmbeddingProvider !== undefined && queryEmbeddingProvider.name !== "hash",
...(options.toolPrefix ? { toolPrefix: options.toolPrefix } : {}),
...(options.customTools.length > 0 ? { customTools: options.customTools } : {}),
});
return () => {
return createMcpServer({
mcp: {
name: options.serverName,
version: options.serverVersion,
},
app: {
index,
metadata,
vectorSearchAvailable:
queryEmbeddingProvider !== undefined && queryEmbeddingProvider.name !== "hash",
...(options.toolPrefix ? { toolPrefix: options.toolPrefix } : {}),
...(options.customTools.length > 0 ? { customTools: options.customTools } : {}),
},
});
};
}

async function loadSearchEngine(input: {
Expand Down
156 changes: 23 additions & 133 deletions packages/server/src/http.ts
Original file line number Diff line number Diff line change
@@ -1,31 +1,11 @@
import crypto from "node:crypto";
import http from "node:http";
import { createRequire } from "node:module";
import { McpServer } from "@modelcontextprotocol/sdk/server/mcp.js";
import type { Transport } from "@modelcontextprotocol/sdk/shared/transport.js";
import { StreamableHTTPServerTransport } from "@modelcontextprotocol/sdk/server/streamableHttp.js";
import {
CallToolRequestSchema,
GetPromptRequestSchema,
ListPromptsRequestSchema,
ListResourcesRequestSchema,
ListResourceTemplatesRequestSchema,
ListToolsRequestSchema,
ReadResourceRequestSchema,
type GetPromptResult,
type ListPromptsResult,
type ListToolsResult,
type ListResourcesResult,
type ListResourceTemplatesResult,
} from "@modelcontextprotocol/sdk/types.js";
import type { AuthInfo, ToolCallContext, ToolProvider } from "./types.js";

const require = createRequire(import.meta.url);
const PKG_VERSION = readPackageVersion();
import type { AuthInfo } from "./types.js";

export interface StartHttpServerOptions {
name?: string;
version?: string;
port?: number;
/**
* Async hook called before each request is processed.
Expand Down Expand Up @@ -72,129 +52,43 @@ class SessionManager {
}
}

evict(sessionId: string): void {
async evict(sessionId: string): Promise<void> {
const entry = this.sessions.get(sessionId);
if (entry) {
this.sessions.delete(sessionId);
entry.transport.close();
entry.server.close().catch(() => {});
await entry.transport.close().catch(() => {});
await entry.server.close().catch(() => {});
}
}
}

function createMcpServer(
app: ToolProvider,
options: StartHttpServerOptions,
includeClientInfo = true,
): McpServer {
const instructions = app.getInstructions();
const server = new McpServer(
{
name: options.name ?? "@speakeasy-api/docs-mcp-server",
version: options.version ?? PKG_VERSION,
},
{
capabilities: {
tools: {},
resources: {},
prompts: {},
},
...(instructions ? { instructions } : {}),
},
);

server.server.setRequestHandler(ListToolsRequestSchema, async () => {
const tools = app.getTools().map((tool) => ({
name: tool.name,
description: tool.description,
inputSchema: tool.inputSchema,
}));
return { tools } satisfies ListToolsResult;
});

server.server.setRequestHandler(CallToolRequestSchema, async (request, extra) => {
const context: ToolCallContext = { signal: extra.signal };
if (extra.authInfo) {
context.authInfo = extra.authInfo;
}
if (extra.requestInfo?.headers) {
context.headers = extra.requestInfo.headers;
}
if (includeClientInfo) {
const clientVersion = server.server.getClientVersion();
if (clientVersion) {
context.clientInfo = { name: clientVersion.name, version: clientVersion.version };
}
}
return app.callTool(request.params.name, request.params.arguments ?? {}, context);
});

server.server.setRequestHandler(ListResourcesRequestSchema, async () => {
const resources = await app.getResources();
return {
resources: resources.map((r) => ({
uri: r.uri,
name: r.name,
title: r.title,
description: r.description,
mimeType: r.mimeType,
})),
} satisfies ListResourcesResult;
});

server.server.setRequestHandler(ListResourceTemplatesRequestSchema, async () => {
return { resourceTemplates: [] } satisfies ListResourceTemplatesResult;
});

server.server.setRequestHandler(ReadResourceRequestSchema, async (request) => {
const result = await app.readResource(request.params.uri);
return result;
});

server.server.setRequestHandler(ListPromptsRequestSchema, async () => {
return {
prompts: app.getPrompts(),
} satisfies ListPromptsResult;
});

server.server.setRequestHandler(GetPromptRequestSchema, async (request) => {
const result = await app.getPrompt(request.params.name, request.params.arguments);
return result as GetPromptResult;
});

return server;
}

function createSessionServer(
app: ToolProvider,
options: StartHttpServerOptions,
function createStatefulTransport(
server: McpServer,
sessionManager: SessionManager,
): { server: McpServer; transport: StreamableHTTPServerTransport } {
const server = createMcpServer(app, options);

): StreamableHTTPServerTransport {
const transport = new StreamableHTTPServerTransport({
sessionIdGenerator: () => crypto.randomUUID(),
onsessioninitialized: (sid: string) => {
sessionManager.add(sid, server, transport);
},
onsessionclosed: (sid: string) => {
sessionManager.evict(sid);
onsessionclosed: async (sid: string) => {
await sessionManager.evict(sid);
},
});

return { server, transport };
return transport;
}

export async function startHttpServer(
app: ToolProvider,
factory: () => McpServer,
options: StartHttpServerOptions = {},
): Promise<HttpServerHandle> {
const port = options.port ?? 20310;
const sessionManager = new SessionManager();

const httpServer = http.createServer(async (req, res) => {
try {
await handleRequest(req, res, app, options, sessionManager);
await handleRequest(factory, req, res, options, sessionManager);
} catch (error) {
console.error("Unhandled error in request handler:", error);
if (!res.headersSent) {
Expand Down Expand Up @@ -236,9 +130,9 @@ function setCorsHeaders(res: http.ServerResponse): void {
}

async function handleRequest(
factory: () => McpServer,
req: http.IncomingMessage,
res: http.ServerResponse,
app: ToolProvider,
options: StartHttpServerOptions,
sessionManager: SessionManager,
): Promise<void> {
Expand Down Expand Up @@ -333,11 +227,12 @@ async function handleRequest(
await entry.transport.handleRequest(req, res, parsed);
return;
}
await handleWithStatelessServer(req, res, parsed, app, options);
await handleWithStatelessServer(factory, req, res, parsed);
return;
}

const { server, transport } = createSessionServer(app, options, sessionManager);
const server = factory();
const transport = createStatefulTransport(server, sessionManager);
try {
await server.connect(transport as unknown as Transport);
await transport.handleRequest(req, res, parsed);
Expand All @@ -353,27 +248,27 @@ async function handleRequest(
}),
);
}
transport.close();
void server.close();

await transport.close().catch(() => {});
await server.close().catch(() => {});
}
}

async function handleWithStatelessServer(
factory: () => McpServer,
req: http.IncomingMessage,
res: http.ServerResponse,
parsed: unknown,
app: ToolProvider,
options: StartHttpServerOptions,
): Promise<void> {
const server = createMcpServer(app, options, false);
const transport = new StreamableHTTPServerTransport();
const server = factory();

try {
await server.connect(transport as unknown as Transport);
await transport.handleRequest(req, res, parsed);
} finally {
transport.close();
void server.close();
await transport.close().catch(() => {});
await server.close().catch(() => {});
}
}

Expand Down Expand Up @@ -439,8 +334,3 @@ function readBody(req: http.IncomingMessage): Promise<string> {
function getHeaderValue(header: string | string[] | undefined): string | undefined {
return typeof header === "string" ? header : undefined;
}

function readPackageVersion(): string {
const pkg = require("../package.json");
return typeof pkg?.version === "string" ? pkg.version : "0.0.0";
}
Loading
Loading