Skip to content

Commit 2b5eb01

Browse files
authored
Merge pull request #1 from wille/batch
- Support request batching - endpoint.headers setting
2 parents 15cb10d + a5ffc00 commit 2b5eb01

File tree

10 files changed

+355
-49
lines changed

10 files changed

+355
-49
lines changed

internal/http.go

Lines changed: 51 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ var defaultClient = &http.Client{
2525
},
2626
}
2727

28-
func ProxyHTTP(ctx context.Context, endpoint *Endpoint, req *rpc.Request, timing *servertiming.Header) (*rpc.Response, *Provider, error) {
28+
func ProxyHTTP(ctx context.Context, endpoint *Endpoint, req *rpc.BatchRequest, timing *servertiming.Header) (*rpc.BatchResponse, *Provider, error) {
2929
providers := endpoint.GetActiveProviders()
3030

3131
for _, provider := range providers {
@@ -66,6 +66,13 @@ func SendHTTPRequest(ctx context.Context, provider *Provider, body []byte) ([]by
6666
}
6767

6868
req.Header = make(http.Header)
69+
70+
if provider.Headers != nil {
71+
for k, v := range provider.Headers {
72+
req.Header.Set(k, v)
73+
}
74+
}
75+
6976
req.Header.Set("User-Agent", UserAgent)
7077
req.Header.Set("Content-Type", "application/json; charset=utf-8")
7178

@@ -93,15 +100,15 @@ func SendHTTPRequest(ctx context.Context, provider *Provider, body []byte) ([]by
93100
return b, nil
94101
}
95102

96-
func SendHTTPRPCRequest(ctx context.Context, p *Provider, rpcreq *rpc.Request) (*rpc.Response, error) {
97-
req := rpc.SerializeRequest(rpcreq)
103+
func SendHTTPRPCRequest(ctx context.Context, p *Provider, req *rpc.BatchRequest) (*rpc.BatchResponse, error) {
104+
body := rpc.SerializeBatchRequest(req)
98105

99-
b, err := SendHTTPRequest(ctx, p, req)
106+
b, err := SendHTTPRequest(ctx, p, body)
100107
if err != nil {
101108
return nil, err
102109
}
103110

104-
response, err := rpc.DecodeResponse(b)
111+
response, err := rpc.DecodeBatchResponse(b)
105112
if err != nil {
106113
return nil, fmt.Errorf("bad response: %w, raw: %s", err, string(b))
107114
}
@@ -127,19 +134,26 @@ func IncomingHttpHandler(ctx context.Context, endpoint *Endpoint, w http.Respons
127134
return
128135
}
129136

130-
rpcReq, err := rpc.DecodeRequest(body)
137+
req, err := rpc.DecodeBatchRequest(body)
131138
if err != nil {
132-
log.Error("http: bad request", "error", err, "msg", rpc.FormatRawBody(string(body)))
139+
log.Error("http: bad request", "error", err, "body", rpc.FormatRawBody(string(body)))
133140
http.Error(w, "bad request", http.StatusBadRequest)
134141
return
135142
}
136143

137-
log = log.With("rpc_id", rpc.GetRequestIDString(rpcReq.ID), "method", rpcReq.Method)
144+
if req.IsBatch {
145+
rpc.BatchIDCounter++
146+
log = log.With("batch_id", rpc.BatchIDCounter, "batch_size", len(req.Requests))
147+
} else {
148+
log = log.With("rpc_id", req.Requests[0].GetID(), "method", req.Requests[0].Method)
149+
}
138150

139-
res, provider, err := ProxyHTTP(ctx, endpoint, rpcReq, timing)
151+
res, provider, err := ProxyHTTP(ctx, endpoint, req, timing)
140152

141153
if err != nil {
142-
metrics.RecordRequest(endpoint.Name, provider.Name, "http", rpcReq.Method, time.Since(start).Seconds(), true)
154+
for _, req := range req.Requests {
155+
metrics.RecordFailedRequest(endpoint.Name, provider.Name, "http", req.Method)
156+
}
143157

144158
if err == ErrNoProvidersAvailable {
145159
log.Error("no providers available")
@@ -154,9 +168,33 @@ func IncomingHttpHandler(ctx context.Context, endpoint *Endpoint, w http.Respons
154168

155169
log = log.With("provider", provider.Name, "request_time", time.Since(start))
156170

157-
log.Debug("request")
171+
for i, res := range req.Requests {
172+
if req.IsBatch {
173+
log.Debug("request", "batch_index", i, "rpc_id", res.GetID(), "method", res.Method)
174+
} else {
175+
// id and method is already set for this single request
176+
log.Debug("request")
177+
}
178+
}
158179

159-
metrics.RecordRequest(endpoint.Name, provider.Name, "http", rpcReq.Method, time.Since(start).Seconds(), res.IsError())
180+
for i, res := range res.Responses {
181+
method := req.Requests[i].Method
182+
183+
if res.IsError() {
184+
log.Error("error", "error", res.Error)
185+
metrics.RecordFailedRequest(endpoint.Name, provider.Name, "http", method)
186+
} else {
187+
metrics.RecordRequest(endpoint.Name, provider.Name, "http", method, time.Since(start).Seconds())
188+
}
189+
}
190+
191+
if req.IsBatch {
192+
if len(res.Responses) != len(req.Requests) {
193+
log.Error("batch response size mismatch", "request_size", len(req.Requests), "response_size", len(res.Responses))
194+
http.Error(w, "batch response size mismatch", http.StatusInternalServerError)
195+
return
196+
}
197+
}
160198

161199
if !endpoint.Public {
162200
w.Header().Set("X-Provider", provider.Name)
@@ -168,7 +206,7 @@ func IncomingHttpHandler(ctx context.Context, endpoint *Endpoint, w http.Respons
168206

169207
w.Header().Set("Content-Type", "application/json; charset=utf-8")
170208

171-
_, err = w.Write(rpc.SerializeResponse(res))
209+
_, err = w.Write(rpc.SerializeBatchResponse(res))
172210
if err != nil {
173211
log.Error("error writing body", "error", err)
174212
return

internal/metrics/metrics.go

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -74,14 +74,15 @@ func MetricsHandler() http.Handler {
7474
}
7575

7676
// RecordRequest records metrics for a request
77-
func RecordRequest(endpoint, provider, transport, method string, duration float64, failed bool) {
77+
func RecordRequest(endpoint, provider, transport, method string, duration float64) {
7878
requestsTotal.WithLabelValues(endpoint, provider, transport, method).Inc()
79-
if failed {
80-
failedRequestsTotal.WithLabelValues(endpoint, provider, transport, method).Inc()
81-
}
8279
requestDuration.WithLabelValues(endpoint, provider, transport, method).Observe(duration)
8380
}
8481

82+
func RecordFailedRequest(endpoint, provider, transport, method string) {
83+
failedRequestsTotal.WithLabelValues(endpoint, provider, transport, method).Inc()
84+
}
85+
8586
func RecordOpenConnection(endpoint, provider string) {
8687
openConnections.WithLabelValues(endpoint, provider, "ws").Inc()
8788
totalConnections.WithLabelValues(endpoint, provider, "ws").Inc()

internal/provider.go

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@ type Provider struct {
2323
Http string `yaml:"http"`
2424
Ws string `yaml:"ws"`
2525

26+
Headers map[string]string `yaml:"headers,omitempty"`
27+
2628
// The timeout for the provider. Use Endpoint.GetTimeout() to get the actual timeout
2729
Timeout time.Duration `yaml:"timeout,omitempty"`
2830

@@ -130,17 +132,17 @@ func (e *Provider) Healthcheck(p *Endpoint) error {
130132
ctx := context.Background()
131133

132134
fn := func(ctx context.Context, req *rpc.Request, errRpcError bool) (*rpc.Response, error) {
133-
res, err := SendHTTPRPCRequest(ctx, e, req)
135+
res, err := SendHTTPRPCRequest(ctx, e, rpc.NewBatchRequest(req))
134136
if err != nil {
135137
return nil, err
136138
}
137139

138-
if errRpcError && res.IsError() {
139-
_, err := res.GetError()
140+
if errRpcError && res.Responses[0].IsError() {
141+
_, err := res.Responses[0].GetError()
140142
return nil, err
141143
}
142144

143-
return res, nil
145+
return res.Responses[0], nil
144146
}
145147

146148
switch p.Kind {

internal/provider_test.go

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,10 @@ func TestRateLimitWithRetryAfter(t *testing.T) {
7070
}
7171

7272
// Send a request that will trigger rate limiting
73-
_, _, err := ProxyHTTP(context.Background(), provider, rpc.NewRequest("1", "eth_blockNumber", []interface{}{}), &servertiming.Header{})
73+
req := rpc.NewBatchRequest(
74+
rpc.NewRequest("1", "eth_blockNumber", []interface{}{}),
75+
)
76+
_, _, err := ProxyHTTP(context.Background(), provider, req, &servertiming.Header{})
7477

7578
// Verify the error and retry time was set correctly
7679
assert.Error(t, err)
@@ -126,7 +129,10 @@ func TestSlowProvider(t *testing.T) {
126129
timing := &servertiming.Header{}
127130

128131
// Attempt to proxy a request - it should use the fast provider because the first one is too slow
129-
resp, endpoint, err := ProxyHTTP(context.Background(), provider, rpc.NewRequest("1", "eth_blockNumber", []interface{}{}), timing)
132+
req := rpc.NewBatchRequest(
133+
rpc.NewRequest("1", "eth_blockNumber", []interface{}{}),
134+
)
135+
resp, endpoint, err := ProxyHTTP(context.Background(), provider, req, timing)
130136

131137
// Verify we got a response
132138
assert.NoError(t, err)
@@ -166,7 +172,10 @@ func TestNonRespondingProvider(t *testing.T) {
166172

167173
// Attempt to use the non-responding provider first
168174
timing := &servertiming.Header{}
169-
resp, endpoint, err := ProxyHTTP(context.Background(), provider, rpc.NewRequest("1", "eth_blockNumber", []interface{}{}), timing)
175+
req := rpc.NewBatchRequest(
176+
rpc.NewRequest("1", "eth_blockNumber", []interface{}{}),
177+
)
178+
resp, endpoint, err := ProxyHTTP(context.Background(), provider, req, timing)
170179

171180
// Verify we got a response from the working provider
172181
assert.NoError(t, err)

internal/rpc/batch.go

Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
1+
package rpc
2+
3+
import (
4+
"encoding/json"
5+
"fmt"
6+
)
7+
8+
type BatchResponse struct {
9+
Responses []*Response
10+
IsBatch bool
11+
}
12+
13+
var _ json.Unmarshaler = &BatchResponse{}
14+
var _ json.Marshaler = &BatchResponse{}
15+
16+
func NewBatchResponse(res ...*Response) *BatchResponse {
17+
return &BatchResponse{
18+
Responses: res,
19+
IsBatch: len(res) > 1,
20+
}
21+
}
22+
23+
// UnmarshalJSON for a batch request supports decoding a single request as well
24+
func (r *BatchResponse) UnmarshalJSON(b []byte) error {
25+
switch b[0] {
26+
case '[':
27+
var res []*Response
28+
err := json.Unmarshal(b, &res)
29+
if err != nil {
30+
return err
31+
}
32+
r.Responses = res
33+
r.IsBatch = true
34+
return nil
35+
case '{':
36+
var res Response
37+
err := json.Unmarshal(b, &res)
38+
if err != nil {
39+
return err
40+
}
41+
r.Responses = []*Response{&res}
42+
r.IsBatch = false
43+
return nil
44+
}
45+
46+
return fmt.Errorf("invalid request: %s", FormatRawBody(string(b)))
47+
}
48+
49+
// MarshalJSON will encode a single (non batched) request to a single object or multiple requests into an array
50+
func (r BatchResponse) MarshalJSON() ([]byte, error) {
51+
if len(r.Responses) == 0 {
52+
return nil, fmt.Errorf("empty batch response")
53+
}
54+
55+
// If the batch is just one single request then unbatch it
56+
if !r.IsBatch {
57+
return json.Marshal(r.Responses[0])
58+
}
59+
60+
return json.Marshal(r.Responses)
61+
}
62+
63+
type BatchRequest struct {
64+
Requests []*Request
65+
IsBatch bool
66+
}
67+
68+
var _ json.Unmarshaler = &BatchRequest{}
69+
var _ json.Marshaler = &BatchRequest{}
70+
71+
func NewBatchRequest(req ...*Request) *BatchRequest {
72+
return &BatchRequest{
73+
Requests: req,
74+
IsBatch: len(req) > 1,
75+
}
76+
}
77+
78+
// UnmarshalJSON for a batch request supports decoding a single request as well
79+
func (r *BatchRequest) UnmarshalJSON(b []byte) error {
80+
switch b[0] {
81+
case '[':
82+
var req []*Request
83+
err := json.Unmarshal(b, &req)
84+
if err != nil {
85+
return err
86+
}
87+
r.Requests = req
88+
r.IsBatch = true
89+
return nil
90+
case '{':
91+
var req Request
92+
err := json.Unmarshal(b, &req)
93+
if err != nil {
94+
return err
95+
}
96+
r.Requests = []*Request{&req}
97+
r.IsBatch = false
98+
return nil
99+
}
100+
101+
return fmt.Errorf("invalid request: %s", FormatRawBody(string(b)))
102+
}
103+
104+
// MarshalJSON will encode a single (non batched) request to a single object or multiple requests into an array
105+
func (r BatchRequest) MarshalJSON() ([]byte, error) {
106+
if len(r.Requests) == 0 {
107+
return nil, fmt.Errorf("empty batch request")
108+
}
109+
110+
// If the batch is just one single request then unbatch it
111+
if !r.IsBatch {
112+
return json.Marshal(r.Requests[0])
113+
}
114+
115+
return json.Marshal(r.Requests)
116+
}
117+
118+
func DecodeBatchRequest(b []byte) (*BatchRequest, error) {
119+
var batch BatchRequest
120+
err := json.Unmarshal(b, &batch)
121+
if err != nil {
122+
return nil, err
123+
}
124+
return &batch, nil
125+
}
126+
127+
func DecodeBatchResponse(b []byte) (*BatchResponse, error) {
128+
var batch BatchResponse
129+
err := json.Unmarshal(b, &batch)
130+
if err != nil {
131+
return nil, err
132+
}
133+
return &batch, nil
134+
}
135+
136+
func SerializeBatchRequest(req *BatchRequest) []byte {
137+
b, err := json.Marshal(req)
138+
if err != nil {
139+
panic(err)
140+
}
141+
return b
142+
}
143+
144+
func SerializeBatchResponse(res *BatchResponse) []byte {
145+
b, err := json.Marshal(res)
146+
if err != nil {
147+
panic(err)
148+
}
149+
return b
150+
}

0 commit comments

Comments
 (0)