Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 57 additions & 0 deletions router-tests/testenv/testenv.go
Original file line number Diff line number Diff line change
Expand Up @@ -2345,6 +2345,44 @@ func (e *Environment) GraphQLWebsocketDialWithRetry(header http.Header, query ur
return nil, nil, err
}

// GraphQLWebsocketDialWithCompressionRetry is like GraphQLWebsocketDialWithRetry but enables
// permessage-deflate compression negotiation on the client side.
func (e *Environment) GraphQLWebsocketDialWithCompressionRetry(header http.Header, query url.Values) (*websocket.Conn, *http.Response, error) {
dialer := websocket.Dialer{
Subprotocols: []string{"graphql-transport-ws"},
EnableCompression: true,
}

waitBetweenRetriesInMs := rand.Intn(10)
timeToSleep := time.Duration(waitBetweenRetriesInMs) * time.Millisecond

var err error

for i := 0; i <= maxSocketRetries; i++ {
urlStr := e.GraphQLWebSocketSubscriptionURL()
if query != nil {
urlStr += "?" + query.Encode()
}
conn, resp, err := dialer.Dial(urlStr, header)

if resp != nil && err == nil {
return conn, resp, err
}

if errors.Is(err, websocket.ErrBadHandshake) {
return conn, resp, err
}

// Make sure that on the final attempt we won't wait
if i != maxSocketRetries {
time.Sleep(timeToSleep)
timeToSleep *= 2
}
}

return nil, nil, err
}

func (e *Environment) InitGraphQLWebSocketConnection(header http.Header, query url.Values, initialPayload json.RawMessage) *websocket.Conn {
conn, _, err := e.GraphQLWebsocketDialWithRetry(header, query)
require.NoError(e.t, err)
Expand All @@ -2362,6 +2400,25 @@ func (e *Environment) InitGraphQLWebSocketConnection(header http.Header, query u
return conn
}

// InitGraphQLWebSocketConnectionWithCompression initializes a WebSocket connection with
// permessage-deflate compression negotiation enabled on the client side.
func (e *Environment) InitGraphQLWebSocketConnectionWithCompression(header http.Header, query url.Values, initialPayload json.RawMessage) (*websocket.Conn, *http.Response) {
conn, resp, err := e.GraphQLWebsocketDialWithCompressionRetry(header, query)
require.NoError(e.t, err)
e.t.Cleanup(func() {
_ = conn.Close()
})
err = conn.WriteJSON(WebSocketMessage{
Type: "connection_init",
Payload: initialPayload,
})
require.NoError(e.t, err)
var ack WebSocketMessage
require.NoError(e.t, ReadAndCheckJSON(e.t, conn, &ack))
require.Equal(e.t, "connection_ack", ack.Type)
return conn, resp
}

func (e *Environment) GraphQLSubscriptionOverSSE(ctx context.Context, request GraphQLRequest, handler func(data string)) {
req, err := e.newGraphQLRequestOverGET(e.GraphQLRequestURL(), request)
if err != nil {
Expand Down
121 changes: 121 additions & 0 deletions router-tests/websocket_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2392,6 +2392,127 @@ func TestWebSockets(t *testing.T) {
})
})

t.Run("compression enabled on server and client", func(t *testing.T) {
t.Parallel()

testenv.Run(t, &testenv.Config{
ModifyWebsocketConfiguration: func(cfg *config.WebSocketConfiguration) {
cfg.Compression.Enabled = true
cfg.Compression.Level = 6
},
}, func(t *testing.T, xEnv *testenv.Environment) {
// Use the compression-enabled dialer
conn, resp := xEnv.InitGraphQLWebSocketConnectionWithCompression(nil, nil, nil)

// Check that compression was negotiated via the Sec-WebSocket-Extensions header
extensions := resp.Header.Get("Sec-WebSocket-Extensions")
require.Contains(t, extensions, "permessage-deflate", "Expected compression to be negotiated")

// Verify the connection works correctly with compression
err := testenv.WSWriteJSON(t, conn, testenv.WebSocketMessage{
ID: "1",
Type: "subscribe",
Payload: []byte(`{"query":"{ employees { id } }"}`),
})
require.NoError(t, err)

var res testenv.WebSocketMessage
err = testenv.WSReadJSON(t, conn, &res)
require.NoError(t, err)
require.Equal(t, "next", res.Type)
require.Equal(t, "1", res.ID)
require.JSONEq(t, `{"data":{"employees":[{"id":1},{"id":2},{"id":3},{"id":4},{"id":5},{"id":7},{"id":8},{"id":10},{"id":11},{"id":12}]}}`, string(res.Payload))

var complete testenv.WebSocketMessage
err = testenv.WSReadJSON(t, conn, &complete)
require.NoError(t, err)
require.Equal(t, "complete", complete.Type)
require.Equal(t, "1", complete.ID)

xEnv.WaitForSubscriptionCount(0, time.Second*5)
})
})

