Skip to content

Commit 53212c7

Browse files
authored
Migrate From Old Completions + Embedding Endpoint (#28)
* migrate away from deprecated OpenAI endpoints Signed-off-by: Oleg <[email protected]> * test embedding correctness Signed-off-by: Oleg <[email protected]>
1 parent 51f94a6 commit 53212c7

File tree

5 files changed

+69
-37
lines changed

5 files changed

+69
-37
lines changed

README.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,11 @@ func main() {
2727
ctx := context.Background()
2828

2929
req := gogpt.CompletionRequest{
30+
Model: "ada",
3031
MaxTokens: 5,
3132
Prompt: "Lorem ipsum",
3233
}
33-
resp, err := c.CreateCompletion(ctx, "ada", req)
34+
resp, err := c.CreateCompletion(ctx, req)
3435
if err != nil {
3536
return
3637
}

api_test.go

Lines changed: 48 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
package gogpt
22

33
import (
4+
"bytes"
45
"context"
6+
"encoding/json"
57
"io/ioutil"
68
"testing"
79
)
@@ -36,9 +38,12 @@ func TestAPI(t *testing.T) {
3638
}
3739
} // else skip
3840

39-
req := CompletionRequest{MaxTokens: 5}
41+
req := CompletionRequest{
42+
MaxTokens: 5,
43+
Model: "ada",
44+
}
4045
req.Prompt = "Lorem ipsum"
41-
_, err = c.CreateCompletion(ctx, "ada", req)
46+
_, err = c.CreateCompletion(ctx, req)
4247
if err != nil {
4348
t.Fatalf("CreateCompletion error: %v", err)
4449
}
@@ -57,9 +62,49 @@ func TestAPI(t *testing.T) {
5762
"The food was delicious and the waiter",
5863
"Other examples of embedding request",
5964
},
65+
Model: AdaSearchQuery,
6066
}
61-
_, err = c.CreateEmbeddings(ctx, embeddingReq, AdaSearchQuery)
67+
_, err = c.CreateEmbeddings(ctx, embeddingReq)
6268
if err != nil {
6369
t.Fatalf("Embedding error: %v", err)
6470
}
6571
}
72+
73+
func TestEmbedding(t *testing.T) {
74+
embeddedModels := []EmbeddingModel{
75+
AdaSimilarity,
76+
BabbageSimilarity,
77+
CurieSimilarity,
78+
DavinciSimilarity,
79+
AdaSearchDocument,
80+
AdaSearchQuery,
81+
BabbageSearchDocument,
82+
BabbageSearchQuery,
83+
CurieSearchDocument,
84+
CurieSearchQuery,
85+
DavinciSearchDocument,
86+
DavinciSearchQuery,
87+
AdaCodeSearchCode,
88+
AdaCodeSearchText,
89+
BabbageCodeSearchCode,
90+
BabbageCodeSearchText,
91+
}
92+
for _, model := range embeddedModels {
93+
embeddingReq := EmbeddingRequest{
94+
Input: []string{
95+
"The food was delicious and the waiter",
96+
"Other examples of embedding request",
97+
},
98+
Model: model,
99+
}
100+
// marshal embeddingReq to JSON and confirm that the model field equals
101+
// the AdaSearchQuery type
102+
marshaled, err := json.Marshal(embeddingReq)
103+
if err != nil {
104+
t.Fatalf("Could not marshal embedding request: %v", err)
105+
}
106+
if !bytes.Contains(marshaled, []byte(`"model":"`+model.String()+`"`)) {
107+
t.Fatalf("Expected embedding request to contain model field")
108+
}
109+
}
110+
}

completion.go

Lines changed: 7 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,12 @@ import (
44
"bytes"
55
"context"
66
"encoding/json"
7-
"fmt"
87
"net/http"
98
)
109

