Skip to content

Commit 8677fb4

Browse files
authored
feat: add azure openai support (#214)
* feat: add azure openai support * chore: refine config * chore: make config options like the python one * chore: adjust config struct field order * test: fix tests * style: make the linter happy * fix: support Azure API Key authentication in sendRequest * chore: check error in CreateChatCompletionStream * chore: pass tests * chore: try pass tests again * chore: change ClientConfig back due to this lib does not like WithXxx config style * chore: revert fix to CreateChatCompletionStream() due to cause tests not pass * chore: at least add some comment about the required fields * chore: re order ClientConfig fields * chore: add DefaultAzure() * chore: set default api_version the same as py one "2023-03-15-preview" * style: fixup typo * test: add api_internal_test.go * style: make lint happy * chore: add constant AzureAPIKeyHeader * chore: use AzureAPIKeyHeader for api-key header, fix azure base url auto trim suffix / * test: add TestAzureFullURL, TestRequestAuthHeader and TestOpenAIFullURL * test: simplify TestRequestAuthHeader * test: refine TestOpenAIFullURL * chore: refine comments * feat: DefaultAzureConfig
1 parent bee0656 commit 8677fb4

File tree

3 files changed

+197
-8
lines changed

3 files changed

+197
-8
lines changed

api.go

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import (
55
"encoding/json"
66
"fmt"
77
"net/http"
8+
"strings"
89
)
910

1011
// Client is OpenAI GPT-3 API client.
@@ -39,7 +40,13 @@ func NewOrgClient(authToken, org string) *Client {
3940

4041
func (c *Client) sendRequest(req *http.Request, v interface{}) error {
4142
req.Header.Set("Accept", "application/json; charset=utf-8")
42-
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.config.authToken))
43+
// Azure API Key authentication
44+
if c.config.APIType == APITypeAzure {
45+
req.Header.Set(AzureAPIKeyHeader, c.config.authToken)
46+
} else {
47+
// OpenAI or Azure AD authentication
48+
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.config.authToken))
49+
}
4350

4451
// Check whether Content-Type is already set, Upload Files API requires
4552
// Content-Type == multipart/form-data
@@ -83,6 +90,15 @@ func (c *Client) sendRequest(req *http.Request, v interface{}) error {
8390
}
8491

8592
func (c *Client) fullURL(suffix string) string {
93+
// /openai/deployments/{engine}/chat/completions?api-version={api_version}
94+
if c.config.APIType == APITypeAzure || c.config.APIType == APITypeAzureAD {
95+
baseURL := c.config.BaseURL
96+
baseURL = strings.TrimRight(baseURL, "/")
97+
return fmt.Sprintf("%s/%s/%s/%s%s?api-version=%s",
98+
baseURL, azureAPIPrefix, azureDeploymentsPrefix, c.config.Engine, suffix, c.config.APIVersion)
99+
}
100+
101+
// c.config.APIType == APITypeOpenAI || c.config.APIType == ""
86102
return fmt.Sprintf("%s%s", c.config.BaseURL, suffix)
87103
}
88104

@@ -100,7 +116,14 @@ func (c *Client) newStreamRequest(
100116
req.Header.Set("Accept", "text/event-stream")
101117
req.Header.Set("Cache-Control", "no-cache")
102118
req.Header.Set("Connection", "keep-alive")
103-
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.config.authToken))
104119

120+
// https://learn.microsoft.com/en-us/azure/cognitive-services/openai/reference#authentication
121+
// Azure API Key authentication
122+
if c.config.APIType == APITypeAzure {
123+
req.Header.Set(AzureAPIKeyHeader, c.config.authToken)
124+
} else {
125+
// OpenAI or Azure AD authentication
126+
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.config.authToken))
127+
}
105128
return req, nil
106129
}

