Skip to content

Commit 3f04de1

Browse files
committed
feat: use object generation for object mode in ai plugin
1 parent 586035d commit 3f04de1

File tree

3 files changed

+62
-80
lines changed

3 files changed

+62
-80
lines changed

apps/test-bot/src/ai.ts

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@ configureAI({
1111
selectAiModel: async () => {
1212
return {
1313
model,
14-
objectMode: true,
1514
};
1615
},
1716
messageFilter: async (message) => {

packages/ai/src/plugin.ts

Lines changed: 37 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,10 @@ import { AiPluginOptions, MessageFilter, SelectAiModel } from './types';
33
import { LoadedCommand, Logger } from 'commandkit';
44
import { AiContext } from './context';
55
import { Collection, Events, Message, TextChannel } from 'discord.js';
6-
import { tool, Tool, generateText, Output } from 'ai';
6+
import { tool, Tool, generateText, generateObject } from 'ai';
77
import { z } from 'zod';
88
import { getAiWorkerContext, runInAiWorkerContext } from './ai-context-worker';
9-
import { AiResponseSchema } from './schema';
9+
import { AiResponseSchema, pollSchema } from './schema';
1010

1111
type WithAI<T extends LoadedCommand> = T & {
1212
data: {
@@ -232,9 +232,18 @@ export class AiPlugin extends RuntimePlugin<AiPluginOptions> {
232232
Tools are basically like commands that you can execute to perform specific actions based on user input.
233233
Keep the response short and concise, and only use tools when necessary. Keep the response length under 2000 characters.
234234
Do not include your own text in the response unless necessary. For text formatting, you can use discord's markdown syntax.
235+
The current channel is ${
236+
'name' in message.channel
237+
? message.channel.name
238+
: message.channel.recipient?.displayName || 'DM'
239+
} whose id is ${message.channelId}. ${
240+
message.channel.isSendable()
241+
? 'You can send messages in this channel.'
242+
: 'You cannot send messages in this channel.'
243+
}
235244
${message.inGuild() ? `\nYou are currently in a guild named ${message.guild.name} whose id is ${message.guildId}. While in guild, you can fetch member information if needed.` : '\nYou are currently in a direct message with the user.'}
236245
If the user asks you to create a poll or embeds, create a text containing the poll or embed information as a markdown instead of json. If structured response is possible, use the structured response format instead.
237-
If the user asks you to perform a task that requires a tool, use the tool to perform the task and return the result.
246+
If the user asks you to perform a task that requires a tool, use the tool to perform the task and return the result. Reject any requests that are not related to the tools you have access to.
238247
`;
239248

240249
const userInfo = `<user>
@@ -257,65 +266,38 @@ export class AiPlugin extends RuntimePlugin<AiPluginOptions> {
257266

258267
const originalPrompt = `${userInfo}\nUser: ${message.content}\nAI:`;
259268

260-
const call = ({
261-
prompt = originalPrompt,
262-
includeTools = true,
263-
disableObjectMode = false,
264-
}) =>
265-
generateText({
266-
abortSignal: AbortSignal.timeout(60_000),
267-
model,
268-
...(includeTools && {
269-
tools: { ...this.toolsRecord, ...this.defaultTools },
270-
}),
271-
prompt,
272-
system: systemPrompt,
273-
maxSteps: 5,
274-
providerOptions: options,
275-
...(objectMode && !disableObjectMode
276-
? {
277-
experimental_output: Output.object({
278-
schema: AiResponseSchema,
279-
}),
280-
}
281-
: {}),
282-
});
283-
284-
let result: any;
285-
286-
try {
287-
result = await call({});
288-
} catch {
289-
if (objectMode) {
290-
const r1 = await call({
291-
includeTools: true,
292-
disableObjectMode: true,
293-
});
294-
295-
if (!r1.text) throw new Error('No text response from AI');
269+
let result: Awaited<
270+
ReturnType<typeof generateText | typeof generateObject>
271+
>;
296272

297-
const r2 = await call({
298-
includeTools: false,
299-
disableObjectMode: false,
300-
prompt: `Original context: ${originalPrompt} ${r1.text}\n\nGenerate a structured response based on the previous response`,
301-
});
273+
const config = {
274+
model,
275+
abortSignal: AbortSignal.timeout(60_000),
276+
prompt: originalPrompt,
277+
system: systemPrompt,
278+
providerOptions: options,
279+
};
302280

303-
result = r2;
304-
}
281+
if (objectMode) {
282+
result = await generateObject({
283+
...config,
284+
schema: AiResponseSchema,
285+
});
286+
} else {
287+
result = await generateText({
288+
...config,
289+
tools: { ...this.toolsRecord, ...this.defaultTools },
290+
maxSteps: 5,
291+
});
305292
}
306293

307294
stopTyping();
308295

309296
let structuredResult: z.infer<typeof AiResponseSchema> | null = null;
310297

311-
try {
312-
const val =
313-
'experimental_output' in result && result.experimental_output;
314-
315-
if (val) {
316-
structuredResult = val;
317-
}
318-
} catch {}
298+
structuredResult = !('text' in result)
299+
? (result.object as z.infer<typeof AiResponseSchema>)
300+
: null;
319301

320302
if (structuredResult) {
321303
const { poll, content, embed } = structuredResult;
@@ -360,7 +342,7 @@ export class AiPlugin extends RuntimePlugin<AiPluginOptions> {
360342
}
361343
: undefined,
362344
});
363-
} else if (!!result.text) {
345+
} else if ('text' in result && !!result.text) {
364346
await message.reply({
365347
content: result.text.substring(0, 2000),
366348
allowedMentions: { parse: [] },

packages/ai/src/schema.ts

Lines changed: 25 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,30 @@ const pollMediaObject = z
1313
'An object representing the media for a poll question, containing the text of the question. Emoji cannot be used in question text.',
1414
);
1515

16+
export const pollSchema = z
17+
.object({
18+
question: pollMediaObject,
19+
answers: z
20+
.array(pollMediaObject)
21+
.min(1)
22+
.max(10)
23+
.describe('An array of answers for the poll'),
24+
allow_multiselect: z
25+
.boolean()
26+
.optional()
27+
.default(false)
28+
.describe('Whether the poll allows multiple selections'),
29+
duration: z
30+
.number()
31+
.int()
32+
.min(1)
33+
.max(32)
34+
.optional()
35+
.default(24)
36+
.describe('The duration of the poll in hours'),
37+
})
38+
.describe('An object representing a poll to include in the message');
39+
1640
export const AiResponseSchema = z
1741
.object({
1842
content: z
@@ -95,30 +119,7 @@ export const AiResponseSchema = z
95119
.describe(
96120
'An object representing embeds to include in the discord message. This is an optional field.',
97121
),
98-
poll: z
99-
.object({
100-
question: pollMediaObject,
101-
answers: z
102-
.array(pollMediaObject)
103-
.min(1)
104-
.max(10)
105-
.describe('An array of answers for the poll'),
106-
allow_multiselect: z
107-
.boolean()
108-
.optional()
109-
.default(false)
110-
.describe('Whether the poll allows multiple selections'),
111-
duration: z
112-
.number()
113-
.int()
114-
.min(1)
115-
.max(32)
116-
.optional()
117-
.default(24)
118-
.describe('The duration of the poll in hours'),
119-
})
120-
.optional()
121-
.describe('An object representing a poll to include in the message'),
122+
poll: pollSchema.optional(),
122123
})
123124
.describe(
124125
'The schema for an AI response message to be sent to discord, including content and embeds. At least one of content, embeds, or poll must be present.',

0 commit comments

Comments
 (0)