Skip to content

Commit 37f7d83

Browse files
mrubenscte
andauthored
Add xAI provider (RooCodeInc#2667)
* Add xAI provider * Add model reasoning effort * DRY this up * Handle undefined delta * Cleanup getModel to fix test * Add missing translations * Small type cleanup * Support temperature --------- Co-authored-by: cte <[email protected]>
1 parent 3b19d7a commit 37f7d83

File tree

27 files changed

+621
-23
lines changed

27 files changed

+621
-23
lines changed

src/api/index.ts

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ import { UnboundHandler } from "./providers/unbound"
2121
import { RequestyHandler } from "./providers/requesty"
2222
import { HumanRelayHandler } from "./providers/human-relay"
2323
import { FakeAIHandler } from "./providers/fake-ai"
24+
import { XAIHandler } from "./providers/xai"
2425

2526
export interface SingleCompletionHandler {
2627
completePrompt(prompt: string): Promise<string>
@@ -78,6 +79,8 @@ export function buildApiHandler(configuration: ApiConfiguration): ApiHandler {
7879
return new HumanRelayHandler(options)
7980
case "fake-ai":
8081
return new FakeAIHandler(options)
82+
case "xai":
83+
return new XAIHandler(options)
8184
default:
8285
return new AnthropicHandler(options)
8386
}
Lines changed: 292 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,292 @@
1+
import { XAIHandler } from "../xai"
2+
import { xaiDefaultModelId, xaiModels } from "../../../shared/api"
3+
import OpenAI from "openai"
4+
import { Anthropic } from "@anthropic-ai/sdk"
5+
6+
// Mock OpenAI client
7+
jest.mock("openai", () => {
8+
const createMock = jest.fn()
9+
return jest.fn(() => ({
10+
chat: {
11+
completions: {
12+
create: createMock,
13+
},
14+
},
15+
}))
16+
})
17+
18+
describe("XAIHandler", () => {
19+
let handler: XAIHandler
20+
let mockCreate: jest.Mock
21+
22+
beforeEach(() => {
23+
// Reset all mocks
24+
jest.clearAllMocks()
25+
26+
// Get the mock create function
27+
mockCreate = (OpenAI as unknown as jest.Mock)().chat.completions.create
28+
29+
// Create handler with mock
30+
handler = new XAIHandler({})
31+
})
32+
33+
test("should use the correct X.AI base URL", () => {
34+
expect(OpenAI).toHaveBeenCalledWith(
35+
expect.objectContaining({
36+
baseURL: "https://api.x.ai/v1",
37+
}),
38+
)
39+
})
40+
41+
test("should use the provided API key", () => {
42+
// Clear mocks before this specific test
43+
jest.clearAllMocks()
44+
45+
// Create a handler with our API key
46+
const xaiApiKey = "test-api-key"
47+
new XAIHandler({ xaiApiKey })
48+
49+
// Verify the OpenAI constructor was called with our API key
50+
expect(OpenAI).toHaveBeenCalledWith(
51+
expect.objectContaining({
52+
apiKey: xaiApiKey,
53+
}),
54+
)
55+
})
56+
57+
test("should return default model when no model is specified", () => {
58+
const model = handler.getModel()
59+
expect(model.id).toBe(xaiDefaultModelId)
60+
expect(model.info).toEqual(xaiModels[xaiDefaultModelId])
61+
})
62+
63+
test("should return specified model when valid model is provided", () => {
64+
const testModelId = "grok-2-latest"
65+
const handlerWithModel = new XAIHandler({ apiModelId: testModelId })
66+
const model = handlerWithModel.getModel()
67+
68+
expect(model.id).toBe(testModelId)
69+
expect(model.info).toEqual(xaiModels[testModelId])
70+
})
71+
72+
test("should include reasoning_effort parameter for mini models", async () => {
73+
const miniModelHandler = new XAIHandler({
74+
apiModelId: "grok-3-mini-beta",
75+
reasoningEffort: "high",
76+
})
77+
78+
// Setup mock for streaming response
79+
mockCreate.mockImplementationOnce(() => {
80+
return {
81+
[Symbol.asyncIterator]: () => ({
82+
async next() {
83+
return { done: true }
84+
},
85+
}),
86+
}
87+
})
88+
89+
// Start generating a message
90+
const messageGenerator = miniModelHandler.createMessage("test prompt", [])
91+
await messageGenerator.next() // Start the generator
92+
93+
// Check that reasoning_effort was included
94+
expect(mockCreate).toHaveBeenCalledWith(
95+
expect.objectContaining({
96+
reasoning_effort: "high",
97+
}),
98+
)
99+
})
100+
101+
test("should not include reasoning_effort parameter for non-mini models", async () => {
102+
const regularModelHandler = new XAIHandler({
103+
apiModelId: "grok-2-latest",
104+
reasoningEffort: "high",
105+
})
106+
107+
// Setup mock for streaming response
108+
mockCreate.mockImplementationOnce(() => {
109+
return {
110+
[Symbol.asyncIterator]: () => ({
111+
async next() {
112+
return { done: true }
113+
},
114+
}),
115+
}
116+
})
117+
118+
// Start generating a message
119+
const messageGenerator = regularModelHandler.createMessage("test prompt", [])
120+
await messageGenerator.next() // Start the generator
121+
122+
// Check call args for reasoning_effort
123+
const calls = mockCreate.mock.calls
124+
const lastCall = calls[calls.length - 1][0]
125+
expect(lastCall).not.toHaveProperty("reasoning_effort")
126+
})
127+
128+
test("completePrompt method should return text from OpenAI API", async () => {
129+
const expectedResponse = "This is a test response"
130+
131+
mockCreate.mockResolvedValueOnce({
132+
choices: [
133+
{
134+
message: {
135+
content: expectedResponse,
136+
},
137+
},
138+
],
139+
})
140+
141+
const result = await handler.completePrompt("test prompt")
142+
expect(result).toBe(expectedResponse)
143+
})
144+
145+
test("should handle errors in completePrompt", async () => {
146+
const errorMessage = "API error"
147+
mockCreate.mockRejectedValueOnce(new Error(errorMessage))
148+
149+
await expect(handler.completePrompt("test prompt")).rejects.toThrow(`xAI completion error: ${errorMessage}`)
150+
})
151+
152+
test("createMessage should yield text content from stream", async () => {
153+
const testContent = "This is test content"
154+
155+
// Setup mock for streaming response
156+
mockCreate.mockImplementationOnce(() => {
157+
return {
158+
[Symbol.asyncIterator]: () => ({
159+
next: jest
160+
.fn()
161+
.mockResolvedValueOnce({
162+
done: false,
163+
value: {
164+
choices: [{ delta: { content: testContent } }],
165+
},
166+
})
167+
.mockResolvedValueOnce({ done: true }),
168+
}),
169+
}
170+
})
171+
172+
// Create and consume the stream
173+
const stream = handler.createMessage("system prompt", [])
174+
const firstChunk = await stream.next()
175+
176+
// Verify the content
177+
expect(firstChunk.done).toBe(false)
178+
expect(firstChunk.value).toEqual({
179+
type: "text",
180+
text: testContent,
181+
})
182+
})
183+
184+
test("createMessage should yield reasoning content from stream", async () => {
185+
const testReasoning = "Test reasoning content"
186+
187+
// Setup mock for streaming response
188+
mockCreate.mockImplementationOnce(() => {
189+
return {
190+
[Symbol.asyncIterator]: () => ({
191+
next: jest
192+
.fn()
193+
.mockResolvedValueOnce({
194+
done: false,
195+
value: {
196+
choices: [{ delta: { reasoning_content: testReasoning } }],
197+
},
198+
})
199+
.mockResolvedValueOnce({ done: true }),
200+
}),
201+
}
202+
})
203+
204+
// Create and consume the stream
205+
const stream = handler.createMessage("system prompt", [])
206+
const firstChunk = await stream.next()
207+
208+
// Verify the reasoning content
209+
expect(firstChunk.done).toBe(false)
210+
expect(firstChunk.value).toEqual({
211+
type: "reasoning",
212+
text: testReasoning,
213+
})
214+
})
215+
216+
test("createMessage should yield usage data from stream", async () => {
217+
// Setup mock for streaming response that includes usage data
218+
mockCreate.mockImplementationOnce(() => {
219+
return {
220+
[Symbol.asyncIterator]: () => ({
221+
next: jest
222+
.fn()
223+
.mockResolvedValueOnce({
224+
done: false,
225+
value: {
226+
choices: [{ delta: {} }], // Needs to have choices array to avoid error
227+
usage: {
228+
prompt_tokens: 10,
229+
completion_tokens: 20,
230+
cache_read_input_tokens: 5,
231+
cache_creation_input_tokens: 15,
232+
},
233+
},
234+
})
235+
.mockResolvedValueOnce({ done: true }),
236+
}),
237+
}
238+
})
239+
240+
// Create and consume the stream
241+
const stream = handler.createMessage("system prompt", [])
242+
const firstChunk = await stream.next()
243+
244+
// Verify the usage data
245+
expect(firstChunk.done).toBe(false)
246+
expect(firstChunk.value).toEqual({
247+
type: "usage",
248+
inputTokens: 10,
249+
outputTokens: 20,
250+
cacheReadTokens: 5,
251+
cacheWriteTokens: 15,
252+
})
253+
})
254+
255+
test("createMessage should pass correct parameters to OpenAI client", async () => {
256+
// Setup a handler with specific model
257+
const modelId = "grok-2-latest"
258+
const modelInfo = xaiModels[modelId]
259+
const handlerWithModel = new XAIHandler({ apiModelId: modelId })
260+
261+
// Setup mock for streaming response
262+
mockCreate.mockImplementationOnce(() => {
263+
return {
264+
[Symbol.asyncIterator]: () => ({
265+
async next() {
266+
return { done: true }
267+
},
268+
}),
269+
}
270+
})
271+
272+
// System prompt and messages
273+
const systemPrompt = "Test system prompt"
274+
const messages: Anthropic.Messages.MessageParam[] = [{ role: "user", content: "Test message" }]
275+
276+
// Start generating a message
277+
const messageGenerator = handlerWithModel.createMessage(systemPrompt, messages)
278+
await messageGenerator.next() // Start the generator
279+
280+
// Check that all parameters were passed correctly
281+
expect(mockCreate).toHaveBeenCalledWith(
282+
expect.objectContaining({
283+
model: modelId,
284+
max_tokens: modelInfo.maxTokens,
285+
temperature: 0,
286+
messages: expect.arrayContaining([{ role: "system", content: systemPrompt }]),
287+
stream: true,
288+
stream_options: { include_usage: true },
289+
}),
290+
)
291+
})
292+
})

src/api/providers/constants.ts

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,12 @@
1+
export const DEFAULT_HEADERS = {
2+
"HTTP-Referer": "https://github.com/RooVetGit/Roo-Cline",
3+
"X-Title": "Roo Code",
4+
}
5+
16
export const ANTHROPIC_DEFAULT_MAX_TOKENS = 8192
27

38
export const DEEP_SEEK_DEFAULT_TEMPERATURE = 0.6
9+
10+
export const AZURE_AI_INFERENCE_PATH = "/models/chat/completions"
11+
12+
export const REASONING_MODELS = new Set(["x-ai/grok-3-mini-beta", "grok-3-mini-beta", "grok-3-mini-fast-beta"])

src/api/providers/openai.ts

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -15,17 +15,10 @@ import { convertToSimpleMessages } from "../transform/simple-format"
1515
import { ApiStream, ApiStreamUsageChunk } from "../transform/stream"
1616
import { BaseProvider } from "./base-provider"
1717
import { XmlMatcher } from "../../utils/xml-matcher"
18-
import { DEEP_SEEK_DEFAULT_TEMPERATURE } from "./constants"
19-
20-
export const defaultHeaders = {
21-
"HTTP-Referer": "https://github.com/RooVetGit/Roo-Cline",
22-
"X-Title": "Roo Code",
23-
}
18+
import { DEEP_SEEK_DEFAULT_TEMPERATURE, DEFAULT_HEADERS, AZURE_AI_INFERENCE_PATH } from "./constants"
2419

2520
export interface OpenAiHandlerOptions extends ApiHandlerOptions {}
2621

27-
const AZURE_AI_INFERENCE_PATH = "/models/chat/completions"
28-
2922
export class OpenAiHandler extends BaseProvider implements SingleCompletionHandler {
3023
protected options: OpenAiHandlerOptions
3124
private client: OpenAI
@@ -45,7 +38,7 @@ export class OpenAiHandler extends BaseProvider implements SingleCompletionHandl
4538
this.client = new OpenAI({
4639
baseURL,
4740
apiKey,
48-
defaultHeaders,
41+
defaultHeaders: DEFAULT_HEADERS,
4942
defaultQuery: { "api-version": this.options.azureApiVersion || "2024-05-01-preview" },
5043
})
5144
} else if (isAzureOpenAi) {
@@ -56,7 +49,7 @@ export class OpenAiHandler extends BaseProvider implements SingleCompletionHandl
5649
apiKey,
5750
apiVersion: this.options.azureApiVersion || azureOpenAiDefaultApiVersion,
5851
defaultHeaders: {
59-
...defaultHeaders,
52+
...DEFAULT_HEADERS,
6053
...(this.options.openAiHostHeader ? { Host: this.options.openAiHostHeader } : {}),
6154
},
6255
})
@@ -65,7 +58,7 @@ export class OpenAiHandler extends BaseProvider implements SingleCompletionHandl
6558
baseURL,
6659
apiKey,
6760
defaultHeaders: {
68-
...defaultHeaders,
61+
...DEFAULT_HEADERS,
6962
...(this.options.openAiHostHeader ? { Host: this.options.openAiHostHeader } : {}),
7063
},
7164
})