api_internal_test.go

Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
package openai
2+
3+
import (
4+
"context"
5+
"testing"
6+
)
7+
8+
func TestOpenAIFullURL(t *testing.T) {
9+
cases := []struct {
10+
Name string
11+
Suffix string
12+
Expect string
13+
}{
14+
{
15+
"ChatCompletionsURL",
16+
"/chat/completions",
17+
"https://api.openai.com/v1/chat/completions",
18+
},
19+
{
20+
"CompletionsURL",
21+
"/completions",
22+
"https://api.openai.com/v1/completions",
23+
},
24+
}
25+
26+
for _, c := range cases {
27+
t.Run(c.Name, func(t *testing.T) {
28+
az := DefaultConfig("dummy")
29+
cli := NewClientWithConfig(az)
30+
actual := cli.fullURL(c.Suffix)
31+
if actual != c.Expect {
32+
t.Errorf("Expected %s, got %s", c.Expect, actual)
33+
}
34+
t.Logf("Full URL: %s", actual)
35+
})
36+
}
37+
}
38+
39+
func TestRequestAuthHeader(t *testing.T) {
40+
cases := []struct {
41+
Name string
42+
APIType APIType
43+
HeaderKey string
44+
Token string
45+
Expect string
46+
}{
47+
{
48+
"OpenAIDefault",
49+
"",
50+
"Authorization",
51+
"dummy-token-openai",
52+
"Bearer dummy-token-openai",
53+
},
54+
{
55+
"OpenAI",
56+
APITypeOpenAI,
57+
"Authorization",
58+
"dummy-token-openai",
59+
"Bearer dummy-token-openai",
60+
},
61+
{
62+
"AzureAD",
63+
APITypeAzureAD,
64+
"Authorization",
65+
"dummy-token-azure",
66+
"Bearer dummy-token-azure",
67+
},
68+
{
69+
"Azure",
70+
APITypeAzure,
71+
AzureAPIKeyHeader,
72+
"dummy-api-key-here",
73+
"dummy-api-key-here",
74+
},
75+
}
76+
77+
for _, c := range cases {
78+
t.Run(c.Name, func(t *testing.T) {
79+
az := DefaultConfig(c.Token)
80+
az.APIType = c.APIType
81+
82+
cli := NewClientWithConfig(az)
83+
req, err := cli.newStreamRequest(context.Background(), "POST", "/chat/completions", nil)
84+
if err != nil {
85+
t.Errorf("Failed to create request: %v", err)
86+
}
87+
actual := req.Header.Get(c.HeaderKey)
88+
if actual != c.Expect {
89+
t.Errorf("Expected %s, got %s", c.Expect, actual)
90+
}
91+
t.Logf("%s: %s", c.HeaderKey, actual)
92+
})
93+
}
94+
}
95+
96+
func TestAzureFullURL(t *testing.T) {
97+
cases := []struct {
98+
Name string
99+
BaseURL string
100+
Engine string
101+
Expect string
102+
}{
103+
{
104+
"AzureBaseURLWithSlashAutoStrip",
105+
"https://httpbin.org/",
106+
"chatgpt-demo",
107+
"https://httpbin.org/" +
108+
"openai/deployments/chatgpt-demo" +
109+
"/chat/completions?api-version=2023-03-15-preview",
110+
},
111+
{
112+
"AzureBaseURLWithoutSlashOK",
113+
"https://httpbin.org",
114+
"chatgpt-demo",
115+
"https://httpbin.org/" +
116+
"openai/deployments/chatgpt-demo" +
117+
"/chat/completions?api-version=2023-03-15-preview",
118+
},
119+
}
120+
121+
for _, c := range cases {
122+
t.Run(c.Name, func(t *testing.T) {
123+
az := DefaultAzureConfig("dummy", c.BaseURL, c.Engine)
124+
cli := NewClientWithConfig(az)
125+
// /openai/deployments/{engine}/chat/completions?api-version={api_version}
126+
actual := cli.fullURL("/chat/completions")
127+
if actual != c.Expect {
128+
t.Errorf("Expected %s, got %s", c.Expect, actual)
129+
}
130+
t.Logf("Full URL: %s", actual)
131+
})
132+
}
133+
}

config.go

Lines changed: 39 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,28 +5,61 @@ import (
55
)
66

77
const (
8-
apiURLv1 = "https://api.openai.com/v1"
8+
openaiAPIURLv1 = "https://api.openai.com/v1"
99
defaultEmptyMessagesLimit uint = 300
10+
11+
azureAPIPrefix = "openai"
12+
azureDeploymentsPrefix = "deployments"
13+
)
14+
15+
type APIType string
16+
17+
const (
18+
APITypeOpenAI APIType = "OPEN_AI"
19+
APITypeAzure APIType = "AZURE"
20+
APITypeAzureAD APIType = "AZURE_AD"
1021
)
1122

23+
const AzureAPIKeyHeader = "api-key"
24+
1225
// ClientConfig is a configuration of a client.
1326
type ClientConfig struct {
1427
authToken string
1528

16-
HTTPClient *http.Client
29+
BaseURL string
30+
OrgID string
31+
APIType APIType
32+
APIVersion string // required when APIType is APITypeAzure or APITypeAzureAD
33+
Engine string // required when APIType is APITypeAzure or APITypeAzureAD
1734

18-
BaseURL string
19-
OrgID string
35+
HTTPClient *http.Client
2036

2137
EmptyMessagesLimit uint
2238
}
2339

2440
func DefaultConfig(authToken string) ClientConfig {
2541
return ClientConfig{
42+
authToken: authToken,
43+
BaseURL: openaiAPIURLv1,
44+
APIType: APITypeOpenAI,
45+
OrgID: "",
46+
2647
HTTPClient: &http.Client{},
27-
BaseURL: apiURLv1,
48+
49+
EmptyMessagesLimit: defaultEmptyMessagesLimit,
50+
}
51+
}
52+
53+
func DefaultAzureConfig(apiKey, baseURL, engine string) ClientConfig {
54+
return ClientConfig{
55+
authToken: apiKey,
56+
BaseURL: baseURL,
2857
OrgID: "",
29-
authToken: authToken,
58+
APIType: APITypeAzure,
59+
APIVersion: "2023-03-15-preview",
60+
Engine: engine,
61+
62+
HTTPClient: &http.Client{},
3063

3164
EmptyMessagesLimit: defaultEmptyMessagesLimit,
3265
}

0 commit comments

Comments
 (0)