Skip to content

Commit df49a47

Browse files
committed
refactor: ai wip
1 parent 3f04de1 commit df49a47

File tree

13 files changed

+336
-373
lines changed

13 files changed

+336
-373
lines changed

packages/ai/src/augmentation.ts

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
import { AiContext } from './context';
2+
import { Awaitable } from 'discord.js';
3+
4+
declare module 'commandkit' {
5+
interface CustomAppCommandProps {
6+
ai?: (ctx: AiContext) => Awaitable<unknown>;
7+
}
8+
}

packages/ai/src/configure.ts

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
import { Message } from 'discord.js';
2+
import { MessageFilter, SelectAiModel } from './types';
3+
import { createSystemPrompt } from './system-prompt';
4+
5+
/**
6+
* Represents the configuration options for the AI model.
7+
*/
8+
export interface ConfigureAI {
9+
/**
10+
* A filter function that determines whether a message should be processed by the AI.
11+
* CommandKit invokes this function before processing the message.
12+
*/
13+
messageFilter?: MessageFilter;
14+
/**
15+
* A function that selects the AI model to use based on the message.
16+
* This function should return a promise that resolves to an object containing the model and options.
17+
*/
18+
selectAiModel?: SelectAiModel;
19+
/**
20+
* A function that generates a system prompt based on the message.
21+
* This function should return a promise that resolves to a string containing the system prompt.
22+
* If not provided, a default system prompt will be used.
23+
*/
24+
systemPrompt?: (message: Message) => Promise<string>;
25+
/**
26+
* A function that prepares the prompt for the AI model.
27+
*/
28+
preparePrompt?: (message: Message) => Promise<string>;
29+
}
30+
31+
const AIConfig: Required<ConfigureAI> = {
32+
messageFilter: async (message) =>
33+
message.mentions.users.has(message.client.user.id),
34+
systemPrompt: async (message) => createSystemPrompt(message),
35+
async preparePrompt(message) {
36+
const userInfo = `<user>
37+
<id>${message.author.id}</id>
38+
<name>${message.author.username}</name>
39+
<displayName>${message.author.displayName}</displayName>
40+
<avatar>${message.author.avatarURL()}</avatar>
41+
</user>`;
42+
43+
return `${userInfo}\nUser: ${message.content}\nAI:`;
44+
},
45+
selectAiModel: async () => {
46+
throw new Error(
47+
'No AI model selected. Please configure the AI plugin using configureAI() function, making sure to include a selectAiModel function.',
48+
);
49+
},
50+
};
51+
52+
/**
53+
* Configures the AI plugin with the provided options.
54+
* This function allows you to set a message filter, select an AI model, and generate a system prompt.
55+
* @param config The configuration options for the AI plugin.
56+
*/
57+
export function configureAI(config: ConfigureAI): void {
58+
if (config.messageFilter) {
59+
AIConfig.messageFilter = config.messageFilter;
60+
}
61+
62+
if (config.selectAiModel) {
63+
AIConfig.selectAiModel = config.selectAiModel;
64+
}
65+
66+
if (config.systemPrompt) {
67+
AIConfig.systemPrompt = config.systemPrompt;
68+
}
69+
}

packages/ai/src/plugin.ts

Lines changed: 22 additions & 240 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,12 @@ import { Collection, Events, Message, TextChannel } from 'discord.js';
66
import { tool, Tool, generateText, generateObject } from 'ai';
77
import { z } from 'zod';
88
import { getAiWorkerContext, runInAiWorkerContext } from './ai-context-worker';
9-
import { AiResponseSchema, pollSchema } from './schema';
9+
import { getAvailableCommands } from './tools/get-available-commands';
10+
import { getChannelById } from './tools/get-channel-by-id';
11+
import { getCurrentClientInfo } from './tools/get-current-client-info';
12+
import { getGuildById } from './tools/get-guild-by-id';
13+
import { getUserById } from './tools/get-user-by-id';
14+
import { createSystemPrompt } from './system-prompt';
1015

