@@ -18,6 +18,7 @@ package websocket
18
18
19
19
import (
20
20
"context"
21
+ "errors"
21
22
"io"
22
23
"net/http"
23
24
"net/http/httptest"
@@ -28,6 +29,9 @@ import (
28
29
"github.com/stretchr/testify/assert"
29
30
"github.com/stretchr/testify/require"
30
31
32
+ apierrors "k8s.io/apimachinery/pkg/api/errors"
33
+ metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
34
+ "k8s.io/apimachinery/pkg/runtime"
31
35
"k8s.io/apimachinery/pkg/util/httpstream"
32
36
"k8s.io/apimachinery/pkg/util/httpstream/wsstream"
33
37
"k8s.io/apimachinery/pkg/util/remotecommand"
@@ -64,31 +68,86 @@ func TestWebSocketRoundTripper_RoundTripperSucceeds(t *testing.T) {
64
68
}
65
69
66
70
func TestWebSocketRoundTripper_RoundTripperFails (t * testing.T ) {
67
- // Create fake WebSocket server.
68
- websocketServer := httptest .NewServer (http .HandlerFunc (func (w http.ResponseWriter , req * http.Request ) {
69
- // Bad handshake means websocket server will not completely initialize.
70
- _ , err := webSocketServerStreams (req , w )
71
- require .Error (t , err )
72
- assert .ErrorContains (t , err , "websocket server finished before becoming ready" )
73
- }))
74
- defer websocketServer .Close ()
75
-
76
- // Create the wrapped roundtripper and websocket upgrade roundtripper and call "RoundTrip()".
77
- websocketLocation , err := url .Parse (websocketServer .URL )
78
- require .NoError (t , err )
79
- req , err := http .NewRequestWithContext (context .Background (), "GET" , websocketServer .URL , nil )
80
- require .NoError (t , err )
81
- rt , _ , err := RoundTripperFor (& restclient.Config {Host : websocketLocation .Host })
82
- require .NoError (t , err )
83
- // Requested subprotocol version 1 is not supported by test websocket server.
84
- requestedProtocol := remotecommand .StreamProtocolV1Name
85
- req .Header [wsstream .WebSocketProtocolHeader ] = []string {requestedProtocol }
86
- _ , err = rt .RoundTrip (req )
87
- // Ensure a "bad handshake" error is returned, since requested protocol is not supported.
88
- require .Error (t , err )
89
- assert .ErrorContains (t , err , "websocket: bad handshake" )
90
- assert .ErrorContains (t , err , "403 Forbidden" )
91
- assert .True (t , httpstream .IsUpgradeFailure (err ))
71
+ testCases := map [string ]struct {
72
+ statusCode int
73
+ body string
74
+ status * metav1.Status
75
+ expectedError string
76
+ }{
77
+ "Empty response status still returns basic websocket error" : {
78
+ statusCode : - 1 ,
79
+ body : "" ,
80
+ expectedError : "websocket: bad handshake" ,
81
+ },
82
+ "Empty response body still returns status" : {
83
+ statusCode : http .StatusForbidden ,
84
+ body : "" ,
85
+ expectedError : "(403 Forbidden)" ,
86
+ },
87
+ "Error response body returned as string when can not be cast as metav1.Status" : {
88
+ statusCode : http .StatusForbidden ,
89
+ body : "RBAC violated" ,
90
+ expectedError : "(403 Forbidden): RBAC violated" ,
91
+ },
92
+ "Error returned as metav1.Status within response body" : {
93
+ statusCode : http .StatusBadRequest ,
94
+ body : "" ,
95
+ status : & metav1.Status {
96
+ TypeMeta : metav1.TypeMeta {
97
+ APIVersion : "meta.k8s.io/v1" ,
98
+ Kind : "Status" ,
99
+ },
100
+ Status : "Failure" ,
101
+ Reason : "Unable to negotiate sub-protocol" ,
102
+ Code : http .StatusBadRequest ,
103
+ },
104
+ },
105
+ }
106
+ encoder := statusCodecs .LegacyCodec (metav1 .SchemeGroupVersion )
107
+ for testName , testCase := range testCases {
108
+ t .Run (testName , func (t * testing.T ) {
109
+ // Create fake WebSocket server.
110
+ websocketServer := httptest .NewServer (http .HandlerFunc (func (w http.ResponseWriter , req * http.Request ) {
111
+ if testCase .statusCode > 0 {
112
+ w .WriteHeader (testCase .statusCode )
113
+ }
114
+ if testCase .status != nil {
115
+ statusBytes , err := runtime .Encode (encoder , testCase .status )
116
+ require .NoError (t , err )
117
+ _ , err = w .Write (statusBytes )
118
+ require .NoError (t , err )
119
+ } else if len (testCase .body ) > 0 {
120
+ _ , err := w .Write ([]byte (testCase .body ))
121
+ require .NoError (t , err )
122
+ }
123
+ }))
124
+ defer websocketServer .Close ()
125
+
126
+ // Create the wrapped roundtripper and websocket upgrade roundtripper and call "RoundTrip()".
127
+ websocketLocation , err := url .Parse (websocketServer .URL )
128
+ require .NoError (t , err )
129
+ req , err := http .NewRequestWithContext (context .Background (), "GET" , websocketServer .URL , nil )
130
+ require .NoError (t , err )
131
+ rt , _ , err := RoundTripperFor (& restclient.Config {Host : websocketLocation .Host })
132
+ require .NoError (t , err )
133
+ _ , err = rt .RoundTrip (req )
134
+ require .Error (t , err )
135
+ assert .True (t , httpstream .IsUpgradeFailure (err ))
136
+ if testCase .status != nil {
137
+ upgradeErr := & httpstream.UpgradeFailureError {}
138
+ validErr := errors .As (err , & upgradeErr )
139
+ assert .True (t , validErr , "could not cast error as httpstream.UpgradeFailureError" )
140
+ statusErr := upgradeErr .Cause
141
+ apiErr := & apierrors.StatusError {}
142
+ validErr = errors .As (statusErr , & apiErr )
143
+ assert .True (t , validErr , "could not cast error as apierrors.StatusError" )
144
+ assert .Equal (t , * testCase .status , apiErr .ErrStatus )
145
+ } else {
146
+ assert .Contains (t , err .Error (), testCase .expectedError ,
147
+ "expected (%s), got (%s)" , testCase .expectedError , err .Error ())
148
+ }
149
+ })
150
+ }
92
151
}
93
152
94
153
func TestWebSocketRoundTripper_NegotiateCreatesConnection (t * testing.T ) {
0 commit comments