Skip to content

Commit eec2be8

Browse files
committed
Upgrade websocket failure add extra error info
1 parent c15581b commit eec2be8

File tree

2 files changed

+125
-28
lines changed

2 files changed

+125
-28
lines changed

staging/src/k8s.io/client-go/transport/websocket/roundtripper.go

Lines changed: 41 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,17 @@ import (
2020
"crypto/tls"
2121
"errors"
2222
"fmt"
23+
"io"
2324
"net/http"
2425
"net/url"
26+
"strings"
2527

2628
gwebsocket "github.com/gorilla/websocket"
2729

30+
apierrors "k8s.io/apimachinery/pkg/api/errors"
31+
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
32+
"k8s.io/apimachinery/pkg/runtime"
33+
"k8s.io/apimachinery/pkg/runtime/serializer"
2834
"k8s.io/apimachinery/pkg/util/httpstream"
2935
"k8s.io/apimachinery/pkg/util/httpstream/wsstream"
3036
utilnet "k8s.io/apimachinery/pkg/util/net"
@@ -37,6 +43,17 @@ var (
3743
_ http.RoundTripper = &RoundTripper{}
3844
)
3945

46+
var (
47+
statusScheme = runtime.NewScheme()
48+
statusCodecs = serializer.NewCodecFactory(statusScheme)
49+
)
50+
51+
func init() {
52+
statusScheme.AddUnversionedTypes(metav1.SchemeGroupVersion,
53+
&metav1.Status{},
54+
)
55+
}
56+
4057
// ConnectionHolder defines functions for structure providing
4158
// access to the websocket connection.
4259
type ConnectionHolder interface {
@@ -110,12 +127,33 @@ func (rt *RoundTripper) RoundTrip(request *http.Request) (retResp *http.Response
110127
}
111128
wsConn, resp, err := dialer.DialContext(request.Context(), request.URL.String(), request.Header)
112129
if err != nil {
130+
// BadHandshake error becomes an "UpgradeFailureError" (used for streaming fallback).
113131
if errors.Is(err, gwebsocket.ErrBadHandshake) {
114-
// Enhance the error message with the response status if possible.
132+
cause := err
133+
// Enhance the error message with the error response if possible.
115134
if resp != nil && len(resp.Status) > 0 {
116-
err = fmt.Errorf("%w (%s)", err, resp.Status)
135+
defer resp.Body.Close() //nolint:errcheck
136+
cause = fmt.Errorf("%w (%s)", err, resp.Status) // Always add the response status
137+
responseError := ""
138+
responseErrorBytes, readErr := io.ReadAll(io.LimitReader(resp.Body, 64*1024))
139+
if readErr != nil {
140+
cause = fmt.Errorf("%w: unable to read error from server response", cause)
141+
} else {
142+
// If returned error can be decoded as "metav1.Status", return a "StatusError".
143+
responseError = strings.TrimSpace(string(responseErrorBytes))
144+
if len(responseError) > 0 {
145+
if obj, _, decodeErr := statusCodecs.UniversalDecoder().Decode(responseErrorBytes, nil, &metav1.Status{}); decodeErr == nil {
146+
if status, ok := obj.(*metav1.Status); ok {
147+
cause = &apierrors.StatusError{ErrStatus: *status}
148+
}
149+
} else {
150+
// Otherwise, append the responseError string.
151+
cause = fmt.Errorf("%w: %s", cause, responseError)
152+
}
153+
}
154+
}
117155
}
118-
return nil, &httpstream.UpgradeFailureError{Cause: err}
156+
return nil, &httpstream.UpgradeFailureError{Cause: cause}
119157
}
120158
return nil, err
121159
}

staging/src/k8s.io/client-go/transport/websocket/roundtripper_test.go

Lines changed: 84 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ package websocket
1818

1919
import (
2020
"context"
21+
"errors"
2122
"io"
2223
"net/http"
2324
"net/http/httptest"
@@ -28,6 +29,9 @@ import (
2829
"github.com/stretchr/testify/assert"
2930
"github.com/stretchr/testify/require"
3031

32+
apierrors "k8s.io/apimachinery/pkg/api/errors"
33+
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
34+
"k8s.io/apimachinery/pkg/runtime"
3135
"k8s.io/apimachinery/pkg/util/httpstream"
3236
"k8s.io/apimachinery/pkg/util/httpstream/wsstream"
3337
"k8s.io/apimachinery/pkg/util/remotecommand"
@@ -64,31 +68,86 @@ func TestWebSocketRoundTripper_RoundTripperSucceeds(t *testing.T) {
6468
}
6569

6670
func TestWebSocketRoundTripper_RoundTripperFails(t *testing.T) {
67-
// Create fake WebSocket server.
68-
websocketServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
69-
// Bad handshake means websocket server will not completely initialize.
70-
_, err := webSocketServerStreams(req, w)
71-
require.Error(t, err)
72-
assert.ErrorContains(t, err, "websocket server finished before becoming ready")
73-
}))
74-
defer websocketServer.Close()
75-
76-
// Create the wrapped roundtripper and websocket upgrade roundtripper and call "RoundTrip()".
77-
websocketLocation, err := url.Parse(websocketServer.URL)
78-
require.NoError(t, err)
79-
req, err := http.NewRequestWithContext(context.Background(), "GET", websocketServer.URL, nil)
80-
require.NoError(t, err)
81-
rt, _, err := RoundTripperFor(&restclient.Config{Host: websocketLocation.Host})
82-
require.NoError(t, err)
83-
// Requested subprotocol version 1 is not supported by test websocket server.
84-
requestedProtocol := remotecommand.StreamProtocolV1Name
85-
req.Header[wsstream.WebSocketProtocolHeader] = []string{requestedProtocol}
86-
_, err = rt.RoundTrip(req)
87-
// Ensure a "bad handshake" error is returned, since requested protocol is not supported.
88-
require.Error(t, err)
89-
assert.ErrorContains(t, err, "websocket: bad handshake")
90-
assert.ErrorContains(t, err, "403 Forbidden")
91-
assert.True(t, httpstream.IsUpgradeFailure(err))
71+
testCases := map[string]struct {
72+
statusCode int
73+
body string
74+
status *metav1.Status
75+
expectedError string
76+
}{
77+
"Empty response status still returns basic websocket error": {
78+
statusCode: -1,
79+
body: "",
80+
expectedError: "websocket: bad handshake",
81+
},
82+
"Empty response body still returns status": {
83+
statusCode: http.StatusForbidden,
84+
body: "",
85+
expectedError: "(403 Forbidden)",
86+
},
87+
"Error response body returned as string when can not be cast as metav1.Status": {
88+
statusCode: http.StatusForbidden,
89+
body: "RBAC violated",
90+
expectedError: "(403 Forbidden): RBAC violated",
91+
},
92+
"Error returned as metav1.Status within response body": {
93+
statusCode: http.StatusBadRequest,
94+
body: "",
95+
status: &metav1.Status{
96+
TypeMeta: metav1.TypeMeta{
97+
APIVersion: "meta.k8s.io/v1",
98+
Kind: "Status",
99+
},
100+
Status: "Failure",
101+
Reason: "Unable to negotiate sub-protocol",
102+
Code: http.StatusBadRequest,
103+
},
104+
},
105+
}
106+
encoder := statusCodecs.LegacyCodec(metav1.SchemeGroupVersion)
107+
for testName, testCase := range testCases {
108+
t.Run(testName, func(t *testing.T) {
109+
// Create fake WebSocket server.
110+
websocketServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
111+
if testCase.statusCode > 0 {
112+
w.WriteHeader(testCase.statusCode)
113+
}
114+
if testCase.status != nil {
115+
statusBytes, err := runtime.Encode(encoder, testCase.status)
116+
require.NoError(t, err)
117+
_, err = w.Write(statusBytes)
118+
require.NoError(t, err)
119+
} else if len(testCase.body) > 0 {
120+
_, err := w.Write([]byte(testCase.body))
121+
require.NoError(t, err)
122+
}
123+
}))
124+
defer websocketServer.Close()
125+
126+
// Create the wrapped roundtripper and websocket upgrade roundtripper and call "RoundTrip()".
127+
websocketLocation, err := url.Parse(websocketServer.URL)
128+
require.NoError(t, err)
129+
req, err := http.NewRequestWithContext(context.Background(), "GET", websocketServer.URL, nil)
130+
require.NoError(t, err)
131+
rt, _, err := RoundTripperFor(&restclient.Config{Host: websocketLocation.Host})
132+
require.NoError(t, err)
133+
_, err = rt.RoundTrip(req)
134+
require.Error(t, err)
135+
assert.True(t, httpstream.IsUpgradeFailure(err))
136+
if testCase.status != nil {
137+
upgradeErr := &httpstream.UpgradeFailureError{}
138+
validErr := errors.As(err, &upgradeErr)
139+
assert.True(t, validErr, "could not cast error as httpstream.UpgradeFailureError")
140+
statusErr := upgradeErr.Cause
141+
apiErr := &apierrors.StatusError{}
142+
validErr = errors.As(statusErr, &apiErr)
143+
assert.True(t, validErr, "could not cast error as apierrors.StatusError")
144+
assert.Equal(t, *testCase.status, apiErr.ErrStatus)
145+
} else {
146+
assert.Contains(t, err.Error(), testCase.expectedError,
147+
"expected (%s), got (%s)", testCase.expectedError, err.Error())
148+
}
149+
})
150+
}
92151
}
93152

94153
func TestWebSocketRoundTripper_NegotiateCreatesConnection(t *testing.T) {

0 commit comments

Comments
 (0)