Skip to content

Commit d30ddaf

Browse files
authored
New streaming api (#91)
This adds functions StreamPredictionText() and StreamPredictionFiles() that stream output from models that return iterators of strings or files, respectively.
1 parent 6158992 commit d30ddaf

File tree

10 files changed

+826
-1
lines changed

10 files changed

+826
-1
lines changed

go.mod

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ toolchain go1.22.0
66

77
require (
88
github.com/stretchr/testify v1.8.4
9+
github.com/vincent-petithory/dataurl v1.0.0
910
golang.org/x/sync v0.6.0
1011
)
1112

go.sum

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@ github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZb
1111
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
1212
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
1313
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
14+
github.com/vincent-petithory/dataurl v1.0.0 h1:cXw+kPto8NLuJtlMsI152irrVw9fRDX8AbShPRpg2CI=
15+
github.com/vincent-petithory/dataurl v1.0.0/go.mod h1:FHafX5vmDzyP+1CQATJn7WFKc9CvnvxyvZy6I1MrG/U=
1416
golang.org/x/sync v0.6.0 h1:5BMeUDZ7vkXGfEr1x9B4bRcTH4lpkTkpdh0T/J+qjbQ=
1517
golang.org/x/sync v0.6.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
1618
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=

identifier_test.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,9 @@ package replicate_test
33
import (
44
"testing"
55

6-
"github.com/replicate/replicate-go"
76
"github.com/stretchr/testify/assert"
7+
8+
"github.com/replicate/replicate-go"
89
)
910

1011
func TestValidWithVersion(t *testing.T) {

internal/sse/decoder.go

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
package sse
2+
3+
import (
4+
"bufio"
5+
"bytes"
6+
"io"
7+
"strings"
8+
)
9+
10+
type Event struct {
11+
Type string
12+
ID string
13+
Data string
14+
}
15+
16+
type Decoder struct {
17+
r *bufio.Reader
18+
}
19+
20+
func NewDecoder(r io.Reader) *Decoder {
21+
return &Decoder{r: bufio.NewReader(r)}
22+
}
23+
24+
var (
25+
eventField = []byte("event:")
26+
dataField = []byte("data:")
27+
idField = []byte("id:")
28+
retryField = []byte("retry:")
29+
space = []byte{' '}
30+
)
31+
32+
func buildEvent(t, id string, data *strings.Builder) Event {
33+
return Event{
34+
Type: t,
35+
ID: id,
36+
Data: data.String(),
37+
}
38+
}
39+
40+
func (d *Decoder) Next() (Event, error) {
41+
var t, id string
42+
var data strings.Builder
43+
for {
44+
line, err := d.r.ReadBytes('\n')
45+
if err == io.EOF {
46+
return buildEvent(t, id, &data), io.ErrUnexpectedEOF
47+
}
48+
if err != nil {
49+
return buildEvent(t, id, &data), err
50+
}
51+
52+
switch {
53+
case line[0] == '\n':
54+
// a blank line finishes the event, so we return it
55+
return buildEvent(t, id, &data), nil
56+
case bytes.HasPrefix(line, eventField):
57+
t = string(bytes.TrimPrefix(line[6:len(line)-1], space))
58+
case bytes.HasPrefix(line, dataField):
59+
// strings.Builder.Write() always returns nil error, so we don't
60+
// need to handle it
61+
data.Write(bytes.TrimPrefix(line[5:], space))
62+
case bytes.HasPrefix(line, idField):
63+
id = string(bytes.TrimPrefix(line[3:len(line)-1], space))
64+
case bytes.HasPrefix(line, retryField):
65+
default:
66+
// ignore the line
67+
}
68+
69+
}
70+
}

internal/sse/decoder_test.go

Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
package sse_test
2+
3+
import (
4+
"fmt"
5+
"io"
6+
"strings"
7+
"testing"
8+
9+
"github.com/stretchr/testify/assert"
10+
"github.com/stretchr/testify/require"
11+
12+
"github.com/replicate/replicate-go/internal/sse"
13+
)
14+
15+
func TestDecodeOneEventNoSpace(t *testing.T) {
16+
input := `event:output
17+
id:123abc
18+
data:giraffe
19+
20+
`
21+
d := sse.NewDecoder(strings.NewReader(input))
22+
23+
e, err := d.Next()
24+
25+
require.NoError(t, err)
26+
27+
assert.Equal(t, "output", e.Type)
28+
assert.Equal(t, "123abc", e.ID)
29+
assert.Equal(t, "giraffe\n", e.Data)
30+
}
31+
32+
func TestDecodeOneEventWithSpace(t *testing.T) {
33+
input := `event: output
34+
id: 123abc
35+
data: giraffe
36+
37+
`
38+
d := sse.NewDecoder(strings.NewReader(input))
39+
40+
e, err := d.Next()
41+
42+
require.NoError(t, err)
43+
44+
assert.Equal(t, "output", e.Type)
45+
assert.Equal(t, "123abc", e.ID)
46+
// only one space should be trimmed
47+
assert.Equal(t, " giraffe\n", e.Data)
48+
}
49+
50+
func TestDecodeOneEventMultipleData(t *testing.T) {
51+
input := `event:output
52+
data:giraffe
53+
data:rhino
54+
data:wombat
55+
56+
`
57+
d := sse.NewDecoder(strings.NewReader(input))
58+
59+
e, err := d.Next()
60+
61+
require.NoError(t, err)
62+
63+
assert.Equal(t, "output", e.Type)
64+
assert.Equal(t, "giraffe\nrhino\nwombat\n", e.Data)
65+
}
66+
67+
func TestDecodeOneEventHugeData(t *testing.T) {
68+
// this test is mainly to make sure we're not constrained by the
69+
// bufio.Reader buffer size
70+
input := fmt.Sprintf(`event:output
71+
data:%s
72+
73+
`, strings.Repeat("0123456789abcdef", 1_000_000))
74+
d := sse.NewDecoder(strings.NewReader(input))
75+
76+
e, err := d.Next()
77+
78+
require.NoError(t, err)
79+
80+
assert.Equal(t, "output", e.Type)
81+
// 16_000_000 data bytes and the terminal LF character
82+
assert.Equal(t, 16_000_001, len(e.Data))
83+
}
84+
85+
func TestDecodeManyEvents(t *testing.T) {
86+
input := `event:output
87+
id:alpha1
88+
data:giraffe
89+
90+
event:output
91+
id:bravo2
92+
data:rhino
93+
94+
event:output
95+
id:gamma3
96+
data:pine marten
97+
98+
`
99+
d := sse.NewDecoder(strings.NewReader(input))
100+
101+
e, err := d.Next()
102+
103+
require.NoError(t, err)
104+
105+
assert.Equal(t, "output", e.Type)
106+
assert.Equal(t, "alpha1", e.ID)
107+
assert.Equal(t, "giraffe\n", e.Data)
108+
109+
e, err = d.Next()
110+
111+
require.NoError(t, err)
112+
113+
assert.Equal(t, "output", e.Type)
114+
assert.Equal(t, "bravo2", e.ID)
115+
assert.Equal(t, "rhino\n", e.Data)
116+
117+
e, err = d.Next()
118+
119+
require.NoError(t, err)
120+
121+
assert.Equal(t, "output", e.Type)
122+
assert.Equal(t, "gamma3", e.ID)
123+
assert.Equal(t, "pine marten\n", e.Data)
124+
}
125+
126+
func TestDecodeEarlyEOF(t *testing.T) {
127+
input := ``
128+
d := sse.NewDecoder(strings.NewReader(input))
129+
130+
_, err := d.Next()
131+
132+
assert.ErrorIs(t, err, io.ErrUnexpectedEOF)
133+
}

internal/sse/stream.go

Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
package sse
2+
3+
import (
4+
"context"
5+
"errors"
6+
"fmt"
7+
"io"
8+
"net/http"
9+
"time"
10+
)
11+
12+
// Backoff is a copy of replicate.Backoff to avoid import cycles
13+
type Backoff interface {
14+
NextDelay(retries int) time.Duration
15+
}
16+
17+
type Streamer struct {
18+
c *http.Client
19+
url string
20+
maxRetries int
21+
backoff Backoff
22+
23+
attempt int
24+
lastEventID string
25+
26+
decoder *Decoder
27+
currentStream io.ReadCloser
28+
}
29+
30+
func NewStreamer(c *http.Client, url string, maxRetries int, backoff Backoff) *Streamer {
31+
return &Streamer{
32+
c: c,
33+
url: url,
34+
maxRetries: maxRetries,
35+
backoff: backoff,
36+
}
37+
}
38+
39+
var ErrMaximumRetries = errors.New("Exceeded maximum retries")
40+
41+
// connect (re-)establishes the connection to the SSE server. It only returns an
42+
// error if it cannot recover through retries.
43+
func (s *Streamer) connect(ctx context.Context) error {
44+
for {
45+
if s.attempt > s.maxRetries {
46+
return ErrMaximumRetries
47+
}
48+
49+
delay := 0 * time.Second
50+
if s.attempt > 0 {
51+
// delay on connection retry
52+
delay = s.backoff.NextDelay(s.attempt - 1)
53+
}
54+
s.attempt++
55+
reconnectDelay := time.NewTimer(delay)
56+
// once we only support go 1.23+, we can use time.After() here and simplify
57+
defer reconnectDelay.Stop()
58+
select {
59+
case <-ctx.Done():
60+
return ctx.Err()
61+
case <-reconnectDelay.C:
62+
}
63+
64+
req, err := http.NewRequestWithContext(ctx, http.MethodGet, s.url, nil)
65+
if err != nil {
66+
return err
67+
}
68+
req.Header.Set("Accept", "text/event-stream")
69+
70+
if s.lastEventID != "" {
71+
req.Header.Set("Last-Event-ID", s.lastEventID)
72+
}
73+
74+
//nolint:bodyclose
75+
resp, err := s.c.Do(req)
76+
if err != nil {
77+
// try again
78+
continue
79+
}
80+
81+
if resp.StatusCode != http.StatusOK {
82+
return fmt.Errorf("received invalid status code: %d", resp.StatusCode)
83+
}
84+
85+
if s.currentStream != nil {
86+
err = s.currentStream.Close()
87+
if err != nil {
88+
return err
89+
}
90+
}
91+
s.currentStream = resp.Body
92+
s.decoder = NewDecoder(s.currentStream)
93+
return nil
94+
}
95+
}
96+
97+
func (s *Streamer) NextEvent(ctx context.Context) (*Event, error) {
98+
if s.decoder == nil {
99+
if err := s.connect(ctx); err != nil {
100+
return nil, err
101+
}
102+
}
103+
for {
104+
e, err := s.decoder.Next()
105+
if err != nil {
106+
if errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF) {
107+
if err = s.connect(ctx); err != nil {
108+
return nil, err
109+
}
110+
continue
111+
}
112+
return nil, err
113+
}
114+
s.lastEventID = e.ID
115+
return &e, nil
116+
}
117+
}
118+
119+
func (s *Streamer) Close() error {
120+
if s.currentStream != nil {
121+
return s.currentStream.Close()
122+
}
123+
return nil
124+
}

0 commit comments

Comments
 (0)