src/api/providers/openrouter.ts

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,9 @@ import { convertToOpenAiMessages } from "../transform/openai-format"
99
import { ApiStreamChunk, ApiStreamUsageChunk } from "../transform/stream"
1010
import { convertToR1Format } from "../transform/r1-format"
1111

12-
import { DEEP_SEEK_DEFAULT_TEMPERATURE } from "./constants"
12+
import { DEFAULT_HEADERS, DEEP_SEEK_DEFAULT_TEMPERATURE } from "./constants"
1313
import { getModelParams, SingleCompletionHandler } from ".."
1414
import { BaseProvider } from "./base-provider"
15-
import { defaultHeaders } from "./openai"
1615

1716
const OPENROUTER_DEFAULT_PROVIDER_NAME = "[default]"
1817

@@ -40,7 +39,7 @@ export class OpenRouterHandler extends BaseProvider implements SingleCompletionH
4039
const baseURL = this.options.openRouterBaseUrl || "https://openrouter.ai/api/v1"
4140
const apiKey = this.options.openRouterApiKey ?? "not-provided"
4241

43-
this.client = new OpenAI({ baseURL, apiKey, defaultHeaders })
42+
this.client = new OpenAI({ baseURL, apiKey, defaultHeaders: DEFAULT_HEADERS })
4443
}
4544

4645
override async *createMessage(

0 commit comments

Comments
 (0)