Skip to content

Commit d9b0914

Browse files
authored
Merge pull request #424 from symflower/keep-query-usage
Collect usage metrics of each query to be able to calculate costs
2 parents 377f295 + 2489a6d commit d9b0914

File tree

13 files changed

+191
-155
lines changed

13 files changed

+191
-155
lines changed

evaluate/evaluate_test.go

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import (
1010
"testing"
1111
"time"
1212

13+
"github.com/sashabaranov/go-openai"
1314
"github.com/stretchr/testify/assert"
1415
"github.com/stretchr/testify/mock"
1516
"github.com/stretchr/testify/require"
@@ -244,7 +245,7 @@ func TestEvaluate(t *testing.T) {
244245

245246
Before: func(t *testing.T, logger *log.Logger, resultPath string) {
246247
// Set up mocks, when test is running.
247-
mockedQuery.On("Query", mock.Anything, mock.Anything, mock.Anything).Return("", ErrEmptyResponseFromModel)
248+
mockedQuery.On("Query", mock.Anything, mock.Anything, mock.Anything).Return(nil, ErrEmptyResponseFromModel)
248249
},
249250
After: func(t *testing.T, logger *log.Logger, resultPath string) {
250251
mockedQuery.AssertNumberOfCalls(t, "Query", 2)
@@ -324,11 +325,14 @@ func TestEvaluate(t *testing.T) {
324325
Name: "Success after retry",
325326

326327
Before: func(t *testing.T, logger *log.Logger, resultPath string) {
328+
queryResult := &provider.QueryResult{
329+
Message: "model-response",
330+
}
327331
// Set up mocks, when test is running.
328-
mockedQuery.On("Query", mock.Anything, mock.Anything, mock.Anything).Return("", ErrEmptyResponseFromModel).Once()
329-
mockedQuery.On("Query", mock.Anything, mock.Anything, mock.Anything).Return("model-response", nil).Once().After(10 * time.Millisecond) // Simulate a model response delay because our internal safety measures trigger when a query is done in 0 milliseconds.
330-
mockedQuery.On("Query", mock.Anything, mock.Anything, mock.Anything).Return("", ErrEmptyResponseFromModel).Once()
331-
mockedQuery.On("Query", mock.Anything, mock.Anything, mock.Anything).Return("model-response", nil).Once().After(10 * time.Millisecond) // Simulate a model response delay because our internal safety measures trigger when a query is done in 0 milliseconds.
332+
mockedQuery.On("Query", mock.Anything, mock.Anything, mock.Anything).Return(nil, ErrEmptyResponseFromModel).Once()
333+
mockedQuery.On("Query", mock.Anything, mock.Anything, mock.Anything).Return(queryResult, nil).Once().After(10 * time.Millisecond) // Simulate a model response delay because our internal safety measures trigger when a query is done in 0 milliseconds.
334+
mockedQuery.On("Query", mock.Anything, mock.Anything, mock.Anything).Return(nil, ErrEmptyResponseFromModel).Once()
335+
mockedQuery.On("Query", mock.Anything, mock.Anything, mock.Anything).Return(queryResult, nil).Once().After(10 * time.Millisecond) // Simulate a model response delay because our internal safety measures trigger when a query is done in 0 milliseconds.
332336
},
333337
After: func(t *testing.T, logger *log.Logger, resultPath string) {
334338
mockedQuery.AssertNumberOfCalls(t, "Query", 4)
@@ -423,8 +427,15 @@ func TestEvaluate(t *testing.T) {
423427
Name: "Immediate success",
424428

425429
Before: func(t *testing.T, logger *log.Logger, resultPath string) {
430+
queryResult := &provider.QueryResult{
431+
Message: "model-response",
432+
Usage: openai.Usage{
433+
PromptTokens: 123,
434+
CompletionTokens: 456,
435+
},
436+
}
426437
// Set up mocks, when test is running.
427-
mockedQuery.On("Query", mock.Anything, mock.Anything, mock.Anything).Return("model-response", nil).After(10 * time.Millisecond) // Simulate a model response delay because our internal safety measures trigger when a query is done in 0 milliseconds.
438+
mockedQuery.On("Query", mock.Anything, mock.Anything, mock.Anything).Return(queryResult, nil).After(10 * time.Millisecond) // Simulate a model response delay because our internal safety measures trigger when a query is done in 0 milliseconds.
428439
},
429440
After: func(t *testing.T, logger *log.Logger, resultPath string) {
430441
mockedQuery.AssertNumberOfCalls(t, "Query", 2)
@@ -457,6 +468,8 @@ func TestEvaluate(t *testing.T) {
457468
metrics.AssessmentKeyGenerateTestsForFileCharacterCount: 14,
458469
metrics.AssessmentKeyResponseCharacterCount: 14,
459470
metrics.AssessmentKeyResponseNoError: 1,
471+
metrics.AssessmentKeyTokenInput: 123,
472+
metrics.AssessmentKeyTokenOutput: 456,
460473
},
461474
},
462475
&metricstesting.AssessmentTuple{
@@ -470,6 +483,8 @@ func TestEvaluate(t *testing.T) {
470483
metrics.AssessmentKeyGenerateTestsForFileCharacterCount: 14,
471484
metrics.AssessmentKeyResponseCharacterCount: 14,
472485
metrics.AssessmentKeyResponseNoError: 1,
486+
metrics.AssessmentKeyTokenInput: 123,
487+
metrics.AssessmentKeyTokenOutput: 456,
473488
},
474489
},
475490
&metricstesting.AssessmentTuple{
@@ -483,6 +498,8 @@ func TestEvaluate(t *testing.T) {
483498
metrics.AssessmentKeyGenerateTestsForFileCharacterCount: 14,
484499
metrics.AssessmentKeyResponseCharacterCount: 14,
485500
metrics.AssessmentKeyResponseNoError: 1,
501+
metrics.AssessmentKeyTokenInput: 123,
502+
metrics.AssessmentKeyTokenOutput: 456,
486503
},
487504
},
488505
&metricstesting.AssessmentTuple{
@@ -496,6 +513,8 @@ func TestEvaluate(t *testing.T) {
496513
metrics.AssessmentKeyGenerateTestsForFileCharacterCount: 14,
497514
metrics.AssessmentKeyResponseCharacterCount: 14,
498515
metrics.AssessmentKeyResponseNoError: 1,
516+
metrics.AssessmentKeyTokenInput: 123,
517+
metrics.AssessmentKeyTokenOutput: 456,
499518
},
500519
},
501520
},

evaluate/metrics/assessment.go

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@ import (
1111
type AssessmentKey string
1212

1313
var (
14-
// allAssessmentKeys holds all registered assessment keys.
15-
allAssessmentKeys []AssessmentKey
14+
// AllAssessmentKeys holds all registered assessment keys.
15+
AllAssessmentKeys []AssessmentKey
1616
// AllAssessmentKeysStrings returns all registered assessment keys as strings.
1717
AllAssessmentKeysStrings []string
1818
)
@@ -22,7 +22,7 @@ func RegisterAssessmentKey(key string) AssessmentKey {
2222
assessment := AssessmentKey(key)
2323
i := sort.SearchStrings(AllAssessmentKeysStrings, key)
2424

25-
allAssessmentKeys = slices.Insert(allAssessmentKeys, i, assessment)
25+
AllAssessmentKeys = slices.Insert(AllAssessmentKeys, i, assessment)
2626
AllAssessmentKeysStrings = slices.Insert(AllAssessmentKeysStrings, i, key)
2727

2828
return assessment
@@ -54,6 +54,11 @@ var (
5454
// AssessmentKeyResponseNoExcess indicates that a model did not produce more content as requested.
5555
// TODO Infer if a model produced "too much" code. https://github.com/symflower/eval-dev-quality/issues/44
5656
AssessmentKeyResponseNoExcess = RegisterAssessmentKey("response-no-excess")
57+
58+
// AssessmentKeyTokenInput collects the number of input token.
59+
AssessmentKeyTokenInput = RegisterAssessmentKey("token-input")
60+
// AssessmentKeyTokenOutput collects the number of output token.
61+
AssessmentKeyTokenOutput = RegisterAssessmentKey("token-output")
5762
)
5863

5964
// Assessments holds a collection of numerical assessment metrics.
@@ -77,7 +82,7 @@ func (a Assessments) Equal(x Assessments) bool {
7782
return a == nil && x == nil
7883
}
7984

80-
for _, key := range allAssessmentKeys {
85+
for _, key := range AllAssessmentKeys {
8186
if a[key] != x[key] {
8287
return false
8388
}
@@ -101,9 +106,9 @@ func (a Assessments) String() string {
101106
if a == nil {
102107
a = NewAssessments()
103108
}
104-
entries := make([]string, len(allAssessmentKeys))
109+
entries := make([]string, len(AllAssessmentKeys))
105110

106-
for i, key := range allAssessmentKeys {
111+
for i, key := range AllAssessmentKeys {
107112
entries[i] = fmt.Sprintf("%s=%d", key, a[key])
108113
}
109114

@@ -116,8 +121,8 @@ func (a Assessments) StringCSV() (row []string) {
116121
a = NewAssessments()
117122
}
118123

119-
row = make([]string, len(allAssessmentKeys))
120-
for i, key := range allAssessmentKeys {
124+
row = make([]string, len(AllAssessmentKeys))
125+
for i, key := range AllAssessmentKeys {
121126
row[i] = fmt.Sprintf("%d", a[key])
122127
}
123128

evaluate/metrics/assessment_test.go

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ func TestAssessmentString(t *testing.T) {
8484

8585
Assessment: NewAssessments(),
8686

87-
ExpectedString: "coverage=0, files-executed=0, files-executed-maximum-reachable=0, generate-tests-for-file-character-count=0, processing-time=0, response-character-count=0, response-no-error=0, response-no-excess=0, response-with-code=0, tests-passing=0",
87+
ExpectedString: "coverage=0, files-executed=0, files-executed-maximum-reachable=0, generate-tests-for-file-character-count=0, processing-time=0, response-character-count=0, response-no-error=0, response-no-excess=0, response-with-code=0, tests-passing=0, token-input=0, token-output=0",
8888
})
8989

9090
validate(t, &testCase{
@@ -101,9 +101,11 @@ func TestAssessmentString(t *testing.T) {
101101
AssessmentKeyResponseWithCode: 5,
102102
AssessmentKeyProcessingTime: 200,
103103
AssessmentKeyTestsPassing: 7,
104+
AssessmentKeyTokenInput: 123,
105+
AssessmentKeyTokenOutput: 456,
104106
},
105107

106-
ExpectedString: "coverage=1, files-executed=2, files-executed-maximum-reachable=2, generate-tests-for-file-character-count=50, processing-time=200, response-character-count=100, response-no-error=3, response-no-excess=4, response-with-code=5, tests-passing=7",
108+
ExpectedString: "coverage=1, files-executed=2, files-executed-maximum-reachable=2, generate-tests-for-file-character-count=50, processing-time=200, response-character-count=100, response-no-error=3, response-no-excess=4, response-with-code=5, tests-passing=7, token-input=123, token-output=456",
107109
})
108110
}
109111

evaluate/report/csv_test.go

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ func TestNewEvaluationFile(t *testing.T) {
2424
require.NoError(t, err)
2525

2626
expectedEvaluationFileContent := bytesutil.StringTrimIndentations(`
27-
model-id,language,repository,case,task,run,coverage,files-executed,files-executed-maximum-reachable,generate-tests-for-file-character-count,processing-time,response-character-count,response-no-error,response-no-excess,response-with-code,tests-passing
27+
model-id,language,repository,case,task,run,coverage,files-executed,files-executed-maximum-reachable,generate-tests-for-file-character-count,processing-time,response-character-count,response-no-error,response-no-excess,response-with-code,tests-passing,token-input,token-output
2828
`)
2929

3030
assert.Equal(t, expectedEvaluationFileContent, string(actualEvaluationFileContent))
@@ -65,8 +65,8 @@ func TestWriteEvaluationRecord(t *testing.T) {
6565
},
6666

6767
ExpectedCSV: `
68-
model-id,language,repository,case,task,run,coverage,files-executed,files-executed-maximum-reachable,generate-tests-for-file-character-count,processing-time,response-character-count,response-no-error,response-no-excess,response-with-code,tests-passing
69-
mocked-model,golang,golang/plain,plain.go,write-tests,1,0,0,0,0,0,0,0,0,0,0
68+
model-id,language,repository,case,task,run,coverage,files-executed,files-executed-maximum-reachable,generate-tests-for-file-character-count,processing-time,response-character-count,response-no-error,response-no-excess,response-with-code,tests-passing,token-input,token-output
69+
mocked-model,golang,golang/plain,plain.go,write-tests,1,0,0,0,0,0,0,0,0,0,0,0,0
7070
`,
7171
})
7272
validate(t, &testCase{
@@ -90,9 +90,9 @@ func TestWriteEvaluationRecord(t *testing.T) {
9090
},
9191

9292
ExpectedCSV: `
93-
model-id,language,repository,case,task,run,coverage,files-executed,files-executed-maximum-reachable,generate-tests-for-file-character-count,processing-time,response-character-count,response-no-error,response-no-excess,response-with-code,tests-passing
94-
mocked-model,golang,golang/plain,plain.go,write-tests,1,0,1,1,0,0,0,1,0,0,0
95-
mocked-model,golang,golang/plain,plain.go,write-tests-symflower-fix,1,10,1,1,0,0,0,1,0,0,0
93+
model-id,language,repository,case,task,run,coverage,files-executed,files-executed-maximum-reachable,generate-tests-for-file-character-count,processing-time,response-character-count,response-no-error,response-no-excess,response-with-code,tests-passing,token-input,token-output
94+
mocked-model,golang,golang/plain,plain.go,write-tests,1,0,1,1,0,0,0,1,0,0,0,0,0
95+
mocked-model,golang,golang/plain,plain.go,write-tests-symflower-fix,1,10,1,1,0,0,0,1,0,0,0,0,0
9696
`,
9797
})
9898
}

evaluate/report/testing/csv.go

Lines changed: 22 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
package testing
22

33
import (
4-
"regexp"
54
"strconv"
5+
"strings"
66
"testing"
77

88
"github.com/stretchr/testify/assert"
@@ -18,32 +18,29 @@ func atoiUint64(t *testing.T, s string) uint64 {
1818
return uint64(value)
1919
}
2020

21-
// extractMetricsCSVMatch is a regular expression to extract metrics from CSV rows.
22-
var extractMetricsCSVMatch = regexp.MustCompile(`(\S+),(\S+),(\S+),(\S+),(\S+),\d+,(\d+),(\d+),(\d+),(\d+),(\d+),(\d+),(\d+),(\d+),(\d+),(\d+)`)
23-
2421
// ParseMetrics extracts multiple assessment metrics from the given string.
2522
func ParseMetrics(t *testing.T, data string) (assessments metricstesting.AssessmentTuples) {
26-
matches := extractMetricsCSVMatch.FindAllStringSubmatch(data, -1)
27-
28-
for _, match := range matches {
29-
assessments = append(assessments, &metricstesting.AssessmentTuple{
30-
Model: match[1],
31-
Language: match[2],
32-
RepositoryPath: match[3],
33-
Case: match[4],
34-
Task: task.Identifier(match[5]),
35-
Assessment: metrics.Assessments{
36-
metrics.AssessmentKeyCoverage: atoiUint64(t, match[6]),
37-
metrics.AssessmentKeyFilesExecuted: atoiUint64(t, match[7]),
38-
metrics.AssessmentKeyFilesExecutedMaximumReachable: atoiUint64(t, match[8]),
39-
metrics.AssessmentKeyGenerateTestsForFileCharacterCount: atoiUint64(t, match[9]),
40-
metrics.AssessmentKeyProcessingTime: atoiUint64(t, match[10]),
41-
metrics.AssessmentKeyResponseCharacterCount: atoiUint64(t, match[11]),
42-
metrics.AssessmentKeyResponseNoError: atoiUint64(t, match[12]),
43-
metrics.AssessmentKeyResponseNoExcess: atoiUint64(t, match[13]),
44-
metrics.AssessmentKeyResponseWithCode: atoiUint64(t, match[14]),
45-
},
46-
})
23+
lines := strings.Split(strings.TrimSpace(data), "\n")
24+
if len(lines) < 2 {
25+
return assessments
26+
}
27+
28+
for _, line := range lines[1:] {
29+
cells := strings.Split(line, ",")
30+
31+
tuple := &metricstesting.AssessmentTuple{
32+
Model: cells[0],
33+
Language: cells[1],
34+
RepositoryPath: cells[2],
35+
Case: cells[3],
36+
Task: task.Identifier(cells[4]),
37+
Assessment: metrics.Assessments{},
38+
}
39+
for i, key := range metrics.AllAssessmentKeys {
40+
tuple.Assessment[key] = atoiUint64(t, cells[i+6])
41+
}
42+
43+
assessments = append(assessments, tuple)
4744
}
4845

4946
return assessments

0 commit comments

Comments
 (0)