@@ -25,14 +25,16 @@ import {IAuthUserWithPermissions} from 'loopback4-authorization';
2525
2626describe ( 'GetTablesNode Unit' , function ( ) {
2727 let node : GetTablesNode ;
28- let llmStub : sinon . SinonStub ;
28+ let smartllmStub : sinon . SinonStub ;
29+ let dumbllmStub : sinon . SinonStub ;
2930 let schemaHelper : DbSchemaHelperService ;
3031 let schemaStore : SchemaStore ;
3132 let tableSearchStub : StubbedInstanceWithSinonAccessor < TableSearchService > ;
3233
3334 beforeEach ( async ( ) => {
34- llmStub = sinon . stub ( ) ;
35- const llm = llmStub as unknown as LLMProvider ;
35+ smartllmStub = sinon . stub ( ) ;
36+ dumbllmStub = sinon . stub ( ) ;
37+ const llm = dumbllmStub as unknown as LLMProvider ;
3638
3739 schemaHelper = new DbSchemaHelperService (
3840 new SqliteConnector (
@@ -48,9 +50,17 @@ describe('GetTablesNode Unit', function () {
4850 ) ;
4951 schemaStore = new SchemaStore ( ) ;
5052 tableSearchStub = createStubInstance ( TableSearchService ) ;
51- node = new GetTablesNode ( llm , schemaHelper , schemaStore , tableSearchStub , [
52- 'test context' ,
53- ] ) ;
53+ node = new GetTablesNode (
54+ llm ,
55+ dumbllmStub as unknown as LLMProvider ,
56+ {
57+ models : [ ] ,
58+ } ,
59+ schemaHelper ,
60+ schemaStore ,
61+ tableSearchStub ,
62+ [ 'test context' ] ,
63+ ) ;
5464 } ) ;
5565
5666 it ( 'should return state with minimal schema based on prompt and table search' , async ( ) => {
@@ -69,13 +79,99 @@ describe('GetTablesNode Unit', function () {
6979 schema : originalSchema ,
7080 } as unknown as DbQueryState ;
7181
72- llmStub . resolves ( {
82+ dumbllmStub . resolves ( {
83+ content : 'employees' ,
84+ } ) ;
85+
86+ const result = await node . execute ( state , { } ) ;
87+
88+ expect ( dumbllmStub . getCalls ( ) [ 0 ] . args [ 0 ] . value . trim ( ) ) . equal (
89+ `<instructions>
90+ You are an AI assistant that extracts table names that are relevant to the users query that will be used to generate an SQL query later.
91+ - Consider not just the user query but also the context and the table descriptions while selecting the tables.
92+ - Carefully consider each and every table before including or excluding it.
93+ - If doubtful about a table's relevance, include it anyway to give the SQL generation step more options to choose from.
94+ - Assume that the table would have appropriate columns for relating them to any other table even if the description does not mention it.
95+ - If you are not sure about the tables to select from the given schema, just return your doubt asking the user for more details or to rephrase the question in the following format -
96+ failed attempt: reason for failure
97+ </instructions>
98+
99+ <tables-with-description>
100+ employees: ${ Employee . definition . settings . description }
101+
102+ exchange_rates: ${ ExchangeRate . definition . settings . description }
103+ </tables-with-description>
104+
105+ <user-question>
106+ Get me the employee with name Akshat
107+ </user-question>
108+
109+ <must-follow-rules>
110+ - test context
111+ - employee salary must be converted to USD, using the currency_id column and the exchange rate table
112+ </must-follow-rules>
113+
114+
115+
116+ <output-format>
117+ The output should be just a comma separated list of table names with no other text, comments or formatting.
118+ Ensure that table names are exact and match the names in the input including schema if given.
119+ <example-output>
120+ public.employees, public.departments
121+ </example-output>
122+ In case of failure, return the failure message in the format -
123+ failed attempt: <reason for failure>
124+ <example-failure>
125+ failed attempt: reason for failure
126+ </example-failure>
127+ </output-format>` ,
128+ ) ;
129+
130+ expect ( result . schema ) . to . deepEqual (
131+ schemaStore . filteredSchema ( [ 'employees' ] ) ,
132+ ) ;
133+ } ) ;
134+
135+ it ( 'should return state with minimal schema based on prompt and table search with smart llm' , async ( ) => {
136+ node = new GetTablesNode (
137+ dumbllmStub as unknown as LLMProvider ,
138+ smartllmStub as unknown as LLMProvider ,
139+ {
140+ models : [ ] ,
141+ nodes : {
142+ // config to use smart llm for this node
143+ getTablesNode : {
144+ useSmartLLM : true ,
145+ } ,
146+ } ,
147+ } ,
148+ schemaHelper ,
149+ schemaStore ,
150+ tableSearchStub ,
151+ [ 'test context' ] ,
152+ ) ;
153+ tableSearchStub . stubs . getTables . resolves ( [ 'employees' , 'exchange_rates' ] ) ;
154+ const originalSchema = schemaHelper . modelToSchema ( '' , [
155+ Employee ,
156+ ExchangeRate ,
157+ Currency ,
158+ Skill ,
159+ EmployeeSkill ,
160+ ] ) ;
161+ await schemaStore . save ( originalSchema ) ;
162+
163+ const state = {
164+ prompt : 'Get me the employee with name Akshat' ,
165+ schema : originalSchema ,
166+ } as unknown as DbQueryState ;
167+
168+ smartllmStub . resolves ( {
73169 content : 'employees' ,
74170 } ) ;
75171
76172 const result = await node . execute ( state , { } ) ;
77173
78- expect ( llmStub . getCalls ( ) [ 0 ] . args [ 0 ] . value . trim ( ) ) . equal (
174+ expect ( smartllmStub . getCalls ( ) [ 0 ] . args [ 0 ] . value . trim ( ) ) . equal (
79175 `<instructions>
80176You are an AI assistant that extracts table names that are relevant to the users query that will be used to generate an SQL query later.
81177- Consider not just the user query but also the context and the table descriptions while selecting the tables.
@@ -138,7 +234,7 @@ failed attempt: reason for failure
138234 schema : originalSchema ,
139235 } as unknown as DbQueryState ;
140236
141- llmStub . resolves ( {
237+ dumbllmStub . resolves ( {
142238 content : 'employees' ,
143239 } ) ;
144240
@@ -163,16 +259,16 @@ failed attempt: reason for failure
163259 schema : originalSchema ,
164260 } as unknown as DbQueryState ;
165261
166- llmStub . onFirstCall ( ) . resolves ( {
262+ dumbllmStub . onFirstCall ( ) . resolves ( {
167263 content : 'non_existing_table' ,
168264 } ) ;
169- llmStub . onSecondCall ( ) . resolves ( {
265+ dumbllmStub . onSecondCall ( ) . resolves ( {
170266 content : 'employees' ,
171267 } ) ;
172268
173269 const result = await node . execute ( state , { } ) ;
174270
175- expect ( llmStub . callCount ) . to . equal ( 2 ) ;
271+ expect ( dumbllmStub . callCount ) . to . equal ( 2 ) ;
176272 expect ( result . schema ) . to . deepEqual (
177273 schemaStore . filteredSchema ( [ 'employees' ] ) ,
178274 ) ;
0 commit comments