Skip to content

Commit 8187e1d

Browse files
committed
Fetch total costs from OpenRouter after query
1 parent 9cd59b9 commit 8187e1d

File tree

6 files changed

+100
-10
lines changed

6 files changed

+100
-10
lines changed

evaluate/metrics/assessment.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,12 @@ var (
5959
AssessmentKeyTokenInput = RegisterAssessmentKey("token-input")
6060
// AssessmentKeyTokenOutput collects the number of output token.
6161
AssessmentKeyTokenOutput = RegisterAssessmentKey("token-output")
62+
// AssessmentKeyNativeTokenInput collects the number of input token.
63+
AssessmentKeyNativeTokenInput = RegisterAssessmentKey("native-token-input")
64+
// AssessmentKeyNativeTokenOutput collects the number of output token.
65+
AssessmentKeyNativeTokenOutput = RegisterAssessmentKey("native-token-output")
66+
// AssessmentKeyCostsTokenActual collects the number of output token.
67+
AssessmentKeyCostsTokenActual = RegisterAssessmentKey("costs-total-actual")
6268
)
6369

6470
// Assessments holds a collection of numerical assessment metrics.

evaluate/metrics/assessment_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ func TestAssessmentString(t *testing.T) {
8484

8585
Assessment: NewAssessments(),
8686

87-
ExpectedString: "coverage=0, files-executed=0, files-executed-maximum-reachable=0, generate-tests-for-file-character-count=0, processing-time=0, response-character-count=0, response-no-error=0, response-no-excess=0, response-with-code=0, tests-passing=0, token-input=0, token-output=0",
87+
ExpectedString: "costs-total-actual=0, coverage=0, files-executed=0, files-executed-maximum-reachable=0, generate-tests-for-file-character-count=0, native-token-input=0, native-token-output=0, processing-time=0, response-character-count=0, response-no-error=0, response-no-excess=0, response-with-code=0, tests-passing=0, token-input=0, token-output=0",
8888
})
8989

9090
validate(t, &testCase{
@@ -105,7 +105,7 @@ func TestAssessmentString(t *testing.T) {
105105
AssessmentKeyTokenOutput: 456,
106106
},
107107

108-
ExpectedString: "coverage=1, files-executed=2, files-executed-maximum-reachable=2, generate-tests-for-file-character-count=50, processing-time=200, response-character-count=100, response-no-error=3, response-no-excess=4, response-with-code=5, tests-passing=7, token-input=123, token-output=456",
108+
ExpectedString: "costs-total-actual=0, coverage=1, files-executed=2, files-executed-maximum-reachable=2, generate-tests-for-file-character-count=50, native-token-input=0, native-token-output=0, processing-time=200, response-character-count=100, response-no-error=3, response-no-excess=4, response-with-code=5, tests-passing=7, token-input=123, token-output=456",
109109
})
110110
}
111111

