diff --git a/router/core/websocket.go b/router/core/websocket.go index 94aa65f75e..560a1b844c 100644 --- a/router/core/websocket.go +++ b/router/core/websocket.go @@ -2,10 +2,12 @@ package core import ( "bytes" + "compress/flate" "context" "encoding/json" "errors" "fmt" + "io" "net" "net/http" "regexp" @@ -17,6 +19,7 @@ import ( "github.com/buger/jsonparser" "github.com/go-chi/chi/v5/middleware" "github.com/gobwas/ws" + "github.com/gobwas/ws/wsflate" "github.com/gobwas/ws/wsutil" "github.com/gorilla/websocket" "github.com/tidwall/gjson" @@ -67,6 +70,13 @@ type WebsocketMiddlewareOptions struct { ApolloCompatibilityFlags config.ApolloCompatibilityFlags } +// NewWebsocketMiddleware creates an HTTP middleware that upgrades eligible requests to WebSocket +// connections and dispatches them to an internal WebsocketHandler configured by opts. +// +// The returned middleware delegates non-WebSocket requests to the next handler. Options in +// WebsocketMiddlewareOptions control timeouts, compression and protocol features, access control, +// header/query param forwarding allow-lists, net-poll integration, and the components used to +// process GraphQL operations over WebSocket. func NewWebsocketMiddleware(ctx context.Context, opts WebsocketMiddlewareOptions) func(http.Handler) http.Handler { handler := &WebsocketHandler{ @@ -87,6 +97,13 @@ func NewWebsocketMiddleware(ctx context.Context, opts WebsocketMiddlewareOptions disableVariablesRemapping: opts.DisableVariablesRemapping, apolloCompatibilityFlags: opts.ApolloCompatibilityFlags, } + if opts.WebSocketConfiguration != nil && opts.WebSocketConfiguration.Compression.Enabled { + handler.compressionEnabled = true + handler.compressionLevel = opts.WebSocketConfiguration.Compression.Level + if handler.compressionLevel < 1 || handler.compressionLevel > 9 { + handler.compressionLevel = flate.DefaultCompression + } + } if opts.WebSocketConfiguration != nil && opts.WebSocketConfiguration.AbsintheProtocol.Enabled { handler.absintheHandlerEnabled = true handler.absintheHandlerPath = opts.WebSocketConfiguration.AbsintheProtocol.HandlerPath @@ -156,13 +173,20 @@ type wsConnectionWrapper struct { mu sync.Mutex readTimeout time.Duration writeTimeout time.Duration + + // Compression fields + compressionEnabled bool + compressionLevel int } -func newWSConnectionWrapper(conn net.Conn, readTimeout, writeTimeout time.Duration) *wsConnectionWrapper { +// deflate compression level when enabled (typically 1–9 or flate.DefaultCompression). +func newWSConnectionWrapper(conn net.Conn, readTimeout, writeTimeout time.Duration, compressionEnabled bool, compressionLevel int) *wsConnectionWrapper { return &wsConnectionWrapper{ - conn: conn, - readTimeout: readTimeout, - writeTimeout: writeTimeout, + conn: conn, + readTimeout: readTimeout, + writeTimeout: writeTimeout, + compressionEnabled: compressionEnabled, + compressionLevel: compressionLevel, } } @@ -175,9 +199,51 @@ func (c *wsConnectionWrapper) ReadJSON(v any) error { } } - text, err := wsutil.ReadClientText(c.conn) - if err != nil { - return err + var text []byte + var err error + + if c.compressionEnabled { + // Read frames directly and handle compression + controlHandler := wsutil.ControlFrameHandler(c.conn, ws.StateServerSide) + for { + frame, err := ws.ReadFrame(c.conn) + if err != nil { + return err + } + + // Unmask client frames + if frame.Header.Masked { + ws.Cipher(frame.Payload, frame.Header.Mask, 0) + } + + if frame.Header.OpCode.IsControl() { + if err := controlHandler(frame.Header, bytes.NewReader(frame.Payload)); err != nil { + return err + } + continue + } + + if frame.Header.OpCode == ws.OpText || frame.Header.OpCode == ws.OpBinary { + // Check if frame is compressed (RSV1 bit set) + isCompressed, err := wsflate.IsCompressed(frame.Header) + if err != nil { + return err + } + if isCompressed { + frame, err = wsflate.DecompressFrame(frame) + if err != nil { + return err + } + } + text = frame.Payload + break + } + } + } else { + text, err = wsutil.ReadClientText(c.conn) + if err != nil { + return err + } } return json.Unmarshal(text, v) @@ -195,6 +261,10 @@ func (c *wsConnectionWrapper) WriteText(text string) error { } } + if c.compressionEnabled { + return c.writeCompressed([]byte(text)) + } + return wsutil.WriteServerText(c.conn, []byte(text)) } @@ -213,9 +283,32 @@ func (c *wsConnectionWrapper) WriteJSON(v any) error { } } + if c.compressionEnabled { + return c.writeCompressed(data) + } + return wsutil.WriteServerText(c.conn, data) } +// writeCompressed writes data with compression. Must be called with the mutex held. +func (c *wsConnectionWrapper) writeCompressed(data []byte) error { + var buf bytes.Buffer + writer := wsflate.NewWriter(&buf, func(w io.Writer) wsflate.Compressor { + fw, _ := flate.NewWriter(w, c.compressionLevel) + return fw + }) + if _, err := writer.Write(data); err != nil { + return err + } + if err := writer.Flush(); err != nil { + return err + } + + frame := ws.NewFrame(ws.OpText, true, buf.Bytes()) + frame.Header.Rsv = ws.Rsv(true, false, false) // Set RSV1 bit for compression + return ws.WriteFrame(c.conn, frame) +} + func (c *wsConnectionWrapper) WriteCloseFrame(code ws.StatusCode, reason string) error { c.mu.Lock() defer c.mu.Unlock() @@ -267,6 +360,9 @@ type WebsocketHandler struct { disableVariablesRemapping bool apolloCompatibilityFlags config.ApolloCompatibilityFlags + + compressionEnabled bool + compressionLevel int } func (h *WebsocketHandler) handleUpgradeRequest(w http.ResponseWriter, r *http.Request) { @@ -309,7 +405,29 @@ func (h *WebsocketHandler) handleUpgradeRequest(w http.ResponseWriter, r *http.R return false }, } + + // Configure permessage-deflate compression if enabled + var compressionNegotiated bool + var ext wsflate.Extension + if h.compressionEnabled { + ext = wsflate.Extension{ + Parameters: wsflate.Parameters{ + ServerNoContextTakeover: true, + ClientNoContextTakeover: true, + }, + } + upgrader.Negotiate = ext.Negotiate + } + c, _, _, err := upgrader.Upgrade(r, w) + + // Check if compression was negotiated + if h.compressionEnabled && err == nil { + if _, accepted := ext.Accepted(); accepted { + compressionNegotiated = true + } + } + if err != nil { requestLogger.Warn("Websocket upgrade", zap.Error(err)) _ = c.Close() @@ -325,7 +443,7 @@ func (h *WebsocketHandler) handleUpgradeRequest(w http.ResponseWriter, r *http.R // After successful upgrade, we can't write to the response writer anymore // because it's hijacked by the websocket connection - conn := newWSConnectionWrapper(c, h.readTimeout, h.writeTimeout) + conn := newWSConnectionWrapper(c, h.readTimeout, h.writeTimeout, compressionNegotiated, h.compressionLevel) protocol, err := wsproto.NewProtocol(subProtocol, conn) if err != nil { requestLogger.Error("Create websocket protocol", zap.Error(err)) @@ -1282,4 +1400,4 @@ func (h *WebSocketConnectionHandler) Close(unsubscribe bool) { if err != nil { h.logger.Debug("Closing websocket connection", zap.Error(err)) } -} +} \ No newline at end of file