Skip to content

Commit 2210755

Browse files
authored
Merge pull request #428 from symflower/handle-empty-responses
fix, Collect assessments if a model responds with an empty message
2 parents eb112a7 + 6c6bd6b commit 2210755

File tree

7 files changed

+288
-36
lines changed

7 files changed

+288
-36
lines changed

evaluate/evaluate_test.go

Lines changed: 58 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -161,15 +161,39 @@ func TestEvaluate(t *testing.T) {
161161

162162
{
163163
languageGolang := &golang.Language{}
164-
mockedModel := modeltesting.NewMockCapabilityWriteTestsNamed(t, "empty-response-model")
164+
mockedModelID := "testing-provider/empty-response-model"
165+
mockedQuery := providertesting.NewMockQuery(t)
166+
mockedModel := llm.NewModel(mockedQuery, mockedModelID)
165167
repositoryPath := filepath.Join("golang", "plain")
166168

167169
validate(t, &testCase{
168-
Name: "Empty model responses are errors",
170+
Name: "Empty model response",
169171

170172
Before: func(t *testing.T, logger *log.Logger, resultPath string) {
173+
queryResult1 := &provider.QueryResult{
174+
Message: "",
175+
GenerationInfo: &provider.GenerationInfo{
176+
TotalCost: 0.111111111,
177+
NativeTokensPrompt: 111,
178+
NativeTokensCompletion: 222,
179+
},
180+
}
181+
// Set up mocks, when test is running.
182+
mockedQuery.On("Query", mock.Anything, mock.Anything, mock.Anything).Return(queryResult1, nil).Once().After(10 * time.Millisecond) // Simulate a model response delay because our internal safety measures trigger when a query is done in 0 milliseconds.
183+
184+
queryResult2 := &provider.QueryResult{
185+
Message: "",
186+
GenerationInfo: &provider.GenerationInfo{
187+
TotalCost: 0.222222222,
188+
NativeTokensPrompt: 333,
189+
NativeTokensCompletion: 444,
190+
},
191+
}
171192
// Set up mocks, when test is running.
172-
mockedModel.MockCapabilityWriteTests.On("WriteTests", mock.Anything).Return(nil, ErrEmptyResponseFromModel)
193+
mockedQuery.On("Query", mock.Anything, mock.Anything, mock.Anything).Return(queryResult2, nil).Once().After(10 * time.Millisecond) // Simulate a model response delay because our internal safety measures trigger when a query is done in 0 milliseconds.
194+
},
195+
After: func(t *testing.T, logger *log.Logger, resultPath string) {
196+
mockedQuery.AssertNumberOfCalls(t, "Query", 2)
173197
},
174198

175199
Context: &Context{
@@ -180,6 +204,11 @@ func TestEvaluate(t *testing.T) {
180204
Models: []evalmodel.Model{
181205
mockedModel,
182206
},
207+
QueryAttempts: 3,
208+
209+
RepositoryPaths: []string{
210+
repositoryPath,
211+
},
183212
},
184213

185214
ExpectedAssessments: []*metricstesting.AssessmentTuple{
@@ -189,8 +218,12 @@ func TestEvaluate(t *testing.T) {
189218
RepositoryPath: repositoryPath,
190219
Case: "plain.go",
191220
Task: evaluatetask.IdentifierWriteTests,
192-
Assessment: metrics.Assessments{
221+
Assessment: map[metrics.AssessmentKey]float64{
193222
metrics.AssessmentKeyFilesExecutedMaximumReachable: 1,
223+
metrics.AssessmentKeyResponseNoError: 1,
224+
metrics.AssessmentKeyCostsTokenActual: 0.111111111,
225+
metrics.AssessmentKeyNativeTokenInput: 111,
226+
metrics.AssessmentKeyNativeTokenOutput: 222,
194227
},
195228
},
196229
&metricstesting.AssessmentTuple{
@@ -199,8 +232,12 @@ func TestEvaluate(t *testing.T) {
199232
RepositoryPath: repositoryPath,
200233
Case: "plain.go",
201234
Task: evaluatetask.IdentifierWriteTestsSymflowerFix,
202-
Assessment: metrics.Assessments{
235+
Assessment: map[metrics.AssessmentKey]float64{
203236
metrics.AssessmentKeyFilesExecutedMaximumReachable: 1,
237+
metrics.AssessmentKeyResponseNoError: 1,
238+
metrics.AssessmentKeyCostsTokenActual: 0.111111111,
239+
metrics.AssessmentKeyNativeTokenInput: 111,
240+
metrics.AssessmentKeyNativeTokenOutput: 222,
204241
},
205242
},
206243
&metricstesting.AssessmentTuple{
@@ -209,8 +246,12 @@ func TestEvaluate(t *testing.T) {
209246
RepositoryPath: repositoryPath,
210247
Case: "plain.go",
211248
Task: evaluatetask.IdentifierWriteTestsSymflowerTemplate,
212-
Assessment: metrics.Assessments{
249+
Assessment: map[metrics.AssessmentKey]float64{
213250
metrics.AssessmentKeyFilesExecutedMaximumReachable: 1,
251+
metrics.AssessmentKeyResponseNoError: 1,
252+
metrics.AssessmentKeyCostsTokenActual: 0.222222222,
253+
metrics.AssessmentKeyNativeTokenInput: 333,
254+
metrics.AssessmentKeyNativeTokenOutput: 444,
214255
},
215256
},
216257
&metricstesting.AssessmentTuple{
@@ -219,15 +260,23 @@ func TestEvaluate(t *testing.T) {
219260
RepositoryPath: repositoryPath,
220261
Case: "plain.go",
221262
Task: evaluatetask.IdentifierWriteTestsSymflowerTemplateSymflowerFix,
222-
Assessment: metrics.Assessments{
263+
Assessment: map[metrics.AssessmentKey]float64{
223264
metrics.AssessmentKeyFilesExecutedMaximumReachable: 1,
265+
metrics.AssessmentKeyResponseNoError: 1,
266+
metrics.AssessmentKeyCostsTokenActual: 0.222222222,
267+
metrics.AssessmentKeyNativeTokenInput: 333,
268+
metrics.AssessmentKeyNativeTokenOutput: 444,
224269
},
225270
},
226271
},
227272
ExpectedResultFiles: map[string]func(t *testing.T, filePath string, data string){
228273
"evaluation.log": nil,
229-
filepath.Join(string(evaluatetask.IdentifierWriteTests), mockedModel.ID(), "golang", "golang", "plain", "evaluation.log"): nil,
230-
"evaluation.csv": nil,
274+
filepath.Join(string(evaluatetask.IdentifierWriteTests), log.CleanModelNameForFileSystem(mockedModelID), "golang", "golang", "plain", "evaluation.log"): func(t *testing.T, filePath, data string) {
275+
assert.Equal(t, 4, strings.Count(data, "no test files found"), "number of ocurrences of \"no test files found\" not matched")
276+
},
277+
"evaluation.csv": func(t *testing.T, filePath, data string) {
278+
assert.Lenf(t, strings.Split(data, "\n"), 6, "expected 6 lines: header, 4x entries and final new line:\n%s", data)
279+
},
231280
},
232281
})
233282
}

evaluate/task/write-test.go

Lines changed: 22 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -120,28 +120,17 @@ func (t *WriteTests) Run(ctx evaltask.Context) (repositoryAssessment map[string]
120120
ctx.Logger.Panicf("ERROR: unable to reset temporary repository path: %s", err)
121121
}
122122

123-
_, err = symflowerTemplate(taskLogger.Logger, dataPath, ctx.Language, filePath) // TODO Incorporate template processing time. https://github.com/symflower/eval-dev-quality/issues/350
123+
testTemplate, err := symflowerTemplateAsString(ctx, taskLogger, dataPath, filePath)
124124
if err != nil {
125-
problems = append(problems, pkgerrors.WithMessage(err, "generating Symflower template"))
125+
problems = append(problems, err)
126126

127127
withSymflowerTemplateAssessment.Add(modelAssessmentFile)
128128
withSymflowerTemplateAndFixAssessment.Add(withSymflowerFixAssessmentFile)
129129

130130
continue
131131
}
132132

133-
testTemplateFilePath := filepath.Join(dataPath, ctx.Language.TestFilePath(dataPath, filePath))
134-
testTemplate, err := os.ReadFile(testTemplateFilePath)
135-
if err != nil {
136-
problems = append(problems, pkgerrors.WithMessagef(err, "reading Symflower template from %q", testTemplateFilePath))
137-
138-
withSymflowerTemplateAssessment.Add(modelAssessmentFile)
139-
withSymflowerTemplateAndFixAssessment.Add(withSymflowerFixAssessmentFile)
140-
141-
continue
142-
}
143-
144-
arguments.Template = string(testTemplate)
133+
arguments.Template = testTemplate
145134
modelTemplateAssessmentFile, templateWithSymflowerFixAssessmentFile, ps, err := runModelAndSymflowerFix(ctx, modelContext, modelCapability.WriteTests)
146135
problems = append(problems, ps...)
147136
if err != nil {
@@ -155,6 +144,25 @@ func (t *WriteTests) Run(ctx evaltask.Context) (repositoryAssessment map[string]
155144
return repositoryAssessment, problems, nil
156145
}
157146

147+
// symflowerTemplateAsString generates a test template for the given file and makes sure that the repository is in the same state as before.
148+
func symflowerTemplateAsString(ctx evaltask.Context, taskLogger *taskLogger, dataPath string, filePath string) (testTemplate string, err error) {
149+
_, err = symflowerTemplate(taskLogger.Logger, dataPath, ctx.Language, filePath) // TODO Incorporate template processing time. https://github.com/symflower/eval-dev-quality/issues/350
150+
if err != nil {
151+
return "", pkgerrors.WithMessage(err, "generating Symflower template")
152+
}
153+
testTemplateFilePath := filepath.Join(dataPath, ctx.Language.TestFilePath(dataPath, filePath))
154+
testTemplateData, err := os.ReadFile(testTemplateFilePath)
155+
if err != nil {
156+
return "", pkgerrors.WithMessagef(err, "reading Symflower template from %q", testTemplateFilePath)
157+
}
158+
159+
if err := ctx.Repository.Reset(ctx.Logger); err != nil {
160+
ctx.Logger.Panicf("ERROR: unable to reset temporary repository path: %s", err)
161+
}
162+
163+
return string(testTemplateData), nil
164+
}
165+
158166
// validateWriteTestsRepository checks if the repository for the "write-tests" task is well-formed.
159167
func validateWriteTestsRepository(logger *log.Logger, repositoryPath string, language language.Language) (err error) {
160168
logger.Info("validating repository", "path", repositoryPath)

evaluate/task/write-test_test.go

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import (
88
"testing"
99

1010
"github.com/stretchr/testify/assert"
11+
"github.com/stretchr/testify/mock"
1112
"github.com/stretchr/testify/require"
1213
"github.com/symflower/eval-dev-quality/evaluate/metrics"
1314
metricstesting "github.com/symflower/eval-dev-quality/evaluate/metrics/testing"
@@ -123,6 +124,52 @@ func TestWriteTestsRun(t *testing.T) {
123124
})
124125
})
125126

127+
{
128+
temporaryDirectoryPath := t.TempDir()
129+
repositoryPath := filepath.Join(temporaryDirectoryPath, "golang", "plain")
130+
require.NoError(t, osutil.CopyTree(filepath.Join("..", "..", "testdata", "golang", "plain"), repositoryPath))
131+
132+
modelMock := modeltesting.NewMockCapabilityWriteTestsNamed(t, "mocked-model")
133+
// Simulate that a model does not generate anything.
134+
modelMock.MockCapabilityWriteTests.On("WriteTests", mock.Anything).Return(metricstesting.AssessmentsWithProcessingTime, nil)
135+
136+
validate(t, &tasktesting.TestCaseTask{
137+
Name: "Reset symflower template so it's not mistaken for model solution",
138+
139+
Model: modelMock,
140+
Language: &golang.Language{},
141+
TestDataPath: temporaryDirectoryPath,
142+
RepositoryPath: filepath.Join("golang", "plain"),
143+
144+
ExpectedRepositoryAssessment: map[string]map[evaltask.Identifier]metrics.Assessments{
145+
"plain.go": map[evaltask.Identifier]metrics.Assessments{
146+
IdentifierWriteTests: metrics.Assessments{
147+
metrics.AssessmentKeyFilesExecutedMaximumReachable: 1,
148+
metrics.AssessmentKeyResponseNoError: 1,
149+
},
150+
IdentifierWriteTestsSymflowerFix: metrics.Assessments{
151+
metrics.AssessmentKeyFilesExecutedMaximumReachable: 1,
152+
metrics.AssessmentKeyResponseNoError: 1,
153+
},
154+
IdentifierWriteTestsSymflowerTemplate: metrics.Assessments{
155+
metrics.AssessmentKeyFilesExecutedMaximumReachable: 1,
156+
metrics.AssessmentKeyResponseNoError: 1,
157+
},
158+
IdentifierWriteTestsSymflowerTemplateSymflowerFix: metrics.Assessments{
159+
metrics.AssessmentKeyFilesExecutedMaximumReachable: 1,
160+
metrics.AssessmentKeyResponseNoError: 1,
161+
},
162+
},
163+
},
164+
ExpectedProblemContains: []string{
165+
"ERROR: no test files found",
166+
"ERROR: no test files found",
167+
"ERROR: no test files found",
168+
"ERROR: no test files found",
169+
},
170+
})
171+
}
172+
126173
t.Run("Symflower Fix", func(t *testing.T) {
127174
t.Run("Go", func(t *testing.T) {
128175
validateGo := func(t *testing.T, testName string, language language.Language, testFileContent string, expectedAssessments map[string]map[evaltask.Identifier]metrics.Assessments, expectedProblems []string, assertTestsPass bool) {

model/llm/llm.go

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -329,10 +329,10 @@ func (m *Model) WriteTests(ctx model.Context) (assessment metrics.Assessments, e
329329

330330
func (m *Model) query(logger *log.Logger, request string) (queryResult *provider.QueryResult, err error) {
331331
var duration time.Duration
332+
id := uuid.NewString()
332333
if err := retry.Do(
333334
func() error {
334-
id := uuid.NewString
335-
logger.Info("querying model", "model", m.ID(), "id", id, "prompt", string(bytesutil.PrefixLines([]byte(request), []byte("\t"))))
335+
logger.Info("querying model", "model", m.ID(), "query-id", id, "prompt", string(bytesutil.PrefixLines([]byte(request), []byte("\t"))))
336336
start := time.Now()
337337
queryResult, err = m.provider.Query(context.Background(), m, request)
338338
if err != nil {
@@ -343,7 +343,7 @@ func (m *Model) query(logger *log.Logger, request string) (queryResult *provider
343343
if queryResult.GenerationInfo != nil {
344344
totalCosts = queryResult.GenerationInfo.TotalCost
345345
}
346-
logger.Info("model responded", "model", m.ID(), "id", id, "duration", duration.Milliseconds(), "response-id", queryResult.ResponseID, "costs-total", totalCosts, "token-input", queryResult.Usage.PromptTokens, "token-output", queryResult.Usage.CompletionTokens, "response", string(bytesutil.PrefixLines([]byte(queryResult.Message), []byte("\t"))))
346+
logger.Info("model responded", "model", m.ID(), "query-id", id, "duration", duration.Milliseconds(), "response-id", queryResult.ResponseID, "costs-total", totalCosts, "token-input", queryResult.Usage.PromptTokens, "token-output", queryResult.Usage.CompletionTokens, "response", string(bytesutil.PrefixLines([]byte(queryResult.Message), []byte("\t"))))
347347

348348
return nil
349349
},
@@ -506,6 +506,10 @@ func handleQueryResult(queryResult *provider.QueryResult, filePathAbsolute strin
506506
assessment[metrics.AssessmentKeyCostsTokenActual] = queryResult.GenerationInfo.TotalCost
507507
}
508508

509+
if sourceFileContent == "" {
510+
return assessment, nil
511+
}
512+
509513
if err := os.MkdirAll(filepath.Dir(filePathAbsolute), 0755); err != nil {
510514
return nil, pkgerrors.WithStack(err)
511515
}

0 commit comments

Comments
 (0)