evaluate/report/csv_test.go

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ func TestNewEvaluationFile(t *testing.T) {
2424
require.NoError(t, err)
2525

2626
expectedEvaluationFileContent := bytesutil.StringTrimIndentations(`
27-
model-id,language,repository,case,task,run,coverage,files-executed,files-executed-maximum-reachable,generate-tests-for-file-character-count,processing-time,response-character-count,response-no-error,response-no-excess,response-with-code,tests-passing,token-input,token-output
27+
model-id,language,repository,case,task,run,costs-total-actual,coverage,files-executed,files-executed-maximum-reachable,generate-tests-for-file-character-count,native-token-input,native-token-output,processing-time,response-character-count,response-no-error,response-no-excess,response-with-code,tests-passing,token-input,token-output
2828
`)
2929

3030
assert.Equal(t, expectedEvaluationFileContent, string(actualEvaluationFileContent))
@@ -65,8 +65,8 @@ func TestWriteEvaluationRecord(t *testing.T) {
6565
},
6666

6767
ExpectedCSV: `
68-
model-id,language,repository,case,task,run,coverage,files-executed,files-executed-maximum-reachable,generate-tests-for-file-character-count,processing-time,response-character-count,response-no-error,response-no-excess,response-with-code,tests-passing,token-input,token-output
69-
mocked-model,golang,golang/plain,plain.go,write-tests,1,0,0,0,0,0,0,0,0,0,0,0,0
68+
model-id,language,repository,case,task,run,costs-total-actual,coverage,files-executed,files-executed-maximum-reachable,generate-tests-for-file-character-count,native-token-input,native-token-output,processing-time,response-character-count,response-no-error,response-no-excess,response-with-code,tests-passing,token-input,token-output
69+
mocked-model,golang,golang/plain,plain.go,write-tests,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
7070
`,
7171
})
7272
validate(t, &testCase{
@@ -90,9 +90,9 @@ func TestWriteEvaluationRecord(t *testing.T) {
9090
},
9191

9292
ExpectedCSV: `
93-
model-id,language,repository,case,task,run,coverage,files-executed,files-executed-maximum-reachable,generate-tests-for-file-character-count,processing-time,response-character-count,response-no-error,response-no-excess,response-with-code,tests-passing,token-input,token-output
94-
mocked-model,golang,golang/plain,plain.go,write-tests,1,0,1,1,0,0,0,1,0,0,0,0,0
95-
mocked-model,golang,golang/plain,plain.go,write-tests-symflower-fix,1,10,1,1,0,0,0,1,0,0,0,0,0
93+
model-id,language,repository,case,task,run,costs-total-actual,coverage,files-executed,files-executed-maximum-reachable,generate-tests-for-file-character-count,native-token-input,native-token-output,processing-time,response-character-count,response-no-error,response-no-excess,response-with-code,tests-passing,token-input,token-output
94+
mocked-model,golang,golang/plain,plain.go,write-tests,1,0,0,1,1,0,0,0,0,0,1,0,0,0,0,0
95+
mocked-model,golang,golang/plain,plain.go,write-tests-symflower-fix,1,0,10,1,1,0,0,0,0,0,1,0,0,0,0,0
9696
`,
9797
})
9898
}

model/llm/llm.go

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -339,7 +339,11 @@ func (m *Model) query(logger *log.Logger, request string) (queryResult *provider
339339
return err
340340
}
341341
duration = time.Since(start)
342-
logger.Info("model responded", "model", m.ID(), "id", id, "duration", duration.Milliseconds(), "response-id", queryResult.ResponseID, "token-input", queryResult.Usage.PromptTokens, "token-output", queryResult.Usage.CompletionTokens, "response", string(bytesutil.PrefixLines([]byte(queryResult.Message), []byte("\t"))))
342+
totalCosts := float64(-1)
343+
if queryResult.GenerationInfo != nil {
344+
totalCosts = queryResult.GenerationInfo.TotalCost
345+
}
346+
logger.Info("model responded", "model", m.ID(), "id", id, "duration", duration.Milliseconds(), "response-id", queryResult.ResponseID, "costs-total", totalCosts, "token-input", queryResult.Usage.PromptTokens, "token-output", queryResult.Usage.CompletionTokens, "response", string(bytesutil.PrefixLines([]byte(queryResult.Message), []byte("\t"))))
343347

