Skip to content

Commit 1a974c7

Browse files
fix chat retries
1 parent 2b4524e commit 1a974c7

File tree

4 files changed

+119
-115
lines changed

4 files changed

+119
-115
lines changed

redisinsight/api/src/modules/ai/query/ai-query.service.ts

Lines changed: 93 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -83,113 +83,115 @@ export class AiQueryService {
8383
dto: SendAiQueryMessageDto,
8484
res: Response,
8585
) {
86-
let socket: Socket;
87-
88-
try {
89-
const auth = await this.aiQueryAuthProvider.getAuthData(sessionMetadata);
90-
const history = await this.aiQueryMessageRepository.list(sessionMetadata, databaseId, auth.accountId);
91-
92-
const client = await this.databaseClientFactory.getOrCreateClient({
93-
sessionMetadata,
94-
databaseId,
95-
context: ClientContext.AI,
96-
});
86+
return this.aiQueryAuthProvider.callWithAuthRetry(sessionMetadata, async () => {
87+
let socket: Socket;
9788

98-
let context = await this.aiQueryContextRepository.getFullDbContext(sessionMetadata, databaseId, auth.accountId);
89+
try {
90+
const auth = await this.aiQueryAuthProvider.getAuthData(sessionMetadata);
91+
const history = await this.aiQueryMessageRepository.list(sessionMetadata, databaseId, auth.accountId);
9992

100-
if (!context) {
101-
context = await this.aiQueryContextRepository.setFullDbContext(
93+
const client = await this.databaseClientFactory.getOrCreateClient({
10294
sessionMetadata,
10395
databaseId,
104-
auth.accountId,
105-
await getFullDbContext(client),
106-
);
107-
}
96+
context: ClientContext.AI,
97+
});
10898

109-
const question = classToClass(AiQueryMessage, {
110-
type: AiQueryMessageType.HumanMessage,
111-
content: dto.content,
112-
databaseId,
113-
accountId: auth.accountId,
114-
createdAt: new Date(),
115-
});
116-
117-
const answer = classToClass(AiQueryMessage, {
118-
type: AiQueryMessageType.AiMessage,
119-
content: '',
120-
databaseId,
121-
accountId: auth.accountId,
122-
});
123-
124-
socket = await this.aiQueryProvider.getSocket(sessionMetadata, auth);
125-
126-
socket.on(AiQueryWsEvents.REPLY_CHUNK, (chunk) => {
127-
answer.content += chunk;
128-
res.write(chunk);
129-
});
130-
131-
socket.on(AiQueryWsEvents.GET_INDEX, async (index, cb) => {
132-
try {
133-
const indexContext = await this.aiQueryContextRepository.getIndexContext(
99+
let context = await this.aiQueryContextRepository.getFullDbContext(sessionMetadata, databaseId, auth.accountId);
100+
101+
if (!context) {
102+
context = await this.aiQueryContextRepository.setFullDbContext(
134103
sessionMetadata,
135104
databaseId,
136105
auth.accountId,
137-
index,
106+
await getFullDbContext(client),
138107
);
108+
}
109+
110+
const question = classToClass(AiQueryMessage, {
111+
type: AiQueryMessageType.HumanMessage,
112+
content: dto.content,
113+
databaseId,
114+
accountId: auth.accountId,
115+
createdAt: new Date(),
116+
});
117+
118+
const answer = classToClass(AiQueryMessage, {
119+
type: AiQueryMessageType.AiMessage,
120+
content: '',
121+
databaseId,
122+
accountId: auth.accountId,
123+
});
124+
125+
socket = await this.aiQueryProvider.getSocket(sessionMetadata, auth);
126+
127+
socket.on(AiQueryWsEvents.REPLY_CHUNK, (chunk) => {
128+
answer.content += chunk;
129+
res.write(chunk);
130+
});
139131

140-
if (!indexContext) {
141-
return cb(await this.aiQueryContextRepository.setIndexContext(
132+
socket.on(AiQueryWsEvents.GET_INDEX, async (index, cb) => {
133+
try {
134+
const indexContext = await this.aiQueryContextRepository.getIndexContext(
142135
sessionMetadata,
143136
databaseId,
144137
auth.accountId,
145138
index,
146-
await getIndexContext(client, index),
147-
));
139+
);
140+
141+
if (!indexContext) {
142+
return cb(await this.aiQueryContextRepository.setIndexContext(
143+
sessionMetadata,
144+
databaseId,
145+
auth.accountId,
146+
index,
147+
await getIndexContext(client, index),
148+
));
149+
}
150+
151+
return cb(indexContext);
152+
} catch (e) {
153+
this.logger.warn('Unable to create index content', e);
154+
return cb(e.message);
148155
}
149-
150-
return cb(indexContext);
151-
} catch (e) {
152-
this.logger.warn('Unable to create index content', e);
153-
return cb(e.message);
154-
}
155-
});
156-
157-
socket.on(AiQueryWsEvents.RUN_QUERY, async (data, cb) => {
158-
try {
159-
if (!COMMANDS_WHITELIST[(data?.[0] || '').toLowerCase()]) {
160-
return cb('-ERR: This command is not allowed');
156+
});
157+
158+
socket.on(AiQueryWsEvents.RUN_QUERY, async (data, cb) => {
159+
try {
160+
if (!COMMANDS_WHITELIST[(data?.[0] || '').toLowerCase()]) {
161+
return cb('-ERR: This command is not allowed');
162+
}
163+
164+
return cb(await client.sendCommand(data, { replyEncoding: 'utf8' }));
165+
} catch (e) {
166+
this.logger.warn('Query execution error', e);
167+
return cb(e.message);
161168
}
162-
163-
return cb(await client.sendCommand(data, { replyEncoding: 'utf8' }));
164-
} catch (e) {
165-
this.logger.warn('Query execution error', e);
166-
return cb(e.message);
167-
}
168-
});
169-
170-
socket.on(AiQueryWsEvents.TOOL_CALL, async (data) => {
171-
answer.steps.push(plainToClass(AiQueryIntermediateStep, {
172-
type: AiQueryIntermediateStepType.TOOL_CALL,
173-
data,
174-
}));
175-
});
176-
177-
socket.on(AiQueryWsEvents.TOOL_REPLY, async (data) => {
178-
answer.steps.push(plainToClass(AiQueryIntermediateStep, {
179-
type: AiQueryIntermediateStepType.TOOL,
180-
data,
181-
}));
182-
});
183-
184-
await socket.emitWithAck('stream', dto.content, context, AiQueryService.prepareHistory(history));
185-
socket.close();
186-
await this.aiQueryMessageRepository.createMany(sessionMetadata, [question, answer]);
187-
188-
return res.end();
189-
} catch (e) {
190-
socket?.close?.();
191-
throw wrapAiQueryError(e, 'Unable to send the question');
192-
}
169+
});
170+
171+
socket.on(AiQueryWsEvents.TOOL_CALL, async (data) => {
172+
answer.steps.push(plainToClass(AiQueryIntermediateStep, {
173+
type: AiQueryIntermediateStepType.TOOL_CALL,
174+
data,
175+
}));
176+
});
177+
178+
socket.on(AiQueryWsEvents.TOOL_REPLY, async (data) => {
179+
answer.steps.push(plainToClass(AiQueryIntermediateStep, {
180+
type: AiQueryIntermediateStepType.TOOL,
181+
data,
182+
}));
183+
});
184+
185+
await socket.emitWithAck('stream', dto.content, context, AiQueryService.prepareHistory(history));
186+
socket.close();
187+
await this.aiQueryMessageRepository.createMany(sessionMetadata, [question, answer]);
188+
189+
return res.end();
190+
} catch (e) {
191+
socket?.close?.();
192+
throw wrapAiQueryError(e, 'Unable to send the question');
193+
}
194+
});
193195
}
194196

195197
async getHistory(sessionMetadata: SessionMetadata, databaseId: string): Promise<AiQueryMessage[]> {

redisinsight/api/src/modules/ai/query/providers/ai-query.provider.ts

Lines changed: 22 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -18,32 +18,30 @@ export class AiQueryProvider {
1818
) {}
1919

2020
async getSocket(sessionMetadata: SessionMetadata, auth: AiQueryAuthData): Promise<Socket> {
21-
return this.aiQueryAuthProvider.callWithAuthRetry(sessionMetadata, async () => {
22-
try {
23-
return await new Promise((resolve, reject) => {
24-
const socket = io(aiConfig.querySocketUrl, {
25-
path: aiConfig.querySocketPath,
26-
reconnection: false,
27-
transports: ['websocket'],
28-
extraHeaders: {
29-
'X-Csrf-Token': auth.csrf,
30-
Cookie: `JSESSIONID=${auth.sessionId}`,
31-
},
32-
});
21+
try {
22+
return await new Promise((resolve, reject) => {
23+
const socket = io(aiConfig.querySocketUrl, {
24+
path: aiConfig.querySocketPath,
25+
reconnection: false,
26+
transports: ['websocket'],
27+
extraHeaders: {
28+
'X-Csrf-Token': auth.csrf,
29+
Cookie: `JSESSIONID=${auth.sessionId}`,
30+
},
31+
});
3332

34-
socket.on(AiQueryWsEvents.CONNECT_ERROR, (e) => {
35-
this.logger.error('Unable to establish AI socket connection', e);
36-
reject(e);
37-
});
33+
socket.on(AiQueryWsEvents.CONNECT_ERROR, (e) => {
34+
this.logger.error('Unable to establish AI socket connection', e);
35+
reject(e);
36+
});
3837

39-
socket.on(AiQueryWsEvents.CONNECT, async () => {
40-
this.logger.debug('AI socket connection established');
41-
resolve(socket);
42-
});
38+
socket.on(AiQueryWsEvents.CONNECT, async () => {
39+
this.logger.debug('AI socket connection established');
40+
resolve(socket);
4341
});
44-
} catch (e) {
45-
throw wrapAiQueryError(e, 'Unable to establish connection');
46-
}
47-
});
42+
});
43+
} catch (e) {
44+
throw wrapAiQueryError(e, 'Unable to establish connection');
45+
}
4846
}
4947
}

redisinsight/api/src/modules/cloud/auth/cloud-auth.service.spec.ts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -314,6 +314,8 @@ describe('CloudAuthService', () => {
314314
{
315315
accessToken: mockCloudAccessTokenNew,
316316
refreshToken: mockCloudRefreshTokenNew,
317+
csrf: null,
318+
apiSessionId: null,
317319
},
318320
);
319321
});

redisinsight/api/src/modules/cloud/auth/cloud-auth.service.ts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -239,6 +239,8 @@ export class CloudAuthService {
239239
await this.sessionService.updateSessionData(sessionMetadata.sessionId, {
240240
accessToken: data.access_token,
241241
refreshToken: data.refresh_token,
242+
csrf: null,
243+
apiSessionId: null,
242244
});
243245
} catch (e) {
244246
throw new CloudApiUnauthorizedException();

0 commit comments

Comments
 (0)