Skip to content

Commit a989d62

Browse files
Merge pull request #3350 from RedisInsight/be/feature/RI-5711-context-cache
add cache for context
2 parents 4d49e37 + eb89795 commit a989d62

File tree

4 files changed

+200
-14
lines changed

4 files changed

+200
-14
lines changed

redisinsight/api/src/modules/ai/query/ai-query.module.ts

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,17 @@ import { AiQueryAuthProvider } from 'src/modules/ai/query/providers/auth/ai-quer
66
import { LocalAiQueryAuthProvider } from 'src/modules/ai/query/providers/auth/local.ai-query-auth.provider';
77
import { AiQueryMessageRepository } from 'src/modules/ai/query/repositories/ai-query.message.repository';
88
import { LocalAiQueryMessageRepository } from 'src/modules/ai/query/repositories/local.ai-query.message.repository';
9+
import { AiQueryContextRepository } from 'src/modules/ai/query/repositories/ai-query.context.repository';
10+
import {
11+
InMemoryAiQueryContextRepository,
12+
} from 'src/modules/ai/query/repositories/in-memory.ai-query.context.repository';
913

1014
@Module({})
1115
export class AiQueryModule {
1216
static register(
1317
aiQueryAuthProvider: Type<AiQueryAuthProvider> = LocalAiQueryAuthProvider,
1418
aiQueryMessageRepository: Type<AiQueryMessageRepository> = LocalAiQueryMessageRepository,
19+
aiQueryContextRepository: Type<AiQueryContextRepository> = InMemoryAiQueryContextRepository,
1520
) {
1621
return {
1722
module: AiQueryModule,
@@ -27,6 +32,10 @@ export class AiQueryModule {
2732
provide: AiQueryMessageRepository,
2833
useClass: aiQueryMessageRepository,
2934
},
35+
{
36+
provide: AiQueryContextRepository,
37+
useClass: aiQueryContextRepository,
38+
},
3039
],
3140
};
3241
}

redisinsight/api/src/modules/ai/query/ai-query.service.ts

Lines changed: 34 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ import { AiQueryMessageRepository } from 'src/modules/ai/query/repositories/ai-q
1919
import { AiQueryAuthProvider } from 'src/modules/ai/query/providers/auth/ai-query-auth.provider';
2020
import { classToClass } from 'src/utils';
2121
import { plainToClass } from 'class-transformer';
22+
import { AiQueryContextRepository } from 'src/modules/ai/query/repositories/ai-query.context.repository';
2223

2324
const COMMANDS_WHITELIST = {
2425
'ft.search': true,
@@ -34,6 +35,7 @@ export class AiQueryService {
3435
private readonly databaseClientFactory: DatabaseClientFactory,
3536
private readonly aiQueryMessageRepository: AiQueryMessageRepository,
3637
private readonly aiQueryAuthProvider: AiQueryAuthProvider,
38+
private readonly aiQueryContextRepository: AiQueryContextRepository,
3739
) {}
3840

3941
static prepareHistoryIntermediateSteps(message: AiQueryMessage): [AiQueryMessageRole, string][] {
@@ -93,7 +95,16 @@ export class AiQueryService {
9395
context: ClientContext.AI,
9496
});
9597

96-
const context = await getFullDbContext(client);
98+
let context = await this.aiQueryContextRepository.getFullDbContext(sessionMetadata, databaseId, auth.accountId);
99+
100+
if (!context) {
101+
context = await this.aiQueryContextRepository.setFullDbContext(
102+
sessionMetadata,
103+
databaseId,
104+
auth.accountId,
105+
await getFullDbContext(client),
106+
);
107+
}
97108

98109
const question = classToClass(AiQueryMessage, {
99110
type: AiQueryMessageType.HumanMessage,
@@ -119,11 +130,27 @@ export class AiQueryService {
119130

120131
socket.on(AiQueryWsEvents.GET_INDEX, async (index, cb) => {
121132
try {
122-
const indexContext = await getIndexContext(client, index);
123-
cb(indexContext);
133+
const indexContext = await this.aiQueryContextRepository.getIndexContext(
134+
sessionMetadata,
135+
databaseId,
136+
auth.accountId,
137+
index,
138+
);
139+
140+
if (!context) {
141+
return cb(await this.aiQueryContextRepository.setIndexContext(
142+
sessionMetadata,
143+
databaseId,
144+
auth.accountId,
145+
index,
146+
await getIndexContext(client, index),
147+
));
148+
}
149+
150+
return cb(indexContext);
124151
} catch (e) {
125152
this.logger.warn('Unable to create index content', e);
126-
cb(e.message);
153+
return cb(e.message);
127154
}
128155
});
129156

@@ -140,16 +167,6 @@ export class AiQueryService {
140167
}
141168
});
142169

143-
socket.on(AiQueryWsEvents.GET_INDEX, async (index, cb) => {
144-
try {
145-
const indexContext = await getIndexContext(client, index);
146-
return cb(indexContext);
147-
} catch (e) {
148-
this.logger.warn('Unable to create index content', e);
149-
return cb(e.message);
150-
}
151-
});
152-
153170
socket.on(AiQueryWsEvents.TOOL_CALL, async (data) => {
154171
answer.steps.push(plainToClass(AiQueryIntermediateStep, {
155172
type: AiQueryIntermediateStepType.TOOL_CALL,
@@ -187,6 +204,9 @@ export class AiQueryService {
187204
async clearHistory(sessionMetadata: SessionMetadata, databaseId: string): Promise<void> {
188205
try {
189206
const auth = await this.aiQueryAuthProvider.getAuthData(sessionMetadata);
207+
208+
await this.aiQueryContextRepository.reset(sessionMetadata, databaseId, auth.accountId);
209+
190210
return this.aiQueryMessageRepository.clearHistory(sessionMetadata, databaseId, auth.accountId);
191211
} catch (e) {
192212
throw wrapAiQueryError(e, 'Unable to clear history');
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
import { SessionMetadata } from 'src/common/models';
2+
3+
export abstract class AiQueryContextRepository {
4+
/**
5+
* Should return saved db context if exists in particular chat
6+
* @param sessionMetadata
7+
* @param databaseId
8+
* @param accountId
9+
*/
10+
abstract getFullDbContext(
11+
sessionMetadata: SessionMetadata,
12+
databaseId: string,
13+
accountId: string,
14+
): Promise<object>;
15+
16+
/**
17+
* Should save db context for particular chat
18+
* @param sessionMetadata
19+
* @param databaseId
20+
* @param accountId
21+
* @param context
22+
*/
23+
abstract setFullDbContext(
24+
sessionMetadata: SessionMetadata,
25+
databaseId: string,
26+
accountId: string,
27+
context: object,
28+
): Promise<object>;
29+
30+
/**
31+
* Should return saved index context if exists in particular chat
32+
* @param sessionMetadata
33+
* @param databaseId
34+
* @param accountId
35+
* @param index
36+
*/
37+
abstract getIndexContext(
38+
sessionMetadata: SessionMetadata,
39+
databaseId: string,
40+
accountId: string,
41+
index: string,
42+
): Promise<object>;
43+
44+
/**
45+
* Should save index context for particular chat
46+
* @param sessionMetadata
47+
* @param databaseId
48+
* @param accountId
49+
* @param index
50+
* @param context
51+
*/
52+
abstract setIndexContext(
53+
sessionMetadata: SessionMetadata,
54+
databaseId: string,
55+
accountId: string,
56+
index: string,
57+
context: object,
58+
): Promise<object>;
59+
60+
/**
61+
* Reset all index and db contexts for particular chat
62+
* @param sessionMetadata
63+
* @param databaseId
64+
* @param accountId
65+
*/
66+
abstract reset(
67+
sessionMetadata: SessionMetadata,
68+
databaseId: string,
69+
accountId: string,
70+
): Promise<void>;
71+
}
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
import { Injectable } from '@nestjs/common';
2+
import { get, set, unset } from 'lodash';
3+
import { SessionMetadata } from 'src/common/models';
4+
import { AiQueryContextRepository } from 'src/modules/ai/query/repositories/ai-query.context.repository';
5+
6+
@Injectable()
7+
export class InMemoryAiQueryContextRepository extends AiQueryContextRepository {
8+
private chats: Record<string, { index: Record<string, object>, db: object }> = {};
9+
10+
static getChatId(databaseId: string, accountId: string) {
11+
return `${databaseId}_${accountId}`;
12+
}
13+
14+
/**
15+
* @inheritdoc
16+
*/
17+
async getFullDbContext(
18+
_sessionMetadata: SessionMetadata,
19+
databaseId: string,
20+
accountId: string,
21+
): Promise<object> {
22+
const chatId = InMemoryAiQueryContextRepository.getChatId(databaseId, accountId);
23+
24+
return get(this.chats, [chatId, 'db'], null);
25+
}
26+
27+
/**
28+
* @inheritdoc
29+
*/
30+
async setFullDbContext(
31+
_sessionMetadata: SessionMetadata,
32+
databaseId: string,
33+
accountId: string,
34+
context: object,
35+
): Promise<object> {
36+
const chatId = InMemoryAiQueryContextRepository.getChatId(databaseId, accountId);
37+
38+
set(this.chats, [chatId, 'db'], context);
39+
40+
return context;
41+
}
42+
43+
/**
44+
* @inheritdoc
45+
*/
46+
async getIndexContext(
47+
_sessionMetadata: SessionMetadata,
48+
databaseId: string,
49+
accountId: string,
50+
index: string,
51+
): Promise<object> {
52+
const chatId = InMemoryAiQueryContextRepository.getChatId(databaseId, accountId);
53+
54+
return get(this.chats, [chatId, 'index', index], null);
55+
}
56+
57+
/**
58+
* @inheritdoc
59+
*/
60+
async setIndexContext(
61+
_sessionMetadata: SessionMetadata,
62+
databaseId: string,
63+
accountId: string,
64+
index: string,
65+
context: object,
66+
): Promise<object> {
67+
const chatId = InMemoryAiQueryContextRepository.getChatId(databaseId, accountId);
68+
69+
set(this.chats, [chatId, 'index', index], context);
70+
71+
return context;
72+
}
73+
74+
/**
75+
* @inheritdoc
76+
*/
77+
async reset(
78+
_sessionMetadata:SessionMetadata,
79+
databaseId:string,
80+
accountId:string,
81+
): Promise<void> {
82+
const chatId = InMemoryAiQueryContextRepository.getChatId(databaseId, accountId);
83+
84+
unset(this.chats, [chatId]);
85+
}
86+
}

0 commit comments

Comments
 (0)