1110
// CompletionRequest represents a request structure for completion API
1211
type CompletionRequest struct {
13-
Model *string `json:"model,omitempty"`
12+
Model string `json:"model"`
1413
Prompt string `json:"prompt,omitempty"`
1514
MaxTokens int `json:"max_tokens,omitempty"`
1615
Temperature float32 `json:"temperature,omitempty"`
@@ -60,29 +59,12 @@ type CompletionResponse struct {
6059
Usage CompletionUsage `json:"usage"`
6160
}
6261

63-
// CreateCompletion — API call to create a completion. This is the main endpoint of the API. Returns new text as well as, if requested, the probabilities over each alternative token at each position.
64-
func (c *Client) CreateCompletion(ctx context.Context, engineID string, request CompletionRequest) (response CompletionResponse, err error) {
65-
var reqBytes []byte
66-
reqBytes, err = json.Marshal(request)
67-
if err != nil {
68-
return
69-
}
70-
71-
urlSuffix := fmt.Sprintf("/engines/%s/completions", engineID)
72-
req, err := http.NewRequest("POST", c.fullURL(urlSuffix), bytes.NewBuffer(reqBytes))
73-
if err != nil {
74-
return
75-
}
76-
77-
req = req.WithContext(ctx)
78-
err = c.sendRequest(req, &response)
79-
return
80-
}
81-
82-
// CreateCompletionWithFineTunedModel - API call to create a completion with a fine tuned model
83-
// See https://beta.openai.com/docs/guides/fine-tuning/use-a-fine-tuned-model
84-
// In this case, the model is specified in the CompletionRequest object.
85-
func (c *Client) CreateCompletionWithFineTunedModel(ctx context.Context, request CompletionRequest) (response CompletionResponse, err error) {
62+
// CreateCompletion — API call to create a completion. This is the main endpoint of the API. Returns new text as well
63+
// as, if requested, the probabilities over each alternative token at each position.
64+
//
65+
// If using a fine-tuned model, simply provide the model's ID in the CompletionRequest object,
66+
// and the server will use the model's parameters to generate the completion.
67+
func (c *Client) CreateCompletion(ctx context.Context, request CompletionRequest) (response CompletionResponse, err error) {
8668
var reqBytes []byte
8769
reqBytes, err = json.Marshal(request)
8870
if err != nil {

embeddings.go

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@ import (
44
"bytes"
55
"context"
66
"encoding/json"
7-
"fmt"
87
"net/http"
98
)
109

@@ -120,18 +119,23 @@ type EmbeddingRequest struct {
120119
// E.g.
121120
// "The food was delicious and the waiter..."
122121
Input []string `json:"input"`
122+
// ID of the model to use. You can use the List models API to see all of your available models,
123+
// or see our Model overview for descriptions of them.
124+
Model EmbeddingModel `json:"model"`
125+
// A unique identifier representing your end-user, which will help OpenAI to monitor and detect abuse.
126+
User string `json:"user"`
123127
}
124128

125129
// CreateEmbeddings returns an EmbeddingResponse which will contain an Embedding for every item in |request.Input|.
126130
// https://beta.openai.com/docs/api-reference/embeddings/create
127-
func (c *Client) CreateEmbeddings(ctx context.Context, request EmbeddingRequest, model EmbeddingModel) (resp EmbeddingResponse, err error) {
131+
func (c *Client) CreateEmbeddings(ctx context.Context, request EmbeddingRequest) (resp EmbeddingResponse, err error) {
128132
var reqBytes []byte
129133
reqBytes, err = json.Marshal(request)
130134
if err != nil {
131135
return
132136
}
133137

134-
urlSuffix := fmt.Sprintf("/engines/%s/embeddings", model)
138+
urlSuffix := "/embeddings"
135139
req, err := http.NewRequest(http.MethodPost, c.fullURL(urlSuffix), bytes.NewBuffer(reqBytes))
136140
if err != nil {
137141
return

search.go

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,11 @@ import (
1717
*/
1818
type SearchRequest struct {
1919
Query string `json:"query"`
20-
Documents []string `json:"documents"` // 1*
21-
FileID string `json:"file"` // 1*
22-
MaxRerank int `json:"max_rerank"` // 2*
23-
ReturnMetadata bool `json:"return_metadata"`
24-
User string `json:"user"`
20+
Documents []string `json:"documents"` // 1*
21+
FileID string `json:"file,omitempty"` // 1*
22+
MaxRerank int `json:"max_rerank,omitempty"` // 2*
23+
ReturnMetadata bool `json:"return_metadata,omitempty"`
24+
User string `json:"user,omitempty"`
2525
}
2626

2727
// SearchResult represents single result from search API

0 commit comments

Comments
 (0)