Skip to content

Commit 77f0d86

Browse files
authored
Generate and display more useful error message (#5)
1 parent 14f2e96 commit 77f0d86

File tree

6 files changed

+169
-10
lines changed

6 files changed

+169
-10
lines changed

client/proxy.go

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,9 @@ import (
2626
"github.com/pkg/errors"
2727
"golang.org/x/net/http2"
2828
"golang.org/x/net/http2/h2c"
29+
"golang.stackrox.io/grpc-http1/internal/grpcproto"
2930
"golang.stackrox.io/grpc-http1/internal/grpcweb"
31+
"golang.stackrox.io/grpc-http1/internal/httputils"
3032
"golang.stackrox.io/grpc-http1/internal/pipeconn"
3133
"golang.stackrox.io/grpc-http1/internal/stringutils"
3234
"google.golang.org/grpc"
@@ -36,6 +38,13 @@ import (
3638
)
3739

3840
func modifyResponse(resp *http.Response) error {
41+
// Check if the response is an error response right away, and attempt to display a more useful
42+
// message than gRPC does by default. We still delegate to the default gRPC behavior for 200 responses
43+
// which are otherwise invalid.
44+
if err := httputils.ExtractResponseError(resp); err != nil {
45+
return errors.Wrap(err, "receiving gRPC response from remote endpoint")
46+
}
47+
3948
if resp.ContentLength == 0 {
4049
// Make sure headers do not get flushed, as otherwise the gRPC client will complain about missing trailers.
4150
resp.Header.Set(dontFlushHeadersHeaderKey, "true")
@@ -66,7 +75,8 @@ func writeError(w http.ResponseWriter, err error) {
6675
w.WriteHeader(http.StatusOK)
6776

6877
w.Header().Set("Grpc-Status", fmt.Sprintf("%d", codes.Unavailable))
69-
w.Header().Set("Grpc-Message", errors.Wrap(err, "transport").Error())
78+
errMsg := errors.Wrap(err, "transport").Error()
79+
w.Header().Set("Grpc-Message", grpcproto.EncodeGrpcMessage(errMsg))
7080
}
7181

7282
func createReverseProxy(endpoint string, transport http.RoundTripper, insecure, forceDowngrade bool) *httputil.ReverseProxy {

client/ws_proxy.go

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ import (
3131
"github.com/pkg/errors"
3232
"golang.stackrox.io/grpc-http1/internal/grpcproto"
3333
"golang.stackrox.io/grpc-http1/internal/grpcwebsocket"
34+
"golang.stackrox.io/grpc-http1/internal/httputils"
3435
"golang.stackrox.io/grpc-http1/internal/pipeconn"
3536
"golang.stackrox.io/grpc-http1/internal/size"
3637
"google.golang.org/grpc/codes"
@@ -42,7 +43,7 @@ const (
4243
)
4344

4445
var (
45-
subprotocols = []string{"grpc-ws"}
46+
subprotocols = []string{grpcwebsocket.SubprotocolName}
4647
)
4748

4849
type http2WebSocketProxy struct {
@@ -209,7 +210,8 @@ func (c *websocketConn) writeErrorIfNecessary() {
209210
c.w.WriteHeader(http.StatusOK)
210211

211212
c.w.Header().Set("Trailer:Grpc-Status", fmt.Sprintf("%d", codes.Unavailable))
212-
c.w.Header().Set("Trailer:Grpc-Message", errors.Wrap(c.err, "transport").Error())
213+
errMsg := errors.Wrap(c.err, "transport").Error()
214+
c.w.Header().Set("Trailer:Grpc-Message", grpcproto.EncodeGrpcMessage(errMsg))
213215
}
214216

215217
// ServeHTTP handles gRPC-WebSocket traffic.
@@ -228,16 +230,26 @@ func (h *http2WebSocketProxy) ServeHTTP(w http.ResponseWriter, req *http.Request
228230
url := *req.URL // Copy the value, so we do not overwrite the URL.
229231
url.Scheme = scheme
230232
url.Host = h.endpoint
231-
conn, _, err := websocket.Dial(req.Context(), url.String(), &websocket.DialOptions{
233+
conn, resp, err := websocket.Dial(req.Context(), url.String(), &websocket.DialOptions{
232234
// Add the gRPC headers to the WebSocket handshake request.
233235
HTTPHeader: req.Header,
234236
HTTPClient: h.httpClient,
235237
Subprotocols: subprotocols,
236238
// gRPC already performs compression, so no need for WebSocket to add compression as well.
237239
CompressionMode: websocket.CompressionDisabled,
238240
})
241+
if resp != nil {
242+
// Not strictly necessary because the library already replaces resp.Body with a NopCloser,
243+
// but seems too easy to miss should we switch to a different library.
244+
defer func() { _ = resp.Body.Close() }()
245+
}
239246
if err != nil {
240-
writeError(w, errors.Wrapf(err, "connecting to gRPC server %q", url.String()))
247+
if resp != nil {
248+
if respErr := httputils.ExtractResponseError(resp); respErr != nil {
249+
err = fmt.Errorf("%w; response error: %v", err, respErr)
250+
}
251+
}
252+
writeError(w, errors.Wrapf(err, "connecting to gRPC endpoint %q", url.String()))
241253
return
242254
}
243255
conn.SetReadLimit(64 * size.MB)

internal/grpcproto/message.go

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
package grpcproto
2+
3+
import (
4+
"bytes"
5+
"fmt"
6+
"unicode/utf8"
7+
)
8+
9+
// This code is copied from google.golang.org/[email protected]/internal/transport/http_util.go, ll.443-494,
10+
// and has been adjusted to make the `EncodeGrpcMessage` function exported.
11+
// The original code is Copyright (c) by the gRPC authors and was distributed under the
12+
// Apache License, version 2.0.
13+
14+
const (
15+
spaceByte = ' '
16+
tildeByte = '~'
17+
percentByte = '%'
18+
)
19+
20+
// EncodeGrpcMessage is used to encode status code in header field
21+
// "grpc-message". It does percent encoding and also replaces invalid utf-8
22+
// characters with Unicode replacement character.
23+
//
24+
// It checks to see if each individual byte in msg is an allowable byte, and
25+
// then either percent encoding or passing it through. When percent encoding,
26+
// the byte is converted into hexadecimal notation with a '%' prepended.
27+
func EncodeGrpcMessage(msg string) string {
28+
if msg == "" {
29+
return ""
30+
}
31+
lenMsg := len(msg)
32+
for i := 0; i < lenMsg; i++ {
33+
c := msg[i]
34+
if !(c >= spaceByte && c <= tildeByte && c != percentByte) {
35+
return encodeGrpcMessageUnchecked(msg)
36+
}
37+
}
38+
return msg
39+
}
40+
41+
func encodeGrpcMessageUnchecked(msg string) string {
42+
var buf bytes.Buffer
43+
for len(msg) > 0 {
44+
r, size := utf8.DecodeRuneInString(msg)
45+
for _, b := range []byte(string(r)) {
46+
if size > 1 {
47+
// If size > 1, r is not ascii. Always do percent encoding.
48+
buf.WriteString(fmt.Sprintf("%%%02X", b))
49+
continue
50+
}
51+
52+
// The for loop is necessary even if size == 1. r could be
53+
// utf8.RuneError.
54+
//
55+
// fmt.Sprintf("%%%02X", utf8.RuneError) gives "%FFFD".
56+
if b >= spaceByte && b <= tildeByte && b != percentByte {
57+
buf.WriteByte(b)
58+
} else {
59+
buf.WriteString(fmt.Sprintf("%%%02X", b))
60+
}
61+
}
62+
msg = msg[size:]
63+
}
64+
return buf.String()
65+
}

internal/grpcwebsocket/consts.go

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
package grpcwebsocket
2+
3+
const (
4+
// SubprotocolName is the subprotocol for gRPC-websocket specified in the Sec-Websocket-Protocol
5+
// header.
6+
SubprotocolName = "grpc-ws"
7+
)

internal/httputils/error_message.go

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
package httputils
2+
3+
import (
4+
"errors"
5+
"fmt"
6+
"io"
7+
"io/ioutil"
8+
"net/http"
9+
"regexp"
10+
"strings"
11+
"unicode/utf8"
12+
)
13+
14+
const (
15+
maxBodyBytes = 1024
16+
)
17+
18+
var (
19+
httpHeaderOptSeparatorRegex = regexp.MustCompile(`;\s*`)
20+
)
21+
22+
// ExtractResponseError extracts an error from an HTTP response, reading at most 1024 bytes of the
23+
// response body.
24+
func ExtractResponseError(resp *http.Response) error {
25+
if resp.StatusCode < 400 {
26+
return nil
27+
}
28+
contentTypeFields := httpHeaderOptSeparatorRegex.Split(resp.Header.Get("Content-Type"), 2)
29+
if len(contentTypeFields) == 0 {
30+
return errors.New(resp.Status)
31+
}
32+
33+
if contentTypeFields[0] != "text/plain" {
34+
return fmt.Errorf("%s, content-type %s", resp.Status, contentTypeFields[0])
35+
}
36+
37+
bodyReader := io.LimitReader(resp.Body, maxBodyBytes)
38+
contents, err := ioutil.ReadAll(bodyReader)
39+
contentsStr := strings.TrimSpace(string(contents))
40+
if !utf8.Valid(contents) {
41+
contentsStr = "invalid UTF-8 characters in response"
42+
}
43+
if err != nil {
44+
if contentsStr == "" {
45+
return fmt.Errorf("%s, error reading response body: %v", resp.Status, err)
46+
}
47+
return fmt.Errorf("%s: %s, error reading response body after %d bytes: %v", resp.Status, contentsStr, len(contents), err)
48+
}
49+
50+
if contentsStr == "" {
51+
return errors.New(resp.Status)
52+
}
53+
return fmt.Errorf("%s: %s", resp.Status, contentsStr)
54+
}

server/server.go

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
package server
1616

1717
import (
18+
"errors"
1819
"fmt"
1920
"net/http"
2021
"strings"
@@ -173,7 +174,10 @@ func CreateDowngradingHandler(grpcSrv *grpc.Server, httpHandler http.Handler, op
173174
}
174175

175176
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
176-
if isWebSocketUpgrade(req.Header) {
177+
if isUpgrade, err := isWebSocketUpgrade(req.Header); err != nil {
178+
http.Error(w, err.Error(), http.StatusBadRequest)
179+
return
180+
} else if isUpgrade {
177181
handleGRPCWS(w, req, grpcSrv)
178182
return
179183
}
@@ -188,10 +192,17 @@ func CreateDowngradingHandler(grpcSrv *grpc.Server, httpHandler http.Handler, op
188192
})
189193
}
190194

191-
func isWebSocketUpgrade(header http.Header) bool {
192-
return header.Get("Connection") == "Upgrade" &&
193-
header.Get("Upgrade") == "websocket" &&
194-
header.Get("Sec-Websocket-Protocol") == "grpc-ws"
195+
func isWebSocketUpgrade(header http.Header) (bool, error) {
196+
if header.Get("Sec-Websocket-Protocol") != grpcwebsocket.SubprotocolName {
197+
return false, nil
198+
}
199+
if header.Get("Connection") != "Upgrade" {
200+
return false, errors.New("missing 'Connection: Upgrade' header in gRPC-websocket request (this usually means your proxy or load balancer does not support websockets)")
201+
}
202+
if header.Get("Upgrade") != "websocket" {
203+
return false, errors.New("missing 'Upgrade: websocket' header in gRPC-websocket request (this usually means your proxy or load balancer does not support websockets)")
204+
}
205+
return true, nil
195206
}
196207

197208
func spaceOrComma(r rune) bool {

0 commit comments

Comments
 (0)