Skip to content

Commit 39ca4e9

Browse files
authored
Implement chat completion streaming (#101)
* Implement chat completion streaming * Optimize the implementation of chat completion stream * Fix linter error
1 parent 58d99eb commit 39ca4e9

File tree

4 files changed

+142
-14
lines changed

4 files changed

+142
-14
lines changed

api.go

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
package gogpt
22

33
import (
4+
"bytes"
5+
"context"
46
"encoding/json"
57
"fmt"
68
"net/http"
@@ -79,3 +81,31 @@ func (c *Client) sendRequest(req *http.Request, v interface{}) error {
7981
func (c *Client) fullURL(suffix string) string {
8082
return fmt.Sprintf("%s%s", c.config.BaseURL, suffix)
8183
}
84+
85+
func (c *Client) newStreamRequest(
86+
ctx context.Context,
87+
method string,
88+
urlSuffix string,
89+
body interface{}) (*http.Request, error) {
90+
var reqBody []byte
91+
if body != nil {
92+
var err error
93+
reqBody, err = json.Marshal(body)
94+
if err != nil {
95+
return nil, err
96+
}
97+
}
98+
99+
req, err := http.NewRequestWithContext(ctx, method, c.fullURL(urlSuffix), bytes.NewBuffer(reqBody))
100+
if err != nil {
101+
return nil, err
102+
}
103+
104+
req.Header.Set("Content-Type", "application/json")
105+
req.Header.Set("Accept", "text/event-stream")
106+
req.Header.Set("Cache-Control", "no-cache")
107+
req.Header.Set("Connection", "keep-alive")
108+
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.config.authToken))
109+
110+
return req, nil
111+
}

chat.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ type ChatCompletionResponse struct {
4949
Usage Usage `json:"usage"`
5050
}
5151

52-
// CreateChatCompletion — API call to Creates a completion for the chat message.
52+
// CreateChatCompletion — API call to Create a completion for the chat message.
5353
func (c *Client) CreateChatCompletion(
5454
ctx context.Context,
5555
request ChatCompletionRequest,

chat_stream.go

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
package gogpt
2+
3+
import (
4+
"bufio"
5+
"bytes"
6+
"context"
7+
"encoding/json"
8+
"io"
9+
"net/http"
10+
)
11+
12+
type ChatCompletionStreamChoiceDelta struct {
13+
Content string `json:"content"`
14+
}
15+
16+
type ChatCompletionStreamChoice struct {
17+
Index int `json:"index"`
18+
Delta ChatCompletionStreamChoiceDelta `json:"delta"`
19+
FinishReason string `json:"finish_reason"`
20+
}
21+
22+
type ChatCompletionStreamResponse struct {
23+
ID string `json:"id"`
24+
Object string `json:"object"`
25+
Created int64 `json:"created"`
26+
Model string `json:"model"`
27+
Choices []ChatCompletionStreamChoice `json:"choices"`
28+
}
29+
30+
// ChatCompletionStream
31+
// Note: Perhaps it is more elegant to abstract Stream using generics.
32+
type ChatCompletionStream struct {
33+
emptyMessagesLimit uint
34+
isFinished bool
35+
36+
reader *bufio.Reader
37+
response *http.Response
38+
}
39+
40+
func (stream *ChatCompletionStream) Recv() (response ChatCompletionStreamResponse, err error) {
41+
if stream.isFinished {
42+
err = io.EOF
43+
return
44+
}
45+
46+
var emptyMessagesCount uint
47+
48+
waitForData:
49+
line, err := stream.reader.ReadBytes('\n')
50+
if err != nil {
51+
return
52+
}
53+
54+
var headerData = []byte("data: ")
55+
line = bytes.TrimSpace(line)
56+
if !bytes.HasPrefix(line, headerData) {
57+
emptyMessagesCount++
58+
if emptyMessagesCount > stream.emptyMessagesLimit {
59+
err = ErrTooManyEmptyStreamMessages
60+
return
61+
}
62+
63+
goto waitForData
64+
}
65+
66+
line = bytes.TrimPrefix(line, headerData)
67+
if string(line) == "[DONE]" {
68+
stream.isFinished = true
69+
err = io.EOF
70+
return
71+
}
72+
73+
err = json.Unmarshal(line, &response)
74+
return
75+
}
76+
77+
func (stream *ChatCompletionStream) Close() {
78+
stream.response.Body.Close()
79+
}
80+
81+
func (stream *ChatCompletionStream) GetResponse() *http.Response {
82+
return stream.response
83+
}
84+
85+
// CreateChatCompletionStream — API call to create a chat completion w/ streaming
86+
// support. It sets whether to stream back partial progress. If set, tokens will be
87+
// sent as data-only server-sent events as they become available, with the
88+
// stream terminated by a data: [DONE] message.
89+
func (c *Client) CreateChatCompletionStream(
90+
ctx context.Context,
91+
request ChatCompletionRequest,
92+
) (stream *ChatCompletionStream, err error) {
93+
request.Stream = true
94+
req, err := c.newStreamRequest(ctx, "POST", "/chat/completions", request)
95+
if err != nil {
96+
return
97+
}
98+
99+
resp, err := c.config.HTTPClient.Do(req) //nolint:bodyclose // body is closed in stream.Close()
100+
if err != nil {
101+
return
102+
}
103+
104+
stream = &ChatCompletionStream{
105+
emptyMessagesLimit: c.config.EmptyMessagesLimit,
106+
reader: bufio.NewReader(resp.Body),
107+
response: resp,
108+
}
109+
return
110+
}

stream.go

Lines changed: 1 addition & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@ import (
66
"context"
77
"encoding/json"
88
"errors"
9-
"fmt"
109
"io"
1110
"net/http"
1211
)
@@ -73,18 +72,7 @@ func (c *Client) CreateCompletionStream(
7372
request CompletionRequest,
7473
) (stream *CompletionStream, err error) {
7574
request.Stream = true
76-
reqBytes, err := json.Marshal(request)
77-
if err != nil {
78-
return
79-
}
80-
81-
urlSuffix := "/completions"
82-
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.fullURL(urlSuffix), bytes.NewBuffer(reqBytes))
83-
req.Header.Set("Content-Type", "application/json")
84-
req.Header.Set("Accept", "text/event-stream")
85-
req.Header.Set("Cache-Control", "no-cache")
86-
req.Header.Set("Connection", "keep-alive")
87-
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.config.authToken))
75+
req, err := c.newStreamRequest(ctx, "POST", "/completions", request)
8876
if err != nil {
8977
return
9078
}

0 commit comments

Comments
 (0)