Skip to content

Commit 14f2e96

Browse files
authored
Add "force downgrade" (client) and "prefer downgrade" (server) options (#4)
1 parent a88f6cc commit 14f2e96

File tree

7 files changed

+161
-22
lines changed

7 files changed

+161
-22
lines changed

_integration-tests/echo_service_test.go

Lines changed: 56 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,15 @@ func newCtx(t *testing.T, checkHeaders bool, checkTrailers bool) (context.Contex
7373
}
7474

7575
func TestWithEchoService(t *testing.T) {
76-
testCfg := newTestConfig(t)
76+
for _, serverPreferGRPCWeb := range []bool{false, true} {
77+
t.Run(fmt.Sprintf("prefer-grpc-web=%t", serverPreferGRPCWeb), func(t *testing.T) {
78+
testWithEchoService(t, serverPreferGRPCWeb)
79+
})
80+
}
81+
}
82+
83+
func testWithEchoService(t *testing.T, serverPreferGRPCWeb bool) {
84+
testCfg := newTestConfig(t, serverPreferGRPCWeb)
7785
defer testCfg.TearDown()
7886

7987
cases := []testCase{
@@ -145,6 +153,44 @@ func TestWithEchoService(t *testing.T) {
145153
expectClientStreamOK: false,
146154
expectBidiStreamOK: false,
147155
},
156+
{
157+
targetID: "raw-grpc",
158+
useProxy: true,
159+
forceDowngrade: true,
160+
expectUnaryOK: true,
161+
expectServerStreamOK: true,
162+
expectClientStreamOK: true,
163+
expectBidiStreamOK: true,
164+
},
165+
{
166+
targetID: "raw-grpc",
167+
behindHTTP1ReverseProxy: true,
168+
useProxy: true,
169+
forceDowngrade: true,
170+
expectUnaryOK: false,
171+
expectServerStreamOK: false,
172+
expectClientStreamOK: false,
173+
expectBidiStreamOK: false,
174+
},
175+
{
176+
targetID: "downgrading-grpc",
177+
useProxy: true,
178+
forceDowngrade: true,
179+
expectUnaryOK: true,
180+
expectServerStreamOK: true,
181+
expectClientStreamOK: false,
182+
expectBidiStreamOK: false,
183+
},
184+
{
185+
targetID: "downgrading-grpc",
186+
behindHTTP1ReverseProxy: true,
187+
useProxy: true,
188+
forceDowngrade: true,
189+
expectUnaryOK: true,
190+
expectServerStreamOK: true,
191+
expectClientStreamOK: false,
192+
expectBidiStreamOK: false,
193+
},
148194
}
149195

150196
for _, c := range cases {
@@ -155,7 +201,7 @@ func TestWithEchoService(t *testing.T) {
155201
}
156202

157203
func TestWSWithEchoService(t *testing.T) {
158-
testCfg := newTestConfig(t)
204+
testCfg := newTestConfig(t, false)
159205
defer testCfg.TearDown()
160206

161207
cases := []testCase{
@@ -269,6 +315,7 @@ type testCase struct {
269315
behindHTTP1ReverseProxy bool
270316
useProxy bool
271317
useWebSocket bool
318+
forceDowngrade bool
272319

273320
expectUnaryOK bool
274321
expectClientStreamOK bool
@@ -282,6 +329,8 @@ func (c *testCase) Name() string {
282329

283330
if c.useWebSocket {
284331
sb.WriteString("-ws")
332+
} else if c.forceDowngrade {
333+
sb.WriteString("-forced-downgrade")
285334
}
286335

287336
if c.behindHTTP1ReverseProxy {
@@ -329,7 +378,7 @@ func (c *testCase) Run(t *testing.T, cfg *testConfig) {
329378
if !c.behindHTTP1ReverseProxy {
330379
opts = append(opts, client.ForceHTTP2())
331380
}
332-
opts = append(opts, client.UseWebSocket(c.useWebSocket))
381+
opts = append(opts, client.UseWebSocket(c.useWebSocket), client.ForceDowngrade(c.forceDowngrade))
333382

334383
cc, err = client.ConnectViaProxy(ctx, targetAddr, nil, opts...)
335384
} else {
@@ -704,7 +753,7 @@ type testConfig struct {
704753
targetAddrs map[string]string
705754
}
706755

707-
func newTestConfig(t *testing.T) *testConfig {
756+
func newTestConfig(t *testing.T, preferGRPCWeb bool) *testConfig {
708757
targetAddrs := make(map[string]string)
709758
grpcSrv := grpc.NewServer()
710759
echo.RegisterEchoServer(grpcSrv, echoService{})
@@ -713,11 +762,13 @@ func newTestConfig(t *testing.T) *testConfig {
713762
go grpcSrv.Serve(lis)
714763
targetAddrs["raw-grpc"] = lis.Addr().String()
715764

765+
opts := []server.Option{server.PreferGRPCWeb(preferGRPCWeb)}
766+
716767
downgradingSrv := &http.Server{}
717768
var h2Srv http2.Server
718769
require.NoError(t, http2.ConfigureServer(downgradingSrv, &h2Srv))
719770
downgradingSrv.Handler = h2c.NewHandler(
720-
server.CreateDowngradingHandler(grpcSrv, http.NotFoundHandler()),
771+
server.CreateDowngradingHandler(grpcSrv, http.NotFoundHandler(), opts...),
721772
&h2Srv)
722773

723774
lis = listenLocal(t)

client/options.go

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,11 @@ package client
1717
import "google.golang.org/grpc"
1818

1919
type connectOptions struct {
20-
dialOpts []grpc.DialOption
21-
extraH2ALPNs []string
22-
forceHTTP2 bool
23-
useWebSocket bool
20+
dialOpts []grpc.DialOption
21+
extraH2ALPNs []string
22+
forceHTTP2 bool
23+
forceDowngrade bool
24+
useWebSocket bool
2425
}
2526

2627
// ConnectOption is an option that can be passed to the `ConnectViaProxy` method.
@@ -57,6 +58,14 @@ func UseWebSocket(use bool) ConnectOption {
5758
return useWebSocketOption(use)
5859
}
5960

61+
// ForceDowngrade returns a connection option that instructs the
62+
// client to always force gRPC-Web downgrade for gRPC requests.
63+
// Client- or Bidi-streaming requests will not work.
64+
// This option has no effect if websockets are being used.
65+
func ForceDowngrade(force bool) ConnectOption {
66+
return forceDowngradeOption(force)
67+
}
68+
6069
type dialOptsOption []grpc.DialOption
6170

6271
func (o dialOptsOption) apply(opts *connectOptions) {
@@ -80,3 +89,9 @@ type useWebSocketOption bool
8089
func (o useWebSocketOption) apply(opts *connectOptions) {
8190
opts.useWebSocket = bool(o)
8291
}
92+
93+
type forceDowngradeOption bool
94+
95+
func (o forceDowngradeOption) apply(opts *connectOptions) {
96+
opts.forceDowngrade = bool(o)
97+
}

client/proxy.go

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -69,14 +69,21 @@ func writeError(w http.ResponseWriter, err error) {
6969
w.Header().Set("Grpc-Message", errors.Wrap(err, "transport").Error())
7070
}
7171

72-
func createReverseProxy(endpoint string, transport http.RoundTripper, insecure bool) *httputil.ReverseProxy {
72+
func createReverseProxy(endpoint string, transport http.RoundTripper, insecure, forceDowngrade bool) *httputil.ReverseProxy {
7373
scheme := "https"
7474
if insecure {
7575
scheme = "http"
7676
}
7777
return &httputil.ReverseProxy{
7878
Director: func(req *http.Request) {
79-
req.Header.Add("Accept", "application/grpc")
79+
if forceDowngrade {
80+
req.ProtoMajor, req.ProtoMinor, req.Proto = 1, 1, "HTTP/1.1"
81+
req.Header.Del("TE")
82+
req.Header.Del("Accept")
83+
req.Header.Add(grpcweb.GRPCWebOnlyHeader, "true")
84+
} else {
85+
req.Header.Add("Accept", "application/grpc")
86+
}
8087
req.Header.Add("Accept", "application/grpc-web")
8188
req.URL.Scheme = scheme
8289
req.URL.Host = endpoint
@@ -124,12 +131,12 @@ func createTransport(tlsClientConf *tls.Config, forceHTTP2 bool, extraH2ALPNs []
124131
return transport, nil
125132
}
126133

127-
func createClientProxy(endpoint string, tlsClientConf *tls.Config, forceHTTP2 bool, extraH2ALPNs []string) (*http.Server, pipeconn.DialContextFunc, error) {
134+
func createClientProxy(endpoint string, tlsClientConf *tls.Config, forceHTTP2, forceDowngrade bool, extraH2ALPNs []string) (*http.Server, pipeconn.DialContextFunc, error) {
128135
transport, err := createTransport(tlsClientConf, forceHTTP2, extraH2ALPNs)
129136
if err != nil {
130137
return nil, nil, errors.Wrap(err, "creating transport")
131138
}
132-
proxy := createReverseProxy(endpoint, transport, tlsClientConf == nil)
139+
proxy := createReverseProxy(endpoint, transport, tlsClientConf == nil, forceDowngrade)
133140
return makeProxyServer(proxy)
134141
}
135142

@@ -153,7 +160,7 @@ func ConnectViaProxy(ctx context.Context, endpoint string, tlsClientConf *tls.Co
153160
if connectOpts.useWebSocket {
154161
proxy, dialCtx, err = createClientWSProxy(endpoint, tlsClientConf)
155162
} else {
156-
proxy, dialCtx, err = createClientProxy(endpoint, tlsClientConf, connectOpts.forceHTTP2, connectOpts.extraH2ALPNs)
163+
proxy, dialCtx, err = createClientProxy(endpoint, tlsClientConf, connectOpts.forceHTTP2, connectOpts.forceDowngrade, connectOpts.extraH2ALPNs)
157164
}
158165

159166
if err != nil {

internal/grpcweb/defs.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,4 +19,10 @@ const (
1919
trailerMessageFlag byte = 1 << 7
2020

2121
completeHeaderLen = 5
22+
23+
// GRPCWebOnlyHeader is a header to indicate that the server should always return gRPC-Web
24+
// responses, regardless of detected client capabilities. The presence of the header alone
25+
// is sufficient, however it is recommended that a client chooses "true" as the only value
26+
// whenver the header is used.
27+
GRPCWebOnlyHeader = `Grpc-Web-Only`
2228
)

internal/grpcweb/defs_test.go

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
package grpcweb
2+
3+
import (
4+
"net/http"
5+
"testing"
6+
7+
"github.com/stretchr/testify/assert"
8+
)
9+
10+
func TestGRPCWebOnlyHeaderNameIsCanonical(t *testing.T) {
11+
assert.Equal(t, GRPCWebOnlyHeader, http.CanonicalHeaderKey(GRPCWebOnlyHeader))
12+
}

server/options.go

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
package server
2+
3+
type options struct {
4+
preferGRPCWeb bool
5+
}
6+
7+
// Option is an object that controls the behavior of the downgrading gRPC server.
8+
type Option interface {
9+
apply(o *options)
10+
}
11+
12+
type optionFunc func(o *options)
13+
14+
func (f optionFunc) apply(o *options) {
15+
f(o)
16+
}
17+
18+
// PreferGRPCWeb instructs the server to always send a downgraded gRPC-Web response
19+
// if the client indicates it accepts gRPC-Web responses.
20+
func PreferGRPCWeb(prefer bool) Option {
21+
return optionFunc(func(o *options) {
22+
o.preferGRPCWeb = prefer
23+
})
24+
}

server/server.go

Lines changed: 32 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -93,31 +93,50 @@ func handleGRPCWS(w http.ResponseWriter, req *http.Request, grpcSrv *grpc.Server
9393
_ = conn.Close(websocket.StatusNormalClosure, "")
9494
}
9595

96-
func handleGRPCWeb(w http.ResponseWriter, req *http.Request, validPaths map[string]struct{}, grpcSrv *grpc.Server) {
96+
func handleGRPCWeb(w http.ResponseWriter, req *http.Request, validPaths map[string]struct{}, grpcSrv *grpc.Server, srvOpts *options) {
97+
_, isDowngradableMethod := validPaths[req.URL.Path]
98+
9799
// Check for HTTP/2.
98100
if req.ProtoMajor != 2 {
99-
if _, ok := validPaths[req.URL.Path]; !ok {
101+
if !isDowngradableMethod {
100102
// Client-streaming only works with HTTP/2.
101103
http.Error(w, "Method cannot be downgraded", http.StatusInternalServerError)
102104
return
103105
}
104106
req.ProtoMajor, req.ProtoMinor, req.Proto = 2, 0, "HTTP/2.0"
105107
}
106108

107-
if req.Header.Get("TE") == "trailers" {
108-
// Yay, client accepts trailers! Let the normal gRPC handler handle the request.
109+
acceptedContentTypes := strings.FieldsFunc(strings.Join(req.Header["Accept"], ","), spaceOrComma)
110+
acceptGRPCWeb := sliceutils.StringFind(acceptedContentTypes, "application/grpc-web") != -1
111+
// The standard gRPC client doesn't actually send an `Accept: application/grpc` header, so always assume
112+
// the client accepts gRPC _unless_ it explicitly specifies an `application/grpc-web` accept header
113+
// WITHOUT an `application/grpc` accept header.
114+
acceptGRPC := !acceptGRPCWeb || sliceutils.StringFind(acceptedContentTypes, "application/grpc") != -1
115+
116+
// Only consider sending a gRPC response if we are not told to prefer gRPC-Web or the client doesn't support
117+
// gRPC-Web.
118+
if srvOpts.preferGRPCWeb && isDowngradableMethod && acceptGRPCWeb {
119+
acceptGRPC = false
120+
}
121+
122+
// If the client accepts trailers, AND gRPC responses, AND did not set the "Grpc-Web-Only" header,
123+
// return the response as a normal gRPC response.
124+
if req.Header.Get("TE") == "trailers" && acceptGRPC && len(req.Header[grpcweb.GRPCWebOnlyHeader]) == 0 {
109125
grpcSrv.ServeHTTP(w, req)
110126
return
111127
}
112128

113-
acceptedContentTypes := strings.FieldsFunc(strings.Join(req.Header["Accept"], ","), spaceOrComma)
114-
acceptGRPCWeb := sliceutils.StringFind(acceptedContentTypes, "application/grpc-web") != -1
115129
if !acceptGRPCWeb {
116130
// Client doesn't support trailers and doesn't accept a response downgraded to gRPC web.
117131
http.Error(w, "Client neither supports trailers nor gRPC web responses", http.StatusInternalServerError)
118132
return
119133
}
120134

135+
if !isDowngradableMethod {
136+
http.Error(w, "Client requires a gRPC-Web response to a method that cannot be downgraded", http.StatusBadRequest)
137+
return
138+
}
139+
121140
// Tell the server we would accept trailers (the gRPC server currently (v1.29.1) doesn't check for this but it
122141
// really should, as the purpose of the TE header according to the gRPC spec is to detect incompatible proxies).
123142
req.Header.Set("TE", "trailers")
@@ -132,7 +151,7 @@ func handleGRPCWeb(w http.ResponseWriter, req *http.Request, validPaths map[stri
132151

133152
// CreateDowngradingHandler takes a gRPC server and a plain HTTP handler, and returns an HTTP handler that has the
134153
// capability of handling HTTP requests and gRPC requests that may require downgrading the response to gRPC-Web or gRPC-WebSocket.
135-
func CreateDowngradingHandler(grpcSrv *grpc.Server, httpHandler http.Handler) http.Handler {
154+
func CreateDowngradingHandler(grpcSrv *grpc.Server, httpHandler http.Handler, opts ...Option) http.Handler {
136155
// Only allow paths corresponding to gRPC methods that do not use client streaming for gRPC-Web.
137156
validGRPCWebPaths := make(map[string]struct{})
138157

@@ -148,6 +167,11 @@ func CreateDowngradingHandler(grpcSrv *grpc.Server, httpHandler http.Handler) ht
148167
}
149168
}
150169

170+
var serverOpts options
171+
for _, opt := range opts {
172+
opt.apply(&serverOpts)
173+
}
174+
151175
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
152176
if isWebSocketUpgrade(req.Header) {
153177
handleGRPCWS(w, req, grpcSrv)
@@ -160,7 +184,7 @@ func CreateDowngradingHandler(grpcSrv *grpc.Server, httpHandler http.Handler) ht
160184
return
161185
}
162186

163-
handleGRPCWeb(w, req, validGRPCWebPaths, grpcSrv)
187+
handleGRPCWeb(w, req, validGRPCWebPaths, grpcSrv, &serverOpts)
164188
})
165189
}
166190

0 commit comments

Comments
 (0)