@@ -3,6 +3,8 @@ package llm
33import (
44 "os"
55 "path/filepath"
6+ "regexp"
7+ "slices"
68 "strings"
79 "testing"
810 "time"
@@ -164,6 +166,119 @@ func TestModelGenerateTestsForFile(t *testing.T) {
164166 })
165167}
166168
169+ func TestModelQuery (t * testing.T ) {
170+ type testCase struct {
171+ Name string
172+
173+ SetupMock func (mockedProvider * providertesting.MockQuery )
174+
175+ QueryAttempts uint
176+ Request string
177+
178+ ExpectedResponse * provider.QueryResult
179+ ExpectedError string
180+
181+ ValidateLogs func (t * testing.T , logs string )
182+ }
183+
184+ validate := func (t * testing.T , tc * testCase ) {
185+ t .Run (tc .Name , func (t * testing.T ) {
186+ logOutput , logger := log .Buffer ()
187+ defer func () {
188+ if t .Failed () {
189+ t .Log (logOutput .String ())
190+ }
191+ }()
192+
193+ mock := providertesting .NewMockQuery (t )
194+ if tc .SetupMock != nil {
195+ tc .SetupMock (mock )
196+ }
197+ llm := NewModel (mock , "some-model" )
198+ llm .SetQueryAttempts (tc .QueryAttempts )
199+
200+ queryResult , actualError := llm .query (logger , tc .Request )
201+
202+ if tc .ExpectedError != "" {
203+ assert .ErrorContains (t , actualError , tc .ExpectedError )
204+ assert .Nil (t , queryResult )
205+ } else {
206+ assert .NoError (t , actualError )
207+
208+ queryResult .Duration = 0
209+
210+ assert .Equal (t , tc .ExpectedResponse , queryResult )
211+ }
212+
213+ if tc .ValidateLogs != nil {
214+ tc .ValidateLogs (t , logOutput .String ())
215+ }
216+ })
217+ }
218+
219+ reLogID := regexp .MustCompile (`query-id=([a-z0-9-]*)` )
220+ parseLogIDs := func (logs string ) (ids []string ) {
221+ for _ , match := range reLogID .FindAllStringSubmatch (logs , - 1 ) {
222+ ids = append (ids , match [1 ])
223+ }
224+
225+ return ids
226+ }
227+ assertAllIDsMatch := func (t * testing.T , logs string ) {
228+ ids := parseLogIDs (logs )
229+ assert .Len (t ,
230+ slices .CompactFunc (ids , func (e1 string , e2 string ) bool {
231+ return e1 == e2
232+ }),
233+ 1 ,
234+ )
235+ }
236+
237+ validate (t , & testCase {
238+ Name : "Successful" ,
239+ SetupMock : func (mockedProvider * providertesting.MockQuery ) {
240+ queryResult := & provider.QueryResult {
241+ Message : "test response" ,
242+ }
243+ mockedProvider .On ("Query" , mock .Anything , mock .Anything , "test request" ).Return (queryResult , nil )
244+ },
245+ QueryAttempts : 1 ,
246+ Request : "test request" ,
247+ ExpectedResponse : & provider.QueryResult {
248+ Message : "test response" ,
249+ },
250+
251+ ValidateLogs : assertAllIDsMatch ,
252+ })
253+
254+ validate (t , & testCase {
255+ Name : "Failed query no retry" ,
256+ SetupMock : func (mockedProvider * providertesting.MockQuery ) {
257+ mockedProvider .On ("Query" , mock .Anything , mock .Anything , "test request" ).Return (nil , assert .AnError )
258+ },
259+ QueryAttempts : 1 ,
260+ Request : "test request" ,
261+ ExpectedError : assert .AnError .Error (),
262+ })
263+
264+ validate (t , & testCase {
265+ Name : "Failed query with retry" ,
266+ SetupMock : func (mockedProvider * providertesting.MockQuery ) {
267+ mockedProvider .On ("Query" , mock .Anything , mock .Anything , "test request" ).Return (nil , assert .AnError ).Once ()
268+ mockedProvider .On ("Query" , mock .Anything , mock .Anything , "test request" ).Return (& provider.QueryResult {
269+ Message : "test response" ,
270+ }, nil ).Once ()
271+ },
272+ QueryAttempts : 1 + 1 ,
273+ Request : "test request" ,
274+ ExpectedResponse : & provider.QueryResult {
275+ Message : "test response" ,
276+ },
277+
278+ ValidateLogs : assertAllIDsMatch ,
279+ })
280+ }
281+
167282func TestModelRepairSourceCodeFile (t * testing.T ) {
168283 type testCase struct {
169284 Name string
0 commit comments