Skip to content

Commit 716c7b7

Browse files
authored
Merge pull request containerd#10575 from dims/borrow-latest-wsstream-from-k8s-v1.31.x-to-1.7
[release/1.7] Borrow latest wsstream from k8s v1.31.x to 1.7
2 parents e242c6a + 9269d97 commit 716c7b7

File tree

12 files changed

+910
-45
lines changed

12 files changed

+910
-45
lines changed

.github/workflows/ci.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ jobs:
3333
- uses: ./.github/actions/install-go
3434
- uses: golangci/golangci-lint-action@v4
3535
with:
36+
only-new-issues: true
3637
version: v1.56.1
3738
skip-cache: true
3839
args: --timeout=8m

go.mod

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ require (
7171
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.19.0
7272
go.opentelemetry.io/otel/sdk v1.21.0
7373
go.opentelemetry.io/otel/trace v1.21.0
74+
golang.org/x/net v0.23.0
7475
golang.org/x/sync v0.5.0
7576
golang.org/x/sys v0.18.0
7677
google.golang.org/genproto v0.0.0-20231211222908-989df2bf70f3
@@ -133,7 +134,6 @@ require (
133134
go.opentelemetry.io/proto/otlp v1.0.0 // indirect
134135
golang.org/x/crypto v0.21.0 // indirect
135136
golang.org/x/mod v0.12.0 // indirect
136-
golang.org/x/net v0.23.0 // indirect
137137
golang.org/x/oauth2 v0.11.0 // indirect
138138
golang.org/x/term v0.18.0 // indirect
139139
golang.org/x/text v0.14.0 // indirect

vendor/k8s.io/apiserver/pkg/util/wsstream/conn.go renamed to pkg/cri/streaming/internal/wsstream/conn.go

Lines changed: 141 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,19 @@
1+
/*
2+
Copyright The containerd Authors.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
*/
16+
117
/*
218
Copyright 2015 The Kubernetes Authors.
319
@@ -21,16 +37,18 @@ import (
2137
"fmt"
2238
"io"
2339
"net/http"
24-
"regexp"
2540
"strings"
2641
"time"
2742

2843
"golang.org/x/net/websocket"
29-
"k8s.io/klog/v2"
3044

45+
"k8s.io/apimachinery/pkg/util/httpstream"
3146
"k8s.io/apimachinery/pkg/util/runtime"
47+
"k8s.io/klog/v2"
3248
)
3349

50+
const WebSocketProtocolHeader = "Sec-Websocket-Protocol"
51+
3452
// The Websocket subprotocol "channel.k8s.io" prepends each binary message with a byte indicating
3553
// the channel number (zero indexed) the message was sent on. Messages in both directions should
3654
// prefix their messages with this channel byte. When used for remote execution, the channel numbers
@@ -77,18 +95,47 @@ const (
7795
ReadWriteChannel
7896
)
7997

80-
var (
81-
// connectionUpgradeRegex matches any Connection header value that includes upgrade
82-
connectionUpgradeRegex = regexp.MustCompile("(^|.*,\\s*)upgrade($|\\s*,)")
83-
)
84-
8598
// IsWebSocketRequest returns true if the incoming request contains connection upgrade headers
8699
// for WebSockets.
87100
func IsWebSocketRequest(req *http.Request) bool {
88101
if !strings.EqualFold(req.Header.Get("Upgrade"), "websocket") {
89102
return false
90103
}
91-
return connectionUpgradeRegex.MatchString(strings.ToLower(req.Header.Get("Connection")))
104+
return httpstream.IsUpgradeRequest(req)
105+
}
106+
107+
// IsWebSocketRequestWithStreamCloseProtocol returns true if the request contains headers
108+
// identifying that it is requesting a websocket upgrade with a remotecommand protocol
109+
// version that supports the "CLOSE" signal; false otherwise.
110+
func IsWebSocketRequestWithStreamCloseProtocol(req *http.Request) bool {
111+
if !IsWebSocketRequest(req) {
112+
return false
113+
}
114+
requestedProtocols := strings.TrimSpace(req.Header.Get(WebSocketProtocolHeader))
115+
for _, requestedProtocol := range strings.Split(requestedProtocols, ",") {
116+
if protocolSupportsStreamClose(strings.TrimSpace(requestedProtocol)) {
117+
return true
118+
}
119+
}
120+
121+
return false
122+
}
123+
124+
// IsWebSocketRequestWithTunnelingProtocol returns true if the request contains headers
125+
// identifying that it is requesting a websocket upgrade with a tunneling protocol;
126+
// false otherwise.
127+
func IsWebSocketRequestWithTunnelingProtocol(req *http.Request) bool {
128+
if !IsWebSocketRequest(req) {
129+
return false
130+
}
131+
requestedProtocols := strings.TrimSpace(req.Header.Get(WebSocketProtocolHeader))
132+
for _, requestedProtocol := range strings.Split(requestedProtocols, ",") {
133+
if protocolSupportsWebsocketTunneling(strings.TrimSpace(requestedProtocol)) {
134+
return true
135+
}
136+
}
137+
138+
return false
92139
}
93140

94141
// IgnoreReceives reads from a WebSocket until it is closed, then returns. If timeout is set, the
@@ -172,15 +219,46 @@ func (conn *Conn) SetIdleTimeout(duration time.Duration) {
172219
conn.timeout = duration
173220
}
174221

222+
// SetWriteDeadline sets a timeout on writing to the websocket connection. The
223+
// passed "duration" identifies how far into the future the write must complete
224+
// by before the timeout fires.
225+
func (conn *Conn) SetWriteDeadline(duration time.Duration) {
226+
conn.ws.SetWriteDeadline(time.Now().Add(duration)) //nolint:errcheck
227+
}
228+
175229
// Open the connection and create channels for reading and writing. It returns
176230
// the selected subprotocol, a slice of channels and an error.
177231
func (conn *Conn) Open(w http.ResponseWriter, req *http.Request) (string, []io.ReadWriteCloser, error) {
232+
// serveHTTPComplete is channel that is closed/selected when "websocket#ServeHTTP" finishes.
233+
serveHTTPComplete := make(chan struct{})
234+
// Ensure panic in spawned goroutine is propagated into the parent goroutine.
235+
panicChan := make(chan any, 1)
178236
go func() {
179-
defer runtime.HandleCrash()
180-
defer conn.Close()
237+
// If websocket server returns, propagate panic if necessary. Otherwise,
238+
// signal HTTPServe finished by closing "serveHTTPComplete".
239+
defer func() {
240+
if p := recover(); p != nil {
241+
panicChan <- p
242+
} else {
243+
close(serveHTTPComplete)
244+
}
245+
}()
181246
websocket.Server{Handshake: conn.handshake, Handler: conn.handle}.ServeHTTP(w, req)
182247
}()
183-
<-conn.ready
248+
249+
// In normal circumstances, "websocket.Server#ServeHTTP" calls "initialize" which closes
250+
// "conn.ready" and then blocks until serving is complete.
251+
select {
252+
case <-conn.ready:
253+
klog.V(8).Infof("websocket server initialized--serving")
254+
case <-serveHTTPComplete:
255+
// websocket server returned before completing initialization; cleanup and return error.
256+
conn.closeNonThreadSafe() //nolint:errcheck
257+
return "", nil, fmt.Errorf("websocket server finished before becoming ready")
258+
case p := <-panicChan:
259+
panic(p)
260+
}
261+
184262
rwc := make([]io.ReadWriteCloser, len(conn.channels))
185263
for i := range conn.channels {
186264
rwc[i] = conn.channels[i]
@@ -229,20 +307,50 @@ func (conn *Conn) resetTimeout() {
229307
}
230308
}
231309

232-
// Close is only valid after Open has been called
233-
func (conn *Conn) Close() error {
234-
<-conn.ready
310+
// closeNonThreadSafe cleans up by closing streams and the websocket
311+
// connection *without* waiting for the "ready" channel.
312+
func (conn *Conn) closeNonThreadSafe() error {
235313
for _, s := range conn.channels {
236314
s.Close()
237315
}
238-
conn.ws.Close()
239-
return nil
316+
var err error
317+
if conn.ws != nil {
318+
err = conn.ws.Close()
319+
}
320+
return err
321+
}
322+
323+
// Close is only valid after Open has been called
324+
func (conn *Conn) Close() error {
325+
<-conn.ready
326+
return conn.closeNonThreadSafe()
327+
}
328+
329+
const (
330+
StreamProtocolV5Name = "v5.channel.k8s.io"
331+
WebsocketsSPDYTunnelingPrefix = "SPDY/3.1+"
332+
KubernetesSuffix = ".k8s.io"
333+
StreamClose = 255
334+
)
335+
336+
// protocolSupportsStreamClose returns true if the passed protocol
337+
// supports the stream close signal (currently only V5 remotecommand);
338+
// false otherwise.
339+
func protocolSupportsStreamClose(protocol string) bool {
340+
return protocol == StreamProtocolV5Name
341+
}
342+
343+
// protocolSupportsWebsocketTunneling returns true if the passed protocol
344+
// is a tunneled Kubernetes spdy protocol; false otherwise.
345+
func protocolSupportsWebsocketTunneling(protocol string) bool {
346+
return strings.HasPrefix(protocol, WebsocketsSPDYTunnelingPrefix) && strings.HasSuffix(protocol, KubernetesSuffix)
240347
}
241348

242349
// handle implements a websocket handler.
243350
func (conn *Conn) handle(ws *websocket.Conn) {
244-
defer conn.Close()
245351
conn.initialize(ws)
352+
defer conn.Close()
353+
supportsStreamClose := protocolSupportsStreamClose(conn.selectedProtocol)
246354

247355
for {
248356
conn.resetTimeout()
@@ -256,6 +364,21 @@ func (conn *Conn) handle(ws *websocket.Conn) {
256364
if len(data) == 0 {
257365
continue
258366
}
367+
if supportsStreamClose && data[0] == StreamClose {
368+
if len(data) != 2 {
369+
klog.Errorf("Single channel byte should follow stream close signal. Got %d bytes", len(data)-1)
370+
break
371+
} else {
372+
channel := data[1]
373+
if int(channel) >= len(conn.channels) {
374+
klog.Errorf("Close is targeted for a channel %d that is not valid, possible protocol error", channel)
375+
break
376+
}
377+
klog.V(4).Infof("Received half-close signal from client; close %d stream", channel)
378+
conn.channels[channel].Close() // After first Close, other closes are noop.
379+
}
380+
continue
381+
}
259382
channel := data[0]
260383
if conn.codec == base64Codec {
261384
channel = channel - '0'
@@ -266,7 +389,7 @@ func (conn *Conn) handle(ws *websocket.Conn) {
266389
continue
267390
}
268391
if _, err := conn.channels[channel].DataFromSocket(data); err != nil {
269-
klog.Errorf("Unable to write frame to %d: %v\n%s", channel, err, string(data))
392+
klog.Errorf("Unable to write frame (%d bytes) to %d: %v", len(data), channel, err)
270393
continue
271394
}
272395
}

0 commit comments

Comments
 (0)