This repository was archived by the owner on Oct 22, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathstream_reader.go
More file actions
100 lines (89 loc) · 1.82 KB
/
stream_reader.go
File metadata and controls
100 lines (89 loc) · 1.82 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
package azopenai
import (
"bufio"
"bytes"
"encoding/json"
"fmt"
"io"
)
var (
DefaultStreamReaderPrefix = []byte("data: ")
DefaultStreamReaderStop = []byte("[DONE]")
DefaultStreamReaderDelim = byte('\n')
)
// StreamReader is a stream reader for Azure OpenAI API when stream enabled.
type StreamReader[T any] struct {
Reader io.ReadCloser
Prefix []byte
Stop []byte
Delim byte
UnMarshaler func([]byte, any) error
}
// RecvChan returns a channel that receives the stream data.
func (s *StreamReader[T]) RecvChan(errCallback func(error)) (<-chan T, error) {
if err := s.defaults(); err != nil {
return nil, err
}
if errCallback == nil {
errCallback = func(error) {}
}
ch := make(chan T)
reader := bufio.NewReader(s.Reader)
var delta T
var err error
go func() {
defer close(ch)
defer func() {
if err != nil && err != io.EOF {
errCallback(err)
}
}()
defer s.Reader.Close()
for err == nil {
delta, err = s.read(reader)
if err != nil {
return
}
ch <- delta
}
}()
return ch, nil
}
func (s *StreamReader[T]) defaults() error {
if s.Reader == nil {
return fmt.Errorf("StreamReader.Reader is nil")
}
if s.Prefix == nil {
s.Prefix = DefaultStreamReaderPrefix
}
if s.Stop == nil {
s.Stop = DefaultStreamReaderStop
}
if s.Delim == 0 {
s.Delim = DefaultStreamReaderDelim
}
if s.UnMarshaler == nil {
s.UnMarshaler = json.Unmarshal
}
return nil
}
func (s *StreamReader[T]) read(reader *bufio.Reader) (T, error) {
var rv T
var line []byte
var err error
for len(line) == 0 {
line, err = reader.ReadBytes(s.Delim)
if err != nil {
return rv, err
}
line = bytes.TrimSpace(line)
}
line = bytes.TrimPrefix(line, s.Prefix)
if bytes.Equal(line, s.Stop) {
return rv, io.EOF
}
if err := s.UnMarshaler(line, &rv); err != nil {
return rv, err
}
return rv, nil
}