t.Run("compression disabled on server but enabled on client", func(t *testing.T) {
t.Parallel()

testenv.Run(t, &testenv.Config{
ModifyWebsocketConfiguration: func(cfg *config.WebSocketConfiguration) {
cfg.Compression.Enabled = false
},
}, func(t *testing.T, xEnv *testenv.Environment) {
// Use the compression-enabled dialer, but server has compression disabled
conn, resp := xEnv.InitGraphQLWebSocketConnectionWithCompression(nil, nil, nil)

// Check that compression was NOT negotiated
extensions := resp.Header.Get("Sec-WebSocket-Extensions")
require.NotContains(t, extensions, "permessage-deflate", "Expected compression NOT to be negotiated when disabled on server")

// Verify the connection still works correctly without compression
err := testenv.WSWriteJSON(t, conn, testenv.WebSocketMessage{
ID: "1",
Type: "subscribe",
Payload: []byte(`{"query":"{ employees { id } }"}`),
})
require.NoError(t, err)

var res testenv.WebSocketMessage
err = testenv.WSReadJSON(t, conn, &res)
require.NoError(t, err)
require.Equal(t, "next", res.Type)
require.Equal(t, "1", res.ID)
require.JSONEq(t, `{"data":{"employees":[{"id":1},{"id":2},{"id":3},{"id":4},{"id":5},{"id":7},{"id":8},{"id":10},{"id":11},{"id":12}]}}`, string(res.Payload))

var complete testenv.WebSocketMessage
err = testenv.WSReadJSON(t, conn, &complete)
require.NoError(t, err)
require.Equal(t, "complete", complete.Type)
require.Equal(t, "1", complete.ID)

xEnv.WaitForSubscriptionCount(0, time.Second*5)
})
})

t.Run("compression with custom level", func(t *testing.T) {
t.Parallel()

testenv.Run(t, &testenv.Config{
ModifyWebsocketConfiguration: func(cfg *config.WebSocketConfiguration) {
cfg.Compression.Enabled = true
cfg.Compression.Level = 9 // Best compression
},
}, func(t *testing.T, xEnv *testenv.Environment) {
conn, resp := xEnv.InitGraphQLWebSocketConnectionWithCompression(nil, nil, nil)

// Check that compression was negotiated
extensions := resp.Header.Get("Sec-WebSocket-Extensions")
require.Contains(t, extensions, "permessage-deflate", "Expected compression to be negotiated")

// Run a subscription query to verify it works with max compression
err := testenv.WSWriteJSON(t, conn, testenv.WebSocketMessage{
ID: "1",
Type: "subscribe",
Payload: []byte(`{"query":"{ employees { id details { forename surname } } }"}`),
})
require.NoError(t, err)

var res testenv.WebSocketMessage
err = testenv.WSReadJSON(t, conn, &res)
require.NoError(t, err)
require.Equal(t, "next", res.Type)
require.Equal(t, "1", res.ID)
require.Contains(t, string(res.Payload), "forename")
require.Contains(t, string(res.Payload), "surname")

var complete testenv.WebSocketMessage
err = testenv.WSReadJSON(t, conn, &complete)
require.NoError(t, err)
require.Equal(t, "complete", complete.Type)

xEnv.WaitForSubscriptionCount(0, time.Second*5)
})
})

}

func TestFlakyWebSockets(t *testing.T) {
Expand Down
Loading