Skip to content

Commit 4dc1eda

Browse files
authored
add embeddings tests (#237)
1 parent 89219e3 commit 4dc1eda

File tree

1 file changed

+44
-0
lines changed

1 file changed

+44
-0
lines changed

embeddings_test.go

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,14 @@ package openai_test
22

33
import (
44
. "github.com/sashabaranov/go-openai"
5+
"github.com/sashabaranov/go-openai/internal/test"
56
"github.com/sashabaranov/go-openai/internal/test/checks"
67

78
"bytes"
9+
"context"
810
"encoding/json"
11+
"fmt"
12+
"net/http"
913
"testing"
1014
)
1115

@@ -45,3 +49,43 @@ func TestEmbedding(t *testing.T) {
4549
}
4650
}
4751
}
52+
53+
func TestEmbeddingModel(t *testing.T) {
54+
var em EmbeddingModel
55+
err := em.UnmarshalText([]byte("text-similarity-ada-001"))
56+
checks.NoError(t, err, "Could not marshal embedding model")
57+
58+
if em != AdaSimilarity {
59+
t.Errorf("Model is not equal to AdaSimilarity")
60+
}
61+
62+
err = em.UnmarshalText([]byte("some-non-existent-model"))
63+
checks.NoError(t, err, "Could not marshal embedding model")
64+
if em != Unknown {
65+
t.Errorf("Model is not equal to Unknown")
66+
}
67+
}
68+
69+
func TestEmbeddingEndpoint(t *testing.T) {
70+
server := test.NewTestServer()
71+
server.RegisterHandler(
72+
"/v1/embeddings",
73+
func(w http.ResponseWriter, r *http.Request) {
74+
resBytes, _ := json.Marshal(EmbeddingResponse{})
75+
fmt.Fprintln(w, string(resBytes))
76+
},
77+
)
78+
// create the test server
79+
var err error
80+
ts := server.OpenAITestServer()
81+
ts.Start()
82+
defer ts.Close()
83+
84+
config := DefaultConfig(test.GetTestToken())
85+
config.BaseURL = ts.URL + "/v1"
86+
client := NewClientWithConfig(config)
87+
ctx := context.Background()
88+
89+
_, err = client.CreateEmbeddings(ctx, EmbeddingRequest{})
90+
checks.NoError(t, err, "CreateEmbeddings error")
91+
}

0 commit comments

Comments
 (0)