-
-
Notifications
You must be signed in to change notification settings - Fork 26
feat: add embeddings API support #38
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,127 @@ | ||
| package openrouter | ||
|
|
||
| import ( | ||
| "context" | ||
| "encoding/json" | ||
| "fmt" | ||
| "net/http" | ||
| ) | ||
|
|
||
| const embeddingsSuffix = "/embeddings" | ||
|
|
||
| // EmbeddingsEncodingFormat controls how embeddings are returned by the API. | ||
| // See: https://openrouter.ai/docs/api/api-reference/embeddings/create-embeddings | ||
| type EmbeddingsEncodingFormat string | ||
|
|
||
| const ( | ||
| EmbeddingsEncodingFormatFloat EmbeddingsEncodingFormat = "float" | ||
| EmbeddingsEncodingFormatBase64 EmbeddingsEncodingFormat = "base64" | ||
| ) | ||
|
|
||
| // EmbeddingsRequest represents a request to the /embeddings endpoint. | ||
| // | ||
| // The input field is intentionally typed as any to support the flexible input | ||
| // types accepted by the OpenRouter API: | ||
| // - string | ||
| // - []string | ||
| // - []float64 | ||
| // - [][]float64 | ||
| // - structured content blocks | ||
| // | ||
| // For examples, see: https://openrouter.ai/docs/api/api-reference/embeddings/create-embeddings | ||
| type EmbeddingsRequest struct { | ||
| // Model is the model slug to use for embeddings. | ||
| Model string `json:"model"` | ||
| // Input is the content to embed. See the API docs for supported formats. | ||
| Input any `json:"input"` | ||
|
|
||
| // EncodingFormat controls how the embedding is returned: "float" or "base64". | ||
| EncodingFormat EmbeddingsEncodingFormat `json:"encoding_format,omitempty"` | ||
| // Dimensions optionally truncates the embedding to the given number of dimensions. | ||
| Dimensions *int `json:"dimensions,omitempty"` | ||
| // User is an optional identifier for the end-user making the request. | ||
| User string `json:"user,omitempty"` | ||
| // Provider configuration for provider routing. This reuses the same structure | ||
| // as chat/completions provider routing, which is compatible with the embeddings API. | ||
| Provider *ChatProvider `json:"provider,omitempty"` | ||
| // InputType is an optional hint describing the type of input, e.g. "text" or "image". | ||
| InputType string `json:"input_type,omitempty"` | ||
| } | ||
|
|
||
| // EmbeddingValue represents a single embedding, which can be returned either as | ||
| // a vector of floats or as a base64 string depending on encoding_format. | ||
| type EmbeddingValue struct { | ||
| Vector []float64 | ||
| Base64 string | ||
| } | ||
|
|
||
| func (e *EmbeddingValue) UnmarshalJSON(data []byte) error { | ||
| // Try to unmarshal as []float64 first (encoding_format: "float"). | ||
| var vec []float64 | ||
| if err := json.Unmarshal(data, &vec); err == nil { | ||
| e.Vector = vec | ||
| e.Base64 = "" | ||
| return nil | ||
| } | ||
|
|
||
| // Fallback to string (encoding_format: "base64"). | ||
| var s string | ||
| if err := json.Unmarshal(data, &s); err == nil { | ||
| e.Base64 = s | ||
| e.Vector = nil | ||
| return nil | ||
| } | ||
|
|
||
| return fmt.Errorf("embedding: invalid format, expected []float64 or string") | ||
| } | ||
|
|
||
| // EmbeddingData represents a single embedding entry in the response. | ||
| type EmbeddingData struct { | ||
| Object string `json:"object"` | ||
| Embedding EmbeddingValue `json:"embedding"` | ||
| Index int `json:"index"` | ||
| } | ||
|
|
||
| // EmbeddingsUsage represents the token and cost statistics for an embeddings request. | ||
| type EmbeddingsUsage struct { | ||
| PromptTokens int `json:"prompt_tokens"` | ||
| TotalTokens int `json:"total_tokens"` | ||
| Cost float64 `json:"cost"` | ||
| } | ||
|
|
||
| // EmbeddingsResponse represents the response from the /embeddings endpoint. | ||
| type EmbeddingsResponse struct { | ||
| ID string `json:"id"` | ||
| Object string `json:"object"` | ||
| Data []EmbeddingData `json:"data"` | ||
| Model string `json:"model"` | ||
| Usage *EmbeddingsUsage `json:"usage,omitempty"` | ||
| } | ||
|
|
||
| // CreateEmbeddings submits an embedding request to the embeddings router. | ||
| // | ||
| // API reference: https://openrouter.ai/docs/api/api-reference/embeddings/create-embeddings | ||
| func (c *Client) CreateEmbeddings( | ||
| ctx context.Context, | ||
| request EmbeddingsRequest, | ||
| ) (response EmbeddingsResponse, err error) { | ||
| if !isSupportingModel(embeddingsSuffix, request.Model) { | ||
| // Keep behavior consistent with chat/completions: let the server return a | ||
| // proper API error until we implement local model validation. | ||
| } | ||
|
|
||
| req, err := c.newRequest( | ||
| ctx, | ||
| http.MethodPost, | ||
| c.fullURL(embeddingsSuffix), | ||
| withBody(request), | ||
| ) | ||
| if err != nil { | ||
| return | ||
|
||
| } | ||
|
|
||
| err = c.sendRequest(req, &response) | ||
| return | ||
| } | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,90 @@ | ||
| package openrouter | ||
|
|
||
| import ( | ||
| "context" | ||
| "io" | ||
| "net/http" | ||
| "strings" | ||
| "testing" | ||
|
|
||
| "github.com/stretchr/testify/require" | ||
| ) | ||
|
|
||
| type fakeHTTPClient struct { | ||
| lastRequest *http.Request | ||
| response *http.Response | ||
| err error | ||
| } | ||
|
|
||
| func (f *fakeHTTPClient) Do(req *http.Request) (*http.Response, error) { | ||
| f.lastRequest = req | ||
| if f.err != nil { | ||
| return nil, f.err | ||
| } | ||
| return f.response, nil | ||
| } | ||
|
|
||
| func TestCreateEmbeddings_Basic(t *testing.T) { | ||
| body := `{ | ||
| "id": "embd_123", | ||
| "object": "list", | ||
| "data": [ | ||
| { | ||
| "object": "embedding", | ||
| "embedding": [0.1, 0.2, 0.3], | ||
| "index": 0 | ||
| } | ||
| ], | ||
| "model": "test-embeddings-model", | ||
| "usage": { | ||
| "prompt_tokens": 5, | ||
| "total_tokens": 5, | ||
| "cost": 0.0001 | ||
| } | ||
| }` | ||
|
|
||
| fakeClient := &fakeHTTPClient{ | ||
| response: &http.Response{ | ||
| StatusCode: http.StatusOK, | ||
| Body: io.NopCloser(strings.NewReader(body)), | ||
| Header: make(http.Header), | ||
| }, | ||
| } | ||
|
|
||
| cfg := DefaultConfig("test-token") | ||
| cfg.BaseURL = "https://example.com/api/v1" | ||
| cfg.HTTPClient = fakeClient | ||
|
|
||
| client := NewClientWithConfig(*cfg) | ||
|
|
||
| req := EmbeddingsRequest{ | ||
| Model: "test-embeddings-model", | ||
| Input: "hello world", | ||
| } | ||
|
|
||
| resp, err := client.CreateEmbeddings(context.Background(), req) | ||
| require.NoError(t, err) | ||
|
|
||
| require.NotNil(t, fakeClient.lastRequest) | ||
| require.Equal(t, http.MethodPost, fakeClient.lastRequest.Method) | ||
| require.True(t, strings.HasSuffix(fakeClient.lastRequest.URL.Path, "/embeddings")) | ||
|
|
||
| require.Equal(t, "embd_123", resp.ID) | ||
| require.Equal(t, "list", resp.Object) | ||
| require.Equal(t, "test-embeddings-model", resp.Model) | ||
| require.NotNil(t, resp.Usage) | ||
| require.Equal(t, 5, resp.Usage.PromptTokens) | ||
| require.Len(t, resp.Data, 1) | ||
| require.Len(t, resp.Data[0].Embedding.Vector, 3) | ||
| } | ||
|
|
||
| func TestEmbeddingValue_UnmarshalJSON_Base64(t *testing.T) { | ||
| var v EmbeddingValue | ||
|
|
||
| err := v.UnmarshalJSON([]byte(`"dGVzdC1lbWJlZGRpbmc="`)) | ||
| require.NoError(t, err) | ||
| require.Nil(t, v.Vector) | ||
| require.Equal(t, "dGVzdC1lbWJlZGRpbmc=", v.Base64) | ||
| } | ||
|
|
||
|
|
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,36 @@ | ||
| package main | ||
|
|
||
| import ( | ||
| "context" | ||
| "encoding/json" | ||
| "fmt" | ||
| "os" | ||
|
|
||
| "github.com/revrost/go-openrouter" | ||
| ) | ||
|
|
||
| func main() { | ||
| ctx := context.Background() | ||
| client := openrouter.NewClient(os.Getenv("OPENROUTER_API_KEY")) | ||
|
|
||
| // Basic text embedding example | ||
| request := openrouter.EmbeddingsRequest{ | ||
| Model: "openai/text-embedding-3-large", | ||
| Input: []string{ | ||
| "Hello world", | ||
| "OpenRouter embeddings example", | ||
| }, | ||
| EncodingFormat: openrouter.EmbeddingsEncodingFormatFloat, | ||
| } | ||
|
|
||
| res, err := client.CreateEmbeddings(ctx, request) | ||
| if err != nil { | ||
| fmt.Println("error", err) | ||
| return | ||
| } | ||
|
|
||
| b, _ := json.MarshalIndent(res, "", "\t") | ||
| fmt.Printf("response :\n %s\n", string(b)) | ||
| } | ||
|
|
||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are we potentially missing a return err here? otherwise we could remove this if block if it is not currently implemented
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll remove this for now, the API always return an proper error if anyhow an invalid model is being used :)