Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
136 changes: 127 additions & 9 deletions router/core/websocket.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@ package core

import (
"bytes"
"compress/flate"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net"
"net/http"
"regexp"
Expand All @@ -17,6 +19,7 @@ import (
"github.com/buger/jsonparser"
"github.com/go-chi/chi/v5/middleware"
"github.com/gobwas/ws"
"github.com/gobwas/ws/wsflate"
"github.com/gobwas/ws/wsutil"
"github.com/gorilla/websocket"
"github.com/tidwall/gjson"
Expand Down Expand Up @@ -67,6 +70,13 @@ type WebsocketMiddlewareOptions struct {
ApolloCompatibilityFlags config.ApolloCompatibilityFlags
}

// NewWebsocketMiddleware creates an HTTP middleware that upgrades eligible requests to WebSocket
// connections and dispatches them to an internal WebsocketHandler configured by opts.
//
// The returned middleware delegates non-WebSocket requests to the next handler. Options in
// WebsocketMiddlewareOptions control timeouts, compression and protocol features, access control,
// header/query param forwarding allow-lists, net-poll integration, and the components used to
// process GraphQL operations over WebSocket.
func NewWebsocketMiddleware(ctx context.Context, opts WebsocketMiddlewareOptions) func(http.Handler) http.Handler {

handler := &WebsocketHandler{
Expand All @@ -87,6 +97,13 @@ func NewWebsocketMiddleware(ctx context.Context, opts WebsocketMiddlewareOptions
disableVariablesRemapping: opts.DisableVariablesRemapping,
apolloCompatibilityFlags: opts.ApolloCompatibilityFlags,
}
if opts.WebSocketConfiguration != nil && opts.WebSocketConfiguration.Compression.Enabled {
handler.compressionEnabled = true
handler.compressionLevel = opts.WebSocketConfiguration.Compression.Level
if handler.compressionLevel < 1 || handler.compressionLevel > 9 {
handler.compressionLevel = flate.DefaultCompression
}
}
if opts.WebSocketConfiguration != nil && opts.WebSocketConfiguration.AbsintheProtocol.Enabled {
handler.absintheHandlerEnabled = true
handler.absintheHandlerPath = opts.WebSocketConfiguration.AbsintheProtocol.HandlerPath
Expand Down Expand Up @@ -156,13 +173,20 @@ type wsConnectionWrapper struct {
mu sync.Mutex
readTimeout time.Duration
writeTimeout time.Duration

// Compression fields
compressionEnabled bool
compressionLevel int
}

func newWSConnectionWrapper(conn net.Conn, readTimeout, writeTimeout time.Duration) *wsConnectionWrapper {
// deflate compression level when enabled (typically 1–9 or flate.DefaultCompression).
func newWSConnectionWrapper(conn net.Conn, readTimeout, writeTimeout time.Duration, compressionEnabled bool, compressionLevel int) *wsConnectionWrapper {
return &wsConnectionWrapper{
conn: conn,
readTimeout: readTimeout,
writeTimeout: writeTimeout,
conn: conn,
readTimeout: readTimeout,
writeTimeout: writeTimeout,
compressionEnabled: compressionEnabled,
compressionLevel: compressionLevel,
}
}

Expand All @@ -175,9 +199,51 @@ func (c *wsConnectionWrapper) ReadJSON(v any) error {
}
}

text, err := wsutil.ReadClientText(c.conn)
if err != nil {
return err
var text []byte
var err error

if c.compressionEnabled {
// Read frames directly and handle compression
controlHandler := wsutil.ControlFrameHandler(c.conn, ws.StateServerSide)
for {
frame, err := ws.ReadFrame(c.conn)
if err != nil {
return err
}

// Unmask client frames
if frame.Header.Masked {
ws.Cipher(frame.Payload, frame.Header.Mask, 0)
}

if frame.Header.OpCode.IsControl() {
if err := controlHandler(frame.Header, bytes.NewReader(frame.Payload)); err != nil {
return err
}
continue
}

if frame.Header.OpCode == ws.OpText || frame.Header.OpCode == ws.OpBinary {
// Check if frame is compressed (RSV1 bit set)
isCompressed, err := wsflate.IsCompressed(frame.Header)
if err != nil {
return err
}
if isCompressed {
frame, err = wsflate.DecompressFrame(frame)
if err != nil {
return err
}
}
text = frame.Payload
break
}
}
} else {
text, err = wsutil.ReadClientText(c.conn)
if err != nil {
return err
}
}

return json.Unmarshal(text, v)
Expand All @@ -195,6 +261,10 @@ func (c *wsConnectionWrapper) WriteText(text string) error {
}
}

if c.compressionEnabled {
return c.writeCompressed([]byte(text))
}

return wsutil.WriteServerText(c.conn, []byte(text))
}

Expand All @@ -213,9 +283,32 @@ func (c *wsConnectionWrapper) WriteJSON(v any) error {
}
}

if c.compressionEnabled {
return c.writeCompressed(data)
}

return wsutil.WriteServerText(c.conn, data)
}

// writeCompressed writes data with compression. Must be called with the mutex held.
func (c *wsConnectionWrapper) writeCompressed(data []byte) error {
var buf bytes.Buffer
writer := wsflate.NewWriter(&buf, func(w io.Writer) wsflate.Compressor {
fw, _ := flate.NewWriter(w, c.compressionLevel)
return fw
})
if _, err := writer.Write(data); err != nil {
return err
}
if err := writer.Flush(); err != nil {
return err
}

frame := ws.NewFrame(ws.OpText, true, buf.Bytes())
frame.Header.Rsv = ws.Rsv(true, false, false) // Set RSV1 bit for compression
return ws.WriteFrame(c.conn, frame)
}

func (c *wsConnectionWrapper) WriteCloseFrame(code ws.StatusCode, reason string) error {
c.mu.Lock()
defer c.mu.Unlock()
Expand Down Expand Up @@ -267,6 +360,9 @@ type WebsocketHandler struct {
disableVariablesRemapping bool

apolloCompatibilityFlags config.ApolloCompatibilityFlags

compressionEnabled bool
compressionLevel int
}

func (h *WebsocketHandler) handleUpgradeRequest(w http.ResponseWriter, r *http.Request) {
Expand Down Expand Up @@ -309,7 +405,29 @@ func (h *WebsocketHandler) handleUpgradeRequest(w http.ResponseWriter, r *http.R
return false
},
}

// Configure permessage-deflate compression if enabled
var compressionNegotiated bool
var ext wsflate.Extension
if h.compressionEnabled {
ext = wsflate.Extension{
Parameters: wsflate.Parameters{
ServerNoContextTakeover: true,
ClientNoContextTakeover: true,
},
}
upgrader.Negotiate = ext.Negotiate
}

c, _, _, err := upgrader.Upgrade(r, w)

// Check if compression was negotiated
if h.compressionEnabled && err == nil {
if _, accepted := ext.Accepted(); accepted {
compressionNegotiated = true
}
}

if err != nil {
requestLogger.Warn("Websocket upgrade", zap.Error(err))
_ = c.Close()
Expand All @@ -325,7 +443,7 @@ func (h *WebsocketHandler) handleUpgradeRequest(w http.ResponseWriter, r *http.R
// After successful upgrade, we can't write to the response writer anymore
// because it's hijacked by the websocket connection

conn := newWSConnectionWrapper(c, h.readTimeout, h.writeTimeout)
conn := newWSConnectionWrapper(c, h.readTimeout, h.writeTimeout, compressionNegotiated, h.compressionLevel)
protocol, err := wsproto.NewProtocol(subProtocol, conn)
if err != nil {
requestLogger.Error("Create websocket protocol", zap.Error(err))
Expand Down Expand Up @@ -1282,4 +1400,4 @@ func (h *WebSocketConnectionHandler) Close(unsubscribe bool) {
if err != nil {
h.logger.Debug("Closing websocket connection", zap.Error(err))
}
}
}
Loading