@@ -6,7 +6,12 @@ import { Collection, Events, Message, TextChannel } from 'discord.js';
6
6
import { tool , Tool , generateText , generateObject } from 'ai' ;
7
7
import { z } from 'zod' ;
8
8
import { getAiWorkerContext , runInAiWorkerContext } from './ai-context-worker' ;
9
- import { AiResponseSchema , pollSchema } from './schema' ;
9
+ import { getAvailableCommands } from './tools/get-available-commands' ;
10
+ import { getChannelById } from './tools/get-channel-by-id' ;
11
+ import { getCurrentClientInfo } from './tools/get-current-client-info' ;
12
+ import { getGuildById } from './tools/get-guild-by-id' ;
13
+ import { getUserById } from './tools/get-user-by-id' ;
14
+ import { createSystemPrompt } from './system-prompt' ;
10
15
11
16
type WithAI < T extends LoadedCommand > = T & {
12
17
data : {
@@ -30,55 +35,18 @@ export interface AiConfig {
30
35
parameters : any ;
31
36
}
32
37
33
- let messageFilter : MessageFilter | null = null ;
34
- let selectAiModel : SelectAiModel | null = null ;
35
- let generateSystemPrompt : ( ( message : Message ) => Promise < string > ) | undefined ;
36
-
37
- /**
38
- * Represents the configuration options for the AI model.
39
- */
40
- export interface ConfigureAI {
41
- /**
42
- * A filter function that determines whether a message should be processed by the AI.
43
- * CommandKit invokes this function before processing the message.
44
- */
45
- messageFilter ?: MessageFilter ;
46
- /**
47
- * A function that selects the AI model to use based on the message.
48
- * This function should return a promise that resolves to an object containing the model and options.
49
- */
50
- selectAiModel ?: SelectAiModel ;
51
- /**
52
- * A function that generates a system prompt based on the message.
53
- * This function should return a promise that resolves to a string containing the system prompt.
54
- * If not provided, a default system prompt will be used.
55
- */
56
- systemPrompt ?: ( message : Message ) => Promise < string > ;
57
- }
58
-
59
- /**
60
- * Configures the AI plugin with the provided options.
61
- * This function allows you to set a message filter, select an AI model, and generate a system prompt.
62
- * @param config The configuration options for the AI plugin.
63
- */
64
- export function configureAI ( config : ConfigureAI ) : void {
65
- if ( config . messageFilter ) {
66
- messageFilter = config . messageFilter ;
67
- }
68
-
69
- if ( config . selectAiModel ) {
70
- selectAiModel = config . selectAiModel ;
71
- }
72
-
73
- if ( config . systemPrompt ) {
74
- generateSystemPrompt = config . systemPrompt ;
75
- }
76
- }
38
+ const defaultTools : Record < string , Tool > = {
39
+ getAvailableCommands,
40
+ getChannelById,
41
+ getCurrentClientInfo,
42
+ getGuildById,
43
+ getUserById,
44
+ } ;
77
45
78
46
export class AiPlugin extends RuntimePlugin < AiPluginOptions > {
79
47
public readonly name = 'AiPlugin' ;
80
48
private toolsRecord : Record < string , Tool > = { } ;
81
- private defaultTools : Record < string , Tool > = { } ;
49
+ private defaultTools = defaultTools ;
82
50
private onMessageFunc : ( ( message : Message ) => Promise < void > ) | null = null ;
83
51
84
52
public constructor ( options : AiPluginOptions ) {
@@ -89,113 +57,9 @@ export class AiPlugin extends RuntimePlugin<AiPluginOptions> {
89
57
this . onMessageFunc = ( message ) => this . handleMessage ( ctx , message ) ;
90
58
ctx . commandkit . client . on ( Events . MessageCreate , this . onMessageFunc ) ;
91
59
92
- this . createDefaultTools ( ctx ) ;
93
-
94
60
Logger . info ( `Plugin ${ this . name } activated` ) ;
95
61
}
96
62
97
- private createDefaultTools ( ctx : CommandKitPluginRuntime ) : void {
98
- const { commandkit } = ctx ;
99
- const client = commandkit . client ;
100
-
101
- this . defaultTools . getUserById = tool ( {
102
- description : 'Get user information by ID' ,
103
- parameters : z . object ( {
104
- userId : z
105
- . string ( )
106
- . describe (
107
- 'The ID of the user to retrieve. This is a Discord snowflake string.' ,
108
- ) ,
109
- } ) ,
110
- execute : async ( params ) => {
111
- const user = await client . users . fetch ( params . userId , {
112
- force : false ,
113
- cache : true ,
114
- } ) ;
115
-
116
- return user . toJSON ( ) ;
117
- } ,
118
- } ) ;
119
-
120
- this . defaultTools . getChannelById = tool ( {
121
- description : 'Get channel information by ID' ,
122
- parameters : z . object ( {
123
- channelId : z
124
- . string ( )
125
- . describe (
126
- 'The ID of the channel to retrieve. This is a Discord snowflake string.' ,
127
- ) ,
128
- } ) ,
129
- execute : async ( params ) => {
130
- const channel = await client . channels . fetch ( params . channelId , {
131
- force : false ,
132
- cache : true ,
133
- } ) ;
134
-
135
- if ( ! channel ) {
136
- throw new Error ( `Channel with ID ${ params . channelId } not found.` ) ;
137
- }
138
-
139
- return channel . toJSON ( ) ;
140
- } ,
141
- } ) ;
142
-
143
- this . defaultTools . getGuildById = tool ( {
144
- description : 'Get guild information by ID' ,
145
- parameters : z . object ( {
146
- guildId : z
147
- . string ( )
148
- . describe (
149
- 'The ID of the guild to retrieve. This is a Discord snowflake string.' ,
150
- ) ,
151
- } ) ,
152
- execute : async ( params ) => {
153
- const guild = await client . guilds . fetch ( {
154
- guild : params . guildId ,
155
- force : false ,
156
- cache : true ,
157
- } ) ;
158
-
159
- if ( ! guild ) {
160
- throw new Error ( `Guild with ID ${ params . guildId } not found.` ) ;
161
- }
162
-
163
- return {
164
- id : guild . id ,
165
- name : guild . name ,
166
- icon : guild . iconURL ( ) ,
167
- memberCount : guild . memberCount ,
168
- } ;
169
- } ,
170
- } ) ;
171
-
172
- this . defaultTools . getCurrentUser = tool ( {
173
- description : 'Get information about the current discord bot user' ,
174
- parameters : z . object ( { } ) ,
175
- execute : async ( ) => {
176
- const user = client . user ;
177
-
178
- if ( ! user ) {
179
- throw new Error ( 'Bot user is not available.' ) ;
180
- }
181
-
182
- return user . toJSON ( ) ;
183
- } ,
184
- } ) ;
185
-
186
- this . defaultTools . getAvailableCommands = tool ( {
187
- description : 'Get all available commands' ,
188
- parameters : z . object ( { } ) ,
189
- execute : async ( ) => {
190
- return ctx . commandkit . commandHandler . getCommandsArray ( ) . map ( ( cmd ) => ( {
191
- name : cmd . data . command . name ,
192
- description : cmd . data . command . description ,
193
- category : cmd . command . category ,
194
- } ) ) ;
195
- } ,
196
- } ) ;
197
- }
198
-
199
63
public async deactivate ( ctx : CommandKitPluginRuntime ) : Promise < void > {
200
64
this . toolsRecord = { } ;
201
65
if ( this . onMessageFunc ) {
@@ -226,25 +90,7 @@ export class AiPlugin extends RuntimePlugin<AiPluginOptions> {
226
90
} ) ;
227
91
228
92
const systemPrompt =
229
- ( await generateSystemPrompt ?.( message ) ) ||
230
- `You are a helpful AI discord bot. Your name is ${ message . client . user . username } and your id is ${ message . client . user . id } .
231
- You are designed to assist users with their questions and tasks. You also have access to various tools that can help you perform tasks.
232
- Tools are basically like commands that you can execute to perform specific actions based on user input.
233
- Keep the response short and concise, and only use tools when necessary. Keep the response length under 2000 characters.
234
- Do not include your own text in the response unless necessary. For text formatting, you can use discord's markdown syntax.
235
- The current channel is ${
236
- 'name' in message . channel
237
- ? message . channel . name
238
- : message . channel . recipient ?. displayName || 'DM'
239
- } whose id is ${ message . channelId } . ${
240
- message . channel . isSendable ( )
241
- ? 'You can send messages in this channel.'
242
- : 'You cannot send messages in this channel.'
243
- }
244
- ${ message . inGuild ( ) ? `\nYou are currently in a guild named ${ message . guild . name } whose id is ${ message . guildId } . While in guild, you can fetch member information if needed.` : '\nYou are currently in a direct message with the user.' }
245
- If the user asks you to create a poll or embeds, create a text containing the poll or embed information as a markdown instead of json. If structured response is possible, use the structured response format instead.
246
- If the user asks you to perform a task that requires a tool, use the tool to perform the task and return the result. Reject any requests that are not related to the tools you have access to.
247
- ` ;
93
+ ( await generateSystemPrompt ?.( message ) ) || createSystemPrompt ( message ) ;
248
94
249
95
const userInfo = `<user>
250
96
<id>${ message . author . id } </id>
@@ -258,18 +104,10 @@ export class AiPlugin extends RuntimePlugin<AiPluginOptions> {
258
104
const stopTyping = await this . startTyping ( channel ) ;
259
105
260
106
try {
261
- const {
262
- model,
263
- options,
264
- objectMode = false ,
265
- } = await aiModelSelector ( message ) ;
107
+ const { model, options } = await aiModelSelector ( message ) ;
266
108
267
109
const originalPrompt = `${ userInfo } \nUser: ${ message . content } \nAI:` ;
268
110
269
- let result : Awaited <
270
- ReturnType < typeof generateText | typeof generateObject >
271
- > ;
272
-
273
111
const config = {
274
112
model,
275
113
abortSignal : AbortSignal . timeout ( 60_000 ) ,
@@ -278,71 +116,15 @@ export class AiPlugin extends RuntimePlugin<AiPluginOptions> {
278
116
providerOptions : options ,
279
117
} ;
280
118
281
- if ( objectMode ) {
282
- result = await generateObject ( {
283
- ...config ,
284
- schema : AiResponseSchema ,
285
- } ) ;
286
- } else {
287
- result = await generateText ( {
288
- ...config ,
289
- tools : { ...this . toolsRecord , ...this . defaultTools } ,
290
- maxSteps : 5 ,
291
- } ) ;
292
- }
119
+ const result = await generateText ( {
120
+ ...config ,
121
+ tools : { ...this . toolsRecord , ...this . defaultTools } ,
122
+ maxSteps : 5 ,
123
+ } ) ;
293
124
294
125
stopTyping ( ) ;
295
126
296
- let structuredResult : z . infer < typeof AiResponseSchema > | null = null ;
297
-
298
- structuredResult = ! ( 'text' in result )
299
- ? ( result . object as z . infer < typeof AiResponseSchema > )
300
- : null ;
301
-
302
- if ( structuredResult ) {
303
- const { poll, content, embed } = structuredResult ;
304
-
305
- if ( ! poll && ! content && ! embed ) {
306
- Logger . warn (
307
- 'AI response did not include any content, embed, or poll.' ,
308
- ) ;
309
- return ;
310
- }
311
-
312
- await message . reply ( {
313
- content : content ?. substring ( 0 , 2000 ) ,
314
- embeds : embed
315
- ? [
316
- {
317
- title : embed . title ,
318
- description : embed . description ,
319
- url : embed . url ,
320
- color : embed . color ,
321
- image : embed . image ? { url : embed . image } : undefined ,
322
- thumbnail : embed . thumbnailImage
323
- ? { url : embed . thumbnailImage }
324
- : undefined ,
325
- fields : embed . fields ?. map ( ( field ) => ( {
326
- name : field . name ,
327
- value : field . value ,
328
- inline : field . inline ,
329
- } ) ) ,
330
- } ,
331
- ]
332
- : [ ] ,
333
- poll : poll
334
- ? {
335
- allowMultiselect : poll . allow_multiselect ,
336
- answers : poll . answers . map ( ( answer ) => ( {
337
- text : answer . text ,
338
- emoji : answer . emoji ,
339
- } ) ) ,
340
- duration : poll . duration ,
341
- question : { text : poll . question . text } ,
342
- }
343
- : undefined ,
344
- } ) ;
345
- } else if ( 'text' in result && ! ! result . text ) {
127
+ if ( ! ! result . text ) {
346
128
await message . reply ( {
347
129
content : result . text . substring ( 0 , 2000 ) ,
348
130
allowedMentions : { parse : [ ] } ,
0 commit comments