344348
return nil
345349
},
@@ -496,6 +500,11 @@ func handleQueryResult(queryResult *provider.QueryResult, filePathAbsolute strin
496500
assessment[metrics.AssessmentKeyGenerateTestsForFileCharacterCount] = float64(len(sourceFileContent))
497501
assessment[metrics.AssessmentKeyTokenInput] = float64(queryResult.Usage.PromptTokens)
498502
assessment[metrics.AssessmentKeyTokenOutput] = float64(queryResult.Usage.CompletionTokens)
503+
if queryResult.GenerationInfo != nil {
504+
assessment[metrics.AssessmentKeyNativeTokenInput] = float64(queryResult.GenerationInfo.NativeTokensPrompt)
505+
assessment[metrics.AssessmentKeyNativeTokenOutput] = float64(queryResult.GenerationInfo.NativeTokensCompletion)
506+
assessment[metrics.AssessmentKeyCostsTokenActual] = queryResult.GenerationInfo.TotalCost
507+
}
499508

500509
if err := os.MkdirAll(filepath.Dir(filePathAbsolute), 0755); err != nil {
501510
return nil, pkgerrors.WithStack(err)

provider/openrouter/openrouter.go

Lines changed: 62 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,17 @@ var _ provider.Query = (*Provider)(nil)
138138

139139
// Query queries the provider with the given model name.
140140
func (p *Provider) Query(ctx context.Context, model model.Model, promptText string) (result *provider.QueryResult, err error) {
141-
return openaiapi.QueryOpenAIAPIModel(ctx, p.client(), model.ModelIDWithoutProvider(), model.Attributes(), promptText)
141+
queryResult, err := openaiapi.QueryOpenAIAPIModel(ctx, p.client(), model.ModelIDWithoutProvider(), model.Attributes(), promptText)
142+
if err != nil {
143+
return nil, pkgerrors.WithStack(err)
144+
}
145+
146+
queryResult.GenerationInfo, err = p.fetchGenerationInfo(queryResult.ResponseID)
147+
if err != nil {
148+
return nil, pkgerrors.WithStack(err)
149+
}
150+
151+
return queryResult, nil
142152
}
143153

144154
// client returns a new client with the current configuration.
@@ -148,3 +158,54 @@ func (p *Provider) client() (client *openai.Client) {
148158

149159
return openai.NewClientWithConfig(config)
150160
}
161+
162+
func (p *Provider) fetchGenerationInfo(generationID string) (generationInfo *provider.GenerationInfo, err error) {
163+
request, err := http.NewRequest("GET", "https://openrouter.ai/api/v1/generation?id="+generationID, nil)
164+
if err != nil {
165+
return nil, pkgerrors.WithStack(err)
166+
}
167+
request.Header.Set("Accept", "application/json")
168+
request.Header.Set("Authorization", "Bearer "+p.token)
169+
170+
client := &http.Client{}
171+
var responseBody []byte
172+
if err := retry.Do( // Query available models with a retry logic cause "openrouter.ai" has failed us in the past.
173+
func() error {
174+
response, err := client.Do(request)
175+
if err != nil {
176+
return pkgerrors.WithStack(err)
177+
}
178+
defer func() {
179+
if e := response.Body.Close(); e != nil {
180+
err = errors.Join(err, pkgerrors.WithStack(e))
181+
}
182+
}()
183+
184+
if response.StatusCode != http.StatusOK {
185+
return pkgerrors.Errorf("received status code %d when querying provider models", response.StatusCode)
186+
}
187+
188+
responseBody, err = io.ReadAll(response.Body)
189+
if err != nil {
190+
return pkgerrors.WithStack(err)
191+
}
192+
193+
return nil
194+
},
195+
retry.Attempts(3),
196+
retry.Delay(5*time.Second),
197+
retry.DelayType(retry.BackOffDelay),
198+
retry.LastErrorOnly(true),
199+
); err != nil {
200+
return nil, err
201+
}
202+
203+
var dataResponse struct {
204+
provider.GenerationInfo `json:"data"`
205+
}
206+
if err := json.Unmarshal(responseBody, &dataResponse); err != nil {
207+
return nil, err
208+
}
209+
210+
return &dataResponse.GenerationInfo, nil
211+
}

provider/provider.go

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,20 @@ type QueryResult struct {
5454
Duration time.Duration
5555
// Usage holds the usage metrics of the query.
5656
Usage openai.Usage
57+
// GenerationInfo holds information about a generation.
58+
GenerationInfo *GenerationInfo
59+
}
60+
61+
// GenerationInfo holds information about a generation.
62+
// See https://openrouter.ai/docs/api-reference/overview#querying-cost-and-stats for more details.
63+
type GenerationInfo struct {
64+
ID string `json:"id"`
65+
TotalCost float64 `json:"total_cost"`
66+
TokensPrompt int `json:"tokens_prompt"`
67+
TokensCompletion int `json:"tokens_completion"`
68+
NativeTokensPrompt int `json:"native_tokens_prompt"`
69+
NativeTokensCompletion int `json:"native_tokens_completion"`
70+
NativeTokensReasoning int `json:"native_tokens_reasoning"`
5771
}
5872

5973
// Query is a provider that allows to query a model directly.

0 commit comments

Comments
 (0)