@@ -2,10 +2,14 @@ package openai_test
2
2
3
3
import (
4
4
. "github.com/sashabaranov/go-openai"
5
+ "github.com/sashabaranov/go-openai/internal/test"
5
6
"github.com/sashabaranov/go-openai/internal/test/checks"
6
7
7
8
"bytes"
9
+ "context"
8
10
"encoding/json"
11
+ "fmt"
12
+ "net/http"
9
13
"testing"
10
14
)
11
15
@@ -45,3 +49,43 @@ func TestEmbedding(t *testing.T) {
45
49
}
46
50
}
47
51
}
52
+
53
+ func TestEmbeddingModel (t * testing.T ) {
54
+ var em EmbeddingModel
55
+ err := em .UnmarshalText ([]byte ("text-similarity-ada-001" ))
56
+ checks .NoError (t , err , "Could not marshal embedding model" )
57
+
58
+ if em != AdaSimilarity {
59
+ t .Errorf ("Model is not equal to AdaSimilarity" )
60
+ }
61
+
62
+ err = em .UnmarshalText ([]byte ("some-non-existent-model" ))
63
+ checks .NoError (t , err , "Could not marshal embedding model" )
64
+ if em != Unknown {
65
+ t .Errorf ("Model is not equal to Unknown" )
66
+ }
67
+ }
68
+
69
+ func TestEmbeddingEndpoint (t * testing.T ) {
70
+ server := test .NewTestServer ()
71
+ server .RegisterHandler (
72
+ "/v1/embeddings" ,
73
+ func (w http.ResponseWriter , r * http.Request ) {
74
+ resBytes , _ := json .Marshal (EmbeddingResponse {})
75
+ fmt .Fprintln (w , string (resBytes ))
76
+ },
77
+ )
78
+ // create the test server
79
+ var err error
80
+ ts := server .OpenAITestServer ()
81
+ ts .Start ()
82
+ defer ts .Close ()
83
+
84
+ config := DefaultConfig (test .GetTestToken ())
85
+ config .BaseURL = ts .URL + "/v1"
86
+ client := NewClientWithConfig (config )
87
+ ctx := context .Background ()
88
+
89
+ _ , err = client .CreateEmbeddings (ctx , EmbeddingRequest {})
90
+ checks .NoError (t , err , "CreateEmbeddings error" )
91
+ }
0 commit comments