1116
type WithAI<T extends LoadedCommand> = T & {
1217
data: {
@@ -30,55 +35,18 @@ export interface AiConfig {
3035
parameters: any;
3136
}
3237

33-
let messageFilter: MessageFilter | null = null;
34-
let selectAiModel: SelectAiModel | null = null;
35-
let generateSystemPrompt: ((message: Message) => Promise<string>) | undefined;
36-
37-
/**
38-
* Represents the configuration options for the AI model.
39-
*/
40-
export interface ConfigureAI {
41-
/**
42-
* A filter function that determines whether a message should be processed by the AI.
43-
* CommandKit invokes this function before processing the message.
44-
*/
45-
messageFilter?: MessageFilter;
46-
/**
47-
* A function that selects the AI model to use based on the message.
48-
* This function should return a promise that resolves to an object containing the model and options.
49-
*/
50-
selectAiModel?: SelectAiModel;
51-
/**
52-
* A function that generates a system prompt based on the message.
53-
* This function should return a promise that resolves to a string containing the system prompt.
54-
* If not provided, a default system prompt will be used.
55-
*/
56-
systemPrompt?: (message: Message) => Promise<string>;
57-
}
58-
59-
/**
60-
* Configures the AI plugin with the provided options.
61-
* This function allows you to set a message filter, select an AI model, and generate a system prompt.
62-
* @param config The configuration options for the AI plugin.
63-
*/
64-
export function configureAI(config: ConfigureAI): void {
65-
if (config.messageFilter) {
66-
messageFilter = config.messageFilter;
67-
}
68-
69-
if (config.selectAiModel) {
70-
selectAiModel = config.selectAiModel;
71-
}
72-
73-
if (config.systemPrompt) {
74-
generateSystemPrompt = config.systemPrompt;
75-
}
76-
}
38+
const defaultTools: Record<string, Tool> = {
39+
getAvailableCommands,
40+
getChannelById,
41+
getCurrentClientInfo,
42+
getGuildById,
43+
getUserById,
44+
};
7745

7846
export class AiPlugin extends RuntimePlugin<AiPluginOptions> {
7947
public readonly name = 'AiPlugin';
8048
private toolsRecord: Record<string, Tool> = {};
81-
private defaultTools: Record<string, Tool> = {};
49+
private defaultTools = defaultTools;
8250
private onMessageFunc: ((message: Message) => Promise<void>) | null = null;
8351

8452
public constructor(options: AiPluginOptions) {
@@ -89,113 +57,9 @@ export class AiPlugin extends RuntimePlugin<AiPluginOptions> {
8957
this.onMessageFunc = (message) => this.handleMessage(ctx, message);
9058
ctx.commandkit.client.on(Events.MessageCreate, this.onMessageFunc);
9159

92-
this.createDefaultTools(ctx);
93-
9460
Logger.info(`Plugin ${this.name} activated`);
9561
}
9662

97-
private createDefaultTools(ctx: CommandKitPluginRuntime): void {
98-
const { commandkit } = ctx;
99-
const client = commandkit.client;
100-
101-
this.defaultTools.getUserById = tool({
102-
description: 'Get user information by ID',
103-
parameters: z.object({
104-
userId: z
105-
.string()
106-
.describe(
107-
'The ID of the user to retrieve. This is a Discord snowflake string.',
108-
),
109-
}),
110-
execute: async (params) => {
111-
const user = await client.users.fetch(params.userId, {
112-
force: false,
113-
cache: true,
114-
});
115-
116-
return user.toJSON();
117-
},
118-
});
119-
120-
this.defaultTools.getChannelById = tool({
121-
description: 'Get channel information by ID',
122-
parameters: z.object({
123-
channelId: z
124-
.string()
125-
.describe(
126-
'The ID of the channel to retrieve. This is a Discord snowflake string.',
127-
),
128-
}),
129-
execute: async (params) => {
130-
const channel = await client.channels.fetch(params.channelId, {
131-
force: false,
132-
cache: true,
133-
});
134-
135-
if (!channel) {
136-
throw new Error(`Channel with ID ${params.channelId} not found.`);
137-
}
138-
139-
return channel.toJSON();
140-
},
141-
});
142-
143-
this.defaultTools.getGuildById = tool({
144-
description: 'Get guild information by ID',
145-
parameters: z.object({
146-
guildId: z
147-
.string()
148-
.describe(
149-
'The ID of the guild to retrieve. This is a Discord snowflake string.',
150-
),
151-
}),
152-
execute: async (params) => {
153-
const guild = await client.guilds.fetch({
154-
guild: params.guildId,
155-
force: false,
156-
cache: true,
157-
});
158-
159-
if (!guild) {
160-
throw new Error(`Guild with ID ${params.guildId} not found.`);
161-
}
162-
163-
return {
164-
id: guild.id,
165-
name: guild.name,
166-
icon: guild.iconURL(),
167-
memberCount: guild.memberCount,
168-
};
169-
},
170-
});
171-
172-
this.defaultTools.getCurrentUser = tool({
173-
description: 'Get information about the current discord bot user',
174-
parameters: z.object({}),
175-
execute: async () => {
176-
const user = client.user;
177-
178-
if (!user) {
179-
throw new Error('Bot user is not available.');
180-
}
181-
182-
return user.toJSON();
183-
},
184-
});
185-
186-
this.defaultTools.getAvailableCommands = tool({
187-
description: 'Get all available commands',
188-
parameters: z.object({}),
189-
execute: async () => {
190-
return ctx.commandkit.commandHandler.getCommandsArray().map((cmd) => ({
191-
name: cmd.data.command.name,
192-
description: cmd.data.command.description,
193-
category: cmd.command.category,
194-
}));
195-
},
196-
});
197-
}
198-
19963
public async deactivate(ctx: CommandKitPluginRuntime): Promise<void> {
20064
this.toolsRecord = {};
20165
if (this.onMessageFunc) {
@@ -226,25 +90,7 @@ export class AiPlugin extends RuntimePlugin<AiPluginOptions> {
22690
});
22791

22892
const systemPrompt =
229-
(await generateSystemPrompt?.(message)) ||
230-
`You are a helpful AI discord bot. Your name is ${message.client.user.username} and your id is ${message.client.user.id}.
231-
You are designed to assist users with their questions and tasks. You also have access to various tools that can help you perform tasks.
232-
Tools are basically like commands that you can execute to perform specific actions based on user input.
233-
Keep the response short and concise, and only use tools when necessary. Keep the response length under 2000 characters.
234-
Do not include your own text in the response unless necessary. For text formatting, you can use discord's markdown syntax.
235-
The current channel is ${
236-
'name' in message.channel
237-
? message.channel.name
238-
: message.channel.recipient?.displayName || 'DM'
239-
} whose id is ${message.channelId}. ${
240-
message.channel.isSendable()
241-
? 'You can send messages in this channel.'
242-
: 'You cannot send messages in this channel.'
243-
}
244-
${message.inGuild() ? `\nYou are currently in a guild named ${message.guild.name} whose id is ${message.guildId}. While in guild, you can fetch member information if needed.` : '\nYou are currently in a direct message with the user.'}
245-
If the user asks you to create a poll or embeds, create a text containing the poll or embed information as a markdown instead of json. If structured response is possible, use the structured response format instead.
246-
If the user asks you to perform a task that requires a tool, use the tool to perform the task and return the result. Reject any requests that are not related to the tools you have access to.
247-
`;
93+
(await generateSystemPrompt?.(message)) || createSystemPrompt(message);
24894

24995
const userInfo = `<user>
25096
<id>${message.author.id}</id>
@@ -258,18 +104,10 @@ export class AiPlugin extends RuntimePlugin<AiPluginOptions> {
258104
const stopTyping = await this.startTyping(channel);
259105

260106
try {
261-
const {
262-
model,
263-
options,
264-
objectMode = false,
265-
} = await aiModelSelector(message);
107+
const { model, options } = await aiModelSelector(message);
266108

267109
const originalPrompt = `${userInfo}\nUser: ${message.content}\nAI:`;
268110

269-
let result: Awaited<
270-
ReturnType<typeof generateText | typeof generateObject>
271-
>;
272-
273111
const config = {
274112
model,
275113
abortSignal: AbortSignal.timeout(60_000),
@@ -278,71 +116,15 @@ export class AiPlugin extends RuntimePlugin<AiPluginOptions> {
278116
providerOptions: options,
279117
};
280118

281-
if (objectMode) {
282-
result = await generateObject({
283-
...config,
284-
schema: AiResponseSchema,
285-
});
286-
} else {
287-
result = await generateText({
288-
...config,
289-
tools: { ...this.toolsRecord, ...this.defaultTools },
290-
maxSteps: 5,
291-
});
292-
}
119+
const result = await generateText({
120+
...config,
121+
tools: { ...this.toolsRecord, ...this.defaultTools },
122+
maxSteps: 5,
123+
});
293124

294125
stopTyping();
295126

296-
let structuredResult: z.infer<typeof AiResponseSchema> | null = null;
297-
298-
structuredResult = !('text' in result)
299-
? (result.object as z.infer<typeof AiResponseSchema>)
300-
: null;
301-
302-
if (structuredResult) {
303-
const { poll, content, embed } = structuredResult;
304-
305-
if (!poll && !content && !embed) {
306-
Logger.warn(
307-
'AI response did not include any content, embed, or poll.',
308-
);
309-
return;
310-
}
311-
312-
await message.reply({
313-
content: content?.substring(0, 2000),
314-
embeds: embed
315-
? [
316-
{
317-
title: embed.title,
318-
description: embed.description,
319-
url: embed.url,
320-
color: embed.color,
321-
image: embed.image ? { url: embed.image } : undefined,
322-
thumbnail: embed.thumbnailImage
323-
? { url: embed.thumbnailImage }
324-
: undefined,
325-
fields: embed.fields?.map((field) => ({
326-
name: field.name,
327-
value: field.value,
328-
inline: field.inline,
329-
})),
330-
},
331-
]
332-
: [],
333-
poll: poll
334-
? {
335-
allowMultiselect: poll.allow_multiselect,
336-
answers: poll.answers.map((answer) => ({
337-
text: answer.text,
338-
emoji: answer.emoji,
339-
})),
340-
duration: poll.duration,
341-
question: { text: poll.question.text },
342-
}
343-
: undefined,
344-
});
345-
} else if ('text' in result && !!result.text) {
127+
if (!!result.text) {
346128
await message.reply({
347129
content: result.text.substring(0, 2000),
348130
allowedMentions: { parse: [] },

0 commit comments

Comments
 (0)