Skip to content

Commit d56da5c

Browse files
committed
feat(component): option to use dumb llm in table selection
1 parent b02b43f commit d56da5c

File tree

4 files changed

+125
-18
lines changed

4 files changed

+125
-18
lines changed

src/__tests__/db-query/unit/nodes/get-tables.node.unit.ts

Lines changed: 108 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -25,14 +25,16 @@ import {IAuthUserWithPermissions} from 'loopback4-authorization';
2525

2626
describe('GetTablesNode Unit', function () {
2727
let node: GetTablesNode;
28-
let llmStub: sinon.SinonStub;
28+
let smartllmStub: sinon.SinonStub;
29+
let dumbllmStub: sinon.SinonStub;
2930
let schemaHelper: DbSchemaHelperService;
3031
let schemaStore: SchemaStore;
3132
let tableSearchStub: StubbedInstanceWithSinonAccessor<TableSearchService>;
3233

3334
beforeEach(async () => {
34-
llmStub = sinon.stub();
35-
const llm = llmStub as unknown as LLMProvider;
35+
smartllmStub = sinon.stub();
36+
dumbllmStub = sinon.stub();
37+
const llm = dumbllmStub as unknown as LLMProvider;
3638

3739
schemaHelper = new DbSchemaHelperService(
3840
new SqliteConnector(
@@ -48,9 +50,17 @@ describe('GetTablesNode Unit', function () {
4850
);
4951
schemaStore = new SchemaStore();
5052
tableSearchStub = createStubInstance(TableSearchService);
51-
node = new GetTablesNode(llm, schemaHelper, schemaStore, tableSearchStub, [
52-
'test context',
53-
]);
53+
node = new GetTablesNode(
54+
llm,
55+
dumbllmStub as unknown as LLMProvider,
56+
{
57+
models: [],
58+
},
59+
schemaHelper,
60+
schemaStore,
61+
tableSearchStub,
62+
['test context'],
63+
);
5464
});
5565

5666
it('should return state with minimal schema based on prompt and table search', async () => {
@@ -69,13 +79,99 @@ describe('GetTablesNode Unit', function () {
6979
schema: originalSchema,
7080
} as unknown as DbQueryState;
7181

72-
llmStub.resolves({
82+
dumbllmStub.resolves({
83+
content: 'employees',
84+
});
85+
86+
const result = await node.execute(state, {});
87+
88+
expect(dumbllmStub.getCalls()[0].args[0].value.trim()).equal(
89+
`<instructions>
90+
You are an AI assistant that extracts table names that are relevant to the users query that will be used to generate an SQL query later.
91+
- Consider not just the user query but also the context and the table descriptions while selecting the tables.
92+
- Carefully consider each and every table before including or excluding it.
93+
- If doubtful about a table's relevance, include it anyway to give the SQL generation step more options to choose from.
94+
- Assume that the table would have appropriate columns for relating them to any other table even if the description does not mention it.
95+
- If you are not sure about the tables to select from the given schema, just return your doubt asking the user for more details or to rephrase the question in the following format -
96+
failed attempt: reason for failure
97+
</instructions>
98+
99+
<tables-with-description>
100+
employees: ${Employee.definition.settings.description}
101+
102+
exchange_rates: ${ExchangeRate.definition.settings.description}
103+
</tables-with-description>
104+
105+
<user-question>
106+
Get me the employee with name Akshat
107+
</user-question>
108+
109+
<must-follow-rules>
110+
- test context
111+
- employee salary must be converted to USD, using the currency_id column and the exchange rate table
112+
</must-follow-rules>
113+
114+
115+
116+
<output-format>
117+
The output should be just a comma separated list of table names with no other text, comments or formatting.
118+
Ensure that table names are exact and match the names in the input including schema if given.
119+
<example-output>
120+
public.employees, public.departments
121+
</example-output>
122+
In case of failure, return the failure message in the format -
123+
failed attempt: <reason for failure>
124+
<example-failure>
125+
failed attempt: reason for failure
126+
</example-failure>
127+
</output-format>`,
128+
);
129+
130+
expect(result.schema).to.deepEqual(
131+
schemaStore.filteredSchema(['employees']),
132+
);
133+
});
134+
135+
it('should return state with minimal schema based on prompt and table search with smart llm', async () => {
136+
node = new GetTablesNode(
137+
dumbllmStub as unknown as LLMProvider,
138+
smartllmStub as unknown as LLMProvider,
139+
{
140+
models: [],
141+
nodes: {
142+
// config to use smart llm for this node
143+
getTablesNode: {
144+
useSmartLLM: true,
145+
},
146+
},
147+
},
148+
schemaHelper,
149+
schemaStore,
150+
tableSearchStub,
151+
['test context'],
152+
);
153+
tableSearchStub.stubs.getTables.resolves(['employees', 'exchange_rates']);
154+
const originalSchema = schemaHelper.modelToSchema('', [
155+
Employee,
156+
ExchangeRate,
157+
Currency,
158+
Skill,
159+
EmployeeSkill,
160+
]);
161+
await schemaStore.save(originalSchema);
162+
163+
const state = {
164+
prompt: 'Get me the employee with name Akshat',
165+
schema: originalSchema,
166+
} as unknown as DbQueryState;
167+
168+
smartllmStub.resolves({
73169
content: 'employees',
74170
});
75171

76172
const result = await node.execute(state, {});
77173

78-
expect(llmStub.getCalls()[0].args[0].value.trim()).equal(
174+
expect(smartllmStub.getCalls()[0].args[0].value.trim()).equal(
79175
`<instructions>
80176
You are an AI assistant that extracts table names that are relevant to the users query that will be used to generate an SQL query later.
81177
- Consider not just the user query but also the context and the table descriptions while selecting the tables.
@@ -138,7 +234,7 @@ failed attempt: reason for failure
138234
schema: originalSchema,
139235
} as unknown as DbQueryState;
140236

141-
llmStub.resolves({
237+
dumbllmStub.resolves({
142238
content: 'employees',
143239
});
144240

@@ -163,16 +259,16 @@ failed attempt: reason for failure
163259
schema: originalSchema,
164260
} as unknown as DbQueryState;
165261

166-
llmStub.onFirstCall().resolves({
262+
dumbllmStub.onFirstCall().resolves({
167263
content: 'non_existing_table',
168264
});
169-
llmStub.onSecondCall().resolves({
265+
dumbllmStub.onSecondCall().resolves({
170266
content: 'employees',
171267
});
172268

173269
const result = await node.execute(state, {});
174270

175-
expect(llmStub.callCount).to.equal(2);
271+
expect(dumbllmStub.callCount).to.equal(2);
176272
expect(result.schema).to.deepEqual(
177273
schemaStore.filteredSchema(['employees']),
178274
);

src/components/db-query/nodes/get-tables.node.ts

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,17 @@ import {DbSchemaHelperService} from '../services';
1313
import {SchemaStore} from '../services/schema.store';
1414
import {TableSearchService} from '../services/search/table-search.service';
1515
import {DbQueryState} from '../state';
16-
import {DatabaseSchema, GenerationError} from '../types';
16+
import {DatabaseSchema, DbQueryConfig, GenerationError} from '../types';
1717

1818
@graphNode(DbQueryNodes.GetTables)
1919
export class GetTablesNode implements IGraphNode<DbQueryState> {
2020
constructor(
21+
@inject(AiIntegrationBindings.CheapLLM)
22+
private readonly llmCheap: LLMProvider,
2123
@inject(AiIntegrationBindings.SmartLLM)
22-
private readonly llm: LLMProvider,
24+
private readonly llmSmart: LLMProvider,
25+
@inject(DbQueryAIExtensionBindings.Config)
26+
private readonly config: DbQueryConfig,
2327
@service(DbSchemaHelperService)
2428
private readonly schemaHelper: DbSchemaHelperService,
2529
@service(SchemaStore)
@@ -95,7 +99,10 @@ Use these if they are relevant to the table selection, otherwise ignore them, th
9599
);
96100
}
97101

98-
const chain = RunnableSequence.from([this.prompt, this.llm]);
102+
const useSmartLLM = this.config.nodes?.getTablesNode?.useSmartLLM ?? false;
103+
const llm = useSmartLLM ? this.llmSmart : this.llmCheap;
104+
105+
const chain = RunnableSequence.from([this.prompt, llm]);
99106
config.writer?.({
100107
type: LLMStreamEventType.ToolStatus,
101108
data: {

src/components/db-query/nodes/sql-generation.node.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ In the last attempt, you generated this SQL query -
114114
});
115115

116116
const generateDesc =
117-
this.config.nodes?.sqlGenerationWithDescription !== false;
117+
this.config.nodes?.sqlGenerationNode?.generateDescription !== false;
118118

119119
const output = await chain.invoke({
120120
dialect: this.config.db?.dialect ?? SupportedDBs.PostgreSQL,

src/components/db-query/types.ts

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -113,8 +113,12 @@ export type DbQueryConfig = {
113113
maxClusterSize?: number;
114114
};
115115
nodes?: {
116-
sqlGenerationWithDescription?: boolean;
117-
renderNode?: boolean;
116+
sqlGenerationNode?: {
117+
generateDescription?: boolean;
118+
};
119+
getTablesNode?: {
120+
useSmartLLM?: boolean;
121+
};
118122
};
119123
columnSelection?: boolean;
120124
};

0 commit comments

Comments
 (0)