Skip to content

Commit 9d86c10

Browse files
committed
Expose SetReadLimits for Websocket server to prevent OOM
1 parent b369a85 commit 9d86c10

File tree

3 files changed

+124
-1
lines changed

3 files changed

+124
-1
lines changed

rpc/server.go

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ type Server struct {
5454
batchItemLimit int
5555
batchResponseLimit int
5656
httpBodyLimit int
57+
readLimit int64
5758
}
5859

5960
// NewServer creates a new server instance with no registered handlers.
@@ -62,6 +63,7 @@ func NewServer() *Server {
6263
idgen: randomIDGenerator(),
6364
codecs: make(map[ServerCodec]struct{}),
6465
httpBodyLimit: defaultBodyLimit,
66+
readLimit: wsDefaultReadLimit,
6567
}
6668
server.run.Store(true)
6769
// Register the default service providing meta information about the RPC service such
@@ -89,6 +91,13 @@ func (s *Server) SetHTTPBodyLimit(limit int) {
8991
s.httpBodyLimit = limit
9092
}
9193

94+
// SetReadLimits sets the limit for max message size for Websocket requests.
95+
//
96+
// This method should be called before processing any requests via Websocket server.
97+
func (s *Server) SetReadLimits(limit int64) {
98+
s.readLimit = limit
99+
}
100+
92101
// RegisterName creates a service for the given receiver type under the given name. When no
93102
// methods on the given receiver match the criteria to be either an RPC method or a
94103
// subscription an error is returned. Otherwise a new service is created and added to the

rpc/server_test.go

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,10 @@ package rpc
1919
import (
2020
"bufio"
2121
"bytes"
22+
"context"
2223
"io"
2324
"net"
25+
"net/http/httptest"
2426
"os"
2527
"path/filepath"
2628
"strings"
@@ -202,3 +204,115 @@ func TestServerBatchResponseSizeLimit(t *testing.T) {
202204
}
203205
}
204206
}
207+
208+
func TestServerSetReadLimits(t *testing.T) {
209+
t.Parallel()
210+
211+
// Test different read limits
212+
testCases := []struct {
213+
name string
214+
readLimit int64
215+
testSize int
216+
shouldFail bool
217+
}{
218+
{
219+
name: "small limit with small request - should succeed",
220+
readLimit: 2048,
221+
testSize: 500, // Small request data
222+
shouldFail: false,
223+
},
224+
{
225+
name: "small limit with large request - should fail",
226+
readLimit: 2048,
227+
testSize: 5000, // Large request data that should exceed limit
228+
shouldFail: true,
229+
},
230+
{
231+
name: "medium limit with medium request - should succeed",
232+
readLimit: 10240,
233+
testSize: 5000, // Medium request data
234+
shouldFail: false,
235+
},
236+
{
237+
name: "medium limit with large request - should fail",
238+
readLimit: 10240,
239+
testSize: 20000, // Large request data
240+
shouldFail: true,
241+
},
242+
{
243+
name: "large limit with large request - should succeed",
244+
readLimit: 50000,
245+
testSize: 20000, // Large request data that should fit
246+
shouldFail: false,
247+
},
248+
}
249+
250+
for _, tc := range testCases {
251+
t.Run(tc.name, func(t *testing.T) {
252+
// Create server and set read limits
253+
srv := newTestServer()
254+
srv.SetReadLimits(tc.readLimit)
255+
defer srv.Stop()
256+
257+
// Start HTTP server with WebSocket handler
258+
httpsrv := httptest.NewServer(srv.WebsocketHandler([]string{"*"}))
259+
defer httpsrv.Close()
260+
261+
wsURL := "ws:" + strings.TrimPrefix(httpsrv.URL, "http:")
262+
263+
// Connect WebSocket client
264+
client, err := DialOptions(context.Background(), wsURL)
265+
if err != nil {
266+
t.Fatalf("can't dial: %v", err)
267+
}
268+
defer client.Close()
269+
270+
// Create large request data - this is what will be limited
271+
largeString := strings.Repeat("A", tc.testSize)
272+
273+
// Send the large string as a parameter in the request
274+
var result echoResult
275+
err = client.Call(&result, "test_echo", largeString, 42, &echoArgs{S: "test"})
276+
277+
if tc.shouldFail {
278+
// Expecting an error due to read limit exceeded
279+
if err == nil {
280+
t.Fatalf("expected error for request size %d with limit %d, but got none", tc.testSize, tc.readLimit)
281+
}
282+
// Check if it's the expected message size limit error
283+
if !strings.Contains(err.Error(), "message too big") {
284+
t.Fatalf("expected 'message too big' error, got: %v", err)
285+
}
286+
} else {
287+
// Expecting success
288+
if err != nil {
289+
t.Fatalf("unexpected error for request size %d with limit %d: %v", tc.testSize, tc.readLimit, err)
290+
}
291+
// Verify the response is correct - the echo should return our string
292+
if result.String != largeString {
293+
t.Fatalf("expected echo result to match input")
294+
}
295+
}
296+
})
297+
}
298+
}
299+
300+
// Test that SetReadLimits properly updates the server's readerLimit field
301+
func TestServerSetReadLimitsField(t *testing.T) {
302+
server := NewServer()
303+
304+
// Test initial default value
305+
if server.readLimit != wsDefaultReadLimit {
306+
t.Errorf("expected initial readerLimit to be %d, got %d", wsDefaultReadLimit, server.readLimit)
307+
}
308+
309+
// Test setting different values
310+
testValues := []int64{1024, 10240, 102400, 1048576}
311+
312+
for _, expectedLimit := range testValues {
313+
server.SetReadLimits(expectedLimit)
314+
if server.readLimit != expectedLimit {
315+
t.Errorf("expected readerLimit to be %d after SetReadLimits, got %d", expectedLimit, server.readLimit)
316+
}
317+
}
318+
}

rpc/websocket.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ func (s *Server) WebsocketHandler(allowedOrigins []string) http.Handler {
6060
log.Debug("WebSocket upgrade failed", "err", err)
6161
return
6262
}
63-
codec := newWebsocketCodec(conn, r.Host, r.Header, wsDefaultReadLimit)
63+
codec := newWebsocketCodec(conn, r.Host, r.Header, s.readLimit)
6464
s.ServeCodec(codec, 0)
6565
})
6666
}

0 commit comments

Comments
 (0)