Skip to content

Commit 64508ef

Browse files
committed
fix:整理代码
1 parent 52d1f87 commit 64508ef

File tree

2 files changed

+160
-157
lines changed

2 files changed

+160
-157
lines changed

src/apiTools.ts

Lines changed: 71 additions & 98 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import {
77
DuckDuckGoSearchOptions,
88
DuckDuckGoSearchResponse,
99
} from '@agent-infra/duckduckgo-search';
10+
import { processDeepSeekResponse} from './deepseekApi';
1011

1112
/** 定义工具的接口 */
1213
export interface Tool {
@@ -57,32 +58,27 @@ export async function handleNativeFunctionCalling(
5758
temperature: temperature,
5859
});
5960

60-
if (streamMode) {
61-
let fullResponse = '';
62-
let toolCalls: Array<Partial<OpenAI.Chat.Completions.ChatCompletionMessageToolCall>> = [];
61+
// 工具调用收集器
62+
let toolCalls: Array<Partial<OpenAI.Chat.Completions.ChatCompletionMessageToolCall>> = [];
6363

64-
// 遍历流式响应的每个 chunk
65-
for await (const chunk of response as AsyncIterable<OpenAI.Chat.Completions.ChatCompletionChunk>) {
66-
if (abortSignal?.aborted) {
67-
return null;
68-
}
64+
// 统一处理响应
65+
const { chunkResponse, nativeToolCalls, completion } = await processDeepSeekResponse({
66+
streamMode,
67+
response: response as AsyncIterable<OpenAI.Chat.Completions.ChatCompletionChunk> | OpenAI.Chat.Completions.ChatCompletion,
68+
abortSignal,
69+
userStopException: 'operation stopped by user',
70+
infoMessage: 'Processing native tool calls...',
71+
outputChannel,
72+
processingMode: 'native-tools',
73+
onChunk: streamMode ? (chunk) => {
74+
// 流式工具调用处理
6975
const delta = chunk.choices[0]?.delta;
70-
// 累积文本内容
71-
if (delta?.content) {
72-
fullResponse += delta.content;
73-
if (outputChannel) {
74-
outputChannel.append(delta.content);
75-
}
76-
}
77-
// 累积工具调用信息
7876
if (delta?.tool_calls) {
7977
for (const toolCallDelta of delta.tool_calls) {
8078
const index = toolCallDelta.index;
8179
if (!toolCalls[index]) {
8280
toolCalls[index] = { type: "function", function: { name: "", arguments: "" } };
8381
}
84-
85-
// 类型安全地处理 id、name 和 arguments
8682
if (toolCallDelta.id && !toolCalls[index].id) {
8783
toolCalls[index].id = toolCallDelta.id;
8884
}
@@ -97,73 +93,53 @@ export async function handleNativeFunctionCalling(
9793
}
9894
}
9995
}
100-
}
96+
} : undefined
97+
});
10198

102-
// 过滤掉不完整的工具调用
103-
const completeToolCalls = toolCalls.filter(tc => tc.id && tc.function?.name);
99+
// 统一工具调用处理
100+
const resolvedToolCalls = streamMode
101+
? toolCalls.filter(tc => tc.id && tc.function?.name)
102+
: (completion?.choices[0].message.tool_calls || []);
104103

105-
if (completeToolCalls.length > 0) {
106-
// 构造助手消息,包括内容和工具调用
107-
const assistantMessage: OpenAI.ChatCompletionMessageParam = {
108-
role: "assistant",
109-
content: fullResponse,
110-
tool_calls: completeToolCalls.map(tc => ({
104+
if (resolvedToolCalls.length > 0) {
105+
// 构造助手消息
106+
const assistantMessage: OpenAI.ChatCompletionMessageParam = {
107+
role: "assistant",
108+
content: chunkResponse,
109+
...(streamMode ? {
110+
tool_calls: resolvedToolCalls.map(tc => ({
111111
id: tc.id!,
112112
type: tc.type!,
113113
function: {
114-
name: tc.function?.name!,
115-
arguments: tc.function?.arguments!,
116-
},
117-
})),
118-
};
119-
messages.push(assistantMessage);
114+
name: tc.function!.name,
115+
arguments: tc.function!.arguments!
116+
}
117+
}))
118+
} : {})
119+
};
120+
messages.push(assistantMessage);
120121

121-
// 处理每个工具调用
122-
for (const toolCall of completeToolCalls) {
123-
const tool = toolRegistry.get(toolCall.function?.name!);
124-
if (tool) {
125-
const args = JSON.parse(toolCall.function?.arguments!);
126-
const result = await tool.function(args);
127-
messages.push({
128-
role: "tool",
129-
tool_call_id: toolCall.id!,
130-
content: result,
131-
});
132-
}
122+
// 执行工具调用
123+
for (const toolCall of resolvedToolCalls) {
124+
const tool = toolRegistry.get(toolCall.function!.name);
125+
if (tool) {
126+
const args = JSON.parse(toolCall.function!.arguments!);
127+
const result = await tool.function(args);
128+
messages.push({
129+
role: "tool",
130+
tool_call_id: toolCall.id!,
131+
content: result,
132+
});
133133
}
134-
// 递归调用以继续对话
135-
return await handleNativeFunctionCalling(
136-
openai, modelName, messages, tools, streamMode, maxToken, temperature, outputChannel, abortSignal
137-
);
138-
} else {
139-
return fullResponse;
140134
}
141-
} else {
142-
// 非流式模式的处理保持不变
143-
const completion = response as OpenAI.Chat.Completions.ChatCompletion;
144-
const message = completion.choices[0].message;
145-
const fullResponse = message.content || '';
146135

147-
if (message.tool_calls) {
148-
messages.push(message);
149-
for (const toolCall of message.tool_calls) {
150-
const tool = toolRegistry.get(toolCall.function.name);
151-
if (tool) {
152-
const args = JSON.parse(toolCall.function.arguments);
153-
const result = await tool.function(args);
154-
messages.push({
155-
role: "tool",
156-
tool_call_id: toolCall.id,
157-
content: result,
158-
});
159-
}
160-
}
161-
return await handleNativeFunctionCalling(
162-
openai, modelName, messages, tools, streamMode, maxToken, temperature, outputChannel, abortSignal
163-
);
164-
}
165-
return fullResponse;
136+
// 递归继续对话
137+
return await handleNativeFunctionCalling(
138+
openai, modelName, messages, tools, streamMode, maxToken, temperature, outputChannel, abortSignal
139+
);
166140
}
141+
142+
return chunkResponse;
167143
}
168144

169145
export async function handleSimulatedFunctionCalling(
@@ -177,8 +153,8 @@ export async function handleSimulatedFunctionCalling(
177153
outputChannel?: vscode.OutputChannel,
178154
abortSignal?: AbortSignal
179155
): Promise<string | null> {
156+
// 添加工具说明到系统提示
180157
if (!messages[0].content?.toString().includes("To call a tool")) {
181-
// 添加工具说明到系统提示
182158
const toolDescriptions = tools.map(tool =>
183159
`- ${tool.name}: ${tool.description}. Parameters: ${JSON.stringify(tool.parameters.properties)}`
184160
).join('\n');
@@ -194,28 +170,23 @@ export async function handleSimulatedFunctionCalling(
194170
temperature: temperature,
195171
});
196172

197-
let fullResponse = '';
198-
if (streamMode) {
199-
// 累积流式响应的内容
200-
for await (const chunk of response as AsyncIterable<OpenAI.Chat.Completions.ChatCompletionChunk>) {
201-
if (abortSignal?.aborted) {
202-
return null;
203-
}
204-
const content = chunk.choices[0]?.delta?.content || '';
205-
fullResponse += content;
206-
if (outputChannel) {
207-
outputChannel.append(content);
208-
}
209-
}
210-
} else {
211-
fullResponse = (response as OpenAI.Chat.Completions.ChatCompletion).choices[0].message.content || '';
212-
}
173+
// 统一处理响应
174+
const { chunkResponse } = await processDeepSeekResponse({
175+
streamMode,
176+
response: response as AsyncIterable<OpenAI.Chat.Completions.ChatCompletionChunk> | OpenAI.Chat.Completions.ChatCompletion,
177+
abortSignal,
178+
userStopException: 'operation stopped by user',
179+
infoMessage: 'Processing simulated tools...',
180+
outputChannel,
181+
processingMode: 'simulated-tools'
182+
});
213183

214-
// 解析多个工具调用
215-
const toolCallMatches = fullResponse.matchAll(/<tool_call>(.*?)<\/tool_call>/gs);
184+
// 解析工具调用
185+
const toolCallMatches = chunkResponse.matchAll(/<tool_call>(.*?)<\/tool_call>/gs);
216186
const toolCalls = Array.from(toolCallMatches, match => JSON.parse(match[1]));
217187

218188
if (toolCalls.length > 0) {
189+
// 记录工具调用结果
219190
for (const toolCall of toolCalls) {
220191
const tool = toolRegistry.get(toolCall.name);
221192
if (tool) {
@@ -226,18 +197,20 @@ export async function handleSimulatedFunctionCalling(
226197
});
227198
}
228199
}
229-
// 递归调用以继续对话
200+
201+
// 递归继续对话
230202
return await handleSimulatedFunctionCalling(
231203
openai, modelName, messages, tools, streamMode, maxToken, temperature, outputChannel, abortSignal
232204
);
233205
}
234206

235-
return fullResponse;
207+
return chunkResponse;
236208
}
237209

238-
export function isToolsSupported(apiBaseURL: string): boolean {
210+
211+
export function isToolsSupported(apiBaseURL: string, modelName: string): boolean {
239212
// 示例:假设 DeepSeek 官方 URL 支持 tools
240-
return apiBaseURL === "https://api.deepseek.com";
213+
return apiBaseURL === "https://api.deepseek.com" && !modelName.includes("r1") && !modelName.includes("reasoner");
241214
}
242215

243216
// Define a minimal interface for search results (adjust based on actual response structure)

0 commit comments

Comments
 (0)