@@ -83,133 +83,139 @@ export class AiQueryService {
83
83
dto : SendAiQueryMessageDto ,
84
84
res : Response ,
85
85
) {
86
- let socket : Socket ;
86
+ return this . aiQueryAuthProvider . callWithAuthRetry ( sessionMetadata , async ( ) => {
87
+ let socket : Socket ;
87
88
88
- try {
89
- const auth = await this . aiQueryAuthProvider . getAuthData ( sessionMetadata ) ;
90
- const history = await this . aiQueryMessageRepository . list ( sessionMetadata , databaseId , auth . accountId ) ;
89
+ try {
90
+ const auth = await this . aiQueryAuthProvider . getAuthData ( sessionMetadata ) ;
91
+ const history = await this . aiQueryMessageRepository . list ( sessionMetadata , databaseId , auth . accountId ) ;
91
92
92
- const client = await this . databaseClientFactory . getOrCreateClient ( {
93
- sessionMetadata,
94
- databaseId,
95
- context : ClientContext . AI ,
96
- } ) ;
97
-
98
- let context = await this . aiQueryContextRepository . getFullDbContext ( sessionMetadata , databaseId , auth . accountId ) ;
99
-
100
- if ( ! context ) {
101
- context = await this . aiQueryContextRepository . setFullDbContext (
93
+ const client = await this . databaseClientFactory . getOrCreateClient ( {
102
94
sessionMetadata,
103
95
databaseId,
104
- auth . accountId ,
105
- await getFullDbContext ( client ) ,
106
- ) ;
107
- }
96
+ context : ClientContext . AI ,
97
+ } ) ;
108
98
109
- const question = classToClass ( AiQueryMessage , {
110
- type : AiQueryMessageType . HumanMessage ,
111
- content : dto . content ,
112
- databaseId,
113
- accountId : auth . accountId ,
114
- createdAt : new Date ( ) ,
115
- } ) ;
116
-
117
- const answer = classToClass ( AiQueryMessage , {
118
- type : AiQueryMessageType . AiMessage ,
119
- content : '' ,
120
- databaseId,
121
- accountId : auth . accountId ,
122
- } ) ;
123
-
124
- socket = await this . aiQueryProvider . getSocket ( auth ) ;
125
-
126
- socket . on ( AiQueryWsEvents . REPLY_CHUNK , ( chunk ) => {
127
- answer . content += chunk ;
128
- res . write ( chunk ) ;
129
- } ) ;
130
-
131
- socket . on ( AiQueryWsEvents . GET_INDEX , async ( index , cb ) => {
132
- try {
133
- const indexContext = await this . aiQueryContextRepository . getIndexContext (
99
+ let context = await this . aiQueryContextRepository . getFullDbContext ( sessionMetadata , databaseId , auth . accountId ) ;
100
+
101
+ if ( ! context ) {
102
+ context = await this . aiQueryContextRepository . setFullDbContext (
134
103
sessionMetadata ,
135
104
databaseId ,
136
105
auth . accountId ,
137
- index ,
106
+ await getFullDbContext ( client ) ,
138
107
) ;
108
+ }
109
+
110
+ const question = classToClass ( AiQueryMessage , {
111
+ type : AiQueryMessageType . HumanMessage ,
112
+ content : dto . content ,
113
+ databaseId,
114
+ accountId : auth . accountId ,
115
+ createdAt : new Date ( ) ,
116
+ } ) ;
117
+
118
+ const answer = classToClass ( AiQueryMessage , {
119
+ type : AiQueryMessageType . AiMessage ,
120
+ content : '' ,
121
+ databaseId,
122
+ accountId : auth . accountId ,
123
+ } ) ;
124
+
125
+ socket = await this . aiQueryProvider . getSocket ( sessionMetadata , auth ) ;
126
+
127
+ socket . on ( AiQueryWsEvents . REPLY_CHUNK , ( chunk ) => {
128
+ answer . content += chunk ;
129
+ res . write ( chunk ) ;
130
+ } ) ;
139
131
140
- if ( ! context ) {
141
- return cb ( await this . aiQueryContextRepository . setIndexContext (
132
+ socket . on ( AiQueryWsEvents . GET_INDEX , async ( index , cb ) => {
133
+ try {
134
+ const indexContext = await this . aiQueryContextRepository . getIndexContext (
142
135
sessionMetadata ,
143
136
databaseId ,
144
137
auth . accountId ,
145
138
index ,
146
- await getIndexContext ( client , index ) ,
147
- ) ) ;
139
+ ) ;
140
+
141
+ if ( ! indexContext ) {
142
+ return cb ( await this . aiQueryContextRepository . setIndexContext (
143
+ sessionMetadata ,
144
+ databaseId ,
145
+ auth . accountId ,
146
+ index ,
147
+ await getIndexContext ( client , index ) ,
148
+ ) ) ;
149
+ }
150
+
151
+ return cb ( indexContext ) ;
152
+ } catch ( e ) {
153
+ this . logger . warn ( 'Unable to create index content' , e ) ;
154
+ return cb ( e . message ) ;
148
155
}
149
-
150
- return cb ( indexContext ) ;
151
- } catch ( e ) {
152
- this . logger . warn ( 'Unable to create index content' , e ) ;
153
- return cb ( e . message ) ;
154
- }
155
- } ) ;
156
-
157
- socket . on ( AiQueryWsEvents . RUN_QUERY , async ( data , cb ) => {
158
- try {
159
- if ( ! COMMANDS_WHITELIST [ ( data ?. [ 0 ] || '' ) . toLowerCase ( ) ] ) {
160
- return cb ( '-ERR: This command is not allowed' ) ;
156
+ } ) ;
157
+
158
+ socket . on ( AiQueryWsEvents . RUN_QUERY , async ( data , cb ) => {
159
+ try {
160
+ if ( ! COMMANDS_WHITELIST [ ( data ?. [ 0 ] || '' ) . toLowerCase ( ) ] ) {
161
+ return cb ( '-ERR: This command is not allowed' ) ;
162
+ }
163
+
164
+ return cb ( await client . sendCommand ( data , { replyEncoding : 'utf8' } ) ) ;
165
+ } catch ( e ) {
166
+ this . logger . warn ( 'Query execution error' , e ) ;
167
+ return cb ( e . message ) ;
161
168
}
162
-
163
- return cb ( await client . sendCommand ( data , { replyEncoding : 'utf8' } ) ) ;
164
- } catch ( e ) {
165
- this . logger . warn ( 'Query execution error' , e ) ;
166
- return cb ( e . message ) ;
167
- }
168
- } ) ;
169
-
170
- socket . on ( AiQueryWsEvents . TOOL_CALL , async ( data ) => {
171
- answer . steps . push ( plainToClass ( AiQueryIntermediateStep , {
172
- type : AiQueryIntermediateStepType . TOOL_CALL ,
173
- data,
174
- } ) ) ;
175
- } ) ;
176
-
177
- socket . on ( AiQueryWsEvents . TOOL_REPLY , async ( data ) => {
178
- answer . steps . push ( plainToClass ( AiQueryIntermediateStep , {
179
- type : AiQueryIntermediateStepType . TOOL ,
180
- data,
181
- } ) ) ;
182
- } ) ;
183
-
184
- await socket . emitWithAck ( 'stream' , dto . content , context , AiQueryService . prepareHistory ( history ) ) ;
185
- socket . close ( ) ;
186
- await this . aiQueryMessageRepository . createMany ( sessionMetadata , [ question , answer ] ) ;
187
-
188
- return res . end ( ) ;
189
- } catch ( e ) {
190
- socket ?. close ?.( ) ;
191
- throw wrapAiQueryError ( e , 'Unable to send the question' ) ;
192
- }
169
+ } ) ;
170
+
171
+ socket . on ( AiQueryWsEvents . TOOL_CALL , async ( data ) => {
172
+ answer . steps . push ( plainToClass ( AiQueryIntermediateStep , {
173
+ type : AiQueryIntermediateStepType . TOOL_CALL ,
174
+ data,
175
+ } ) ) ;
176
+ } ) ;
177
+
178
+ socket . on ( AiQueryWsEvents . TOOL_REPLY , async ( data ) => {
179
+ answer . steps . push ( plainToClass ( AiQueryIntermediateStep , {
180
+ type : AiQueryIntermediateStepType . TOOL ,
181
+ data,
182
+ } ) ) ;
183
+ } ) ;
184
+
185
+ await socket . emitWithAck ( 'stream' , dto . content , context , AiQueryService . prepareHistory ( history ) ) ;
186
+ socket . close ( ) ;
187
+ await this . aiQueryMessageRepository . createMany ( sessionMetadata , [ question , answer ] ) ;
188
+
189
+ return res . end ( ) ;
190
+ } catch ( e ) {
191
+ socket ?. close ?.( ) ;
192
+ throw wrapAiQueryError ( e , 'Unable to send the question' ) ;
193
+ }
194
+ } ) ;
193
195
}
194
196
195
197
async getHistory ( sessionMetadata : SessionMetadata , databaseId : string ) : Promise < AiQueryMessage [ ] > {
196
- try {
197
- const auth = await this . aiQueryAuthProvider . getAuthData ( sessionMetadata ) ;
198
- return await this . aiQueryMessageRepository . list ( sessionMetadata , databaseId , auth . accountId ) ;
199
- } catch ( e ) {
200
- throw wrapAiQueryError ( e , 'Unable to get history' ) ;
201
- }
198
+ return this . aiQueryAuthProvider . callWithAuthRetry ( sessionMetadata , async ( ) => {
199
+ try {
200
+ const auth = await this . aiQueryAuthProvider . getAuthData ( sessionMetadata ) ;
201
+ return await this . aiQueryMessageRepository . list ( sessionMetadata , databaseId , auth . accountId ) ;
202
+ } catch ( e ) {
203
+ throw wrapAiQueryError ( e , 'Unable to get history' ) ;
204
+ }
205
+ } ) ;
202
206
}
203
207
204
208
async clearHistory ( sessionMetadata : SessionMetadata , databaseId : string ) : Promise < void > {
205
- try {
206
- const auth = await this . aiQueryAuthProvider . getAuthData ( sessionMetadata ) ;
209
+ return this . aiQueryAuthProvider . callWithAuthRetry ( sessionMetadata , async ( ) => {
210
+ try {
211
+ const auth = await this . aiQueryAuthProvider . getAuthData ( sessionMetadata ) ;
207
212
208
- await this . aiQueryContextRepository . reset ( sessionMetadata , databaseId , auth . accountId ) ;
213
+ await this . aiQueryContextRepository . reset ( sessionMetadata , databaseId , auth . accountId ) ;
209
214
210
- return this . aiQueryMessageRepository . clearHistory ( sessionMetadata , databaseId , auth . accountId ) ;
211
- } catch ( e ) {
212
- throw wrapAiQueryError ( e , 'Unable to clear history' ) ;
213
- }
215
+ return this . aiQueryMessageRepository . clearHistory ( sessionMetadata , databaseId , auth . accountId ) ;
216
+ } catch ( e ) {
217
+ throw wrapAiQueryError ( e , 'Unable to clear history' ) ;
218
+ }
219
+ } ) ;
214
220
}
215
221
}
0 commit comments