Skip to content

Commit b0b31f3

Browse files
authored
feat(gateway): add custom middleware support with onion model (#5035)
1 parent 82a937d commit b0b31f3

File tree

2 files changed

+72
-2
lines changed

2 files changed

+72
-2
lines changed

gateway/server.go

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,12 +27,16 @@ import (
2727
const defaultHttpScheme = "http"
2828

2929
type (
30+
// MiddlewareFunc defines the function signature for middleware.
31+
MiddlewareFunc func(next http.HandlerFunc) http.HandlerFunc
32+
3033
// Server is a gateway server.
3134
Server struct {
3235
*rest.Server
3336
upstreams []Upstream
3437
conns []zrpc.Client
3538
processHeader func(http.Header) []string
39+
middlewares []MiddlewareFunc
3640
dialer func(conf zrpc.RpcClientConf) zrpc.Client
3741
}
3842

@@ -105,7 +109,7 @@ func (s *Server) build() error {
105109

106110
func (s *Server) buildGrpcHandler(source grpcurl.DescriptorSource, resolver jsonpb.AnyResolver,
107111
cli zrpc.Client, rpcPath string) func(http.ResponseWriter, *http.Request) {
108-
return func(w http.ResponseWriter, r *http.Request) {
112+
handler := func(w http.ResponseWriter, r *http.Request) {
109113
parser, err := internal.NewRequestParser(r, resolver)
110114
if err != nil {
111115
httpx.ErrorCtx(r.Context(), w, err)
@@ -124,6 +128,8 @@ func (s *Server) buildGrpcHandler(source grpcurl.DescriptorSource, resolver json
124128
httpx.ErrorCtx(r.Context(), w, st.Err())
125129
}
126130
}
131+
132+
return s.buildChainHandler(handler)
127133
}
128134

129135
func (s *Server) buildGrpcRoute(up Upstream, writer mr.Writer[rest.Route], cancel func(error)) {
@@ -177,7 +183,7 @@ func (s *Server) buildGrpcRoute(up Upstream, writer mr.Writer[rest.Route], cance
177183
}
178184

179185
func (s *Server) buildHttpHandler(target *HttpClientConf) http.HandlerFunc {
180-
return func(w http.ResponseWriter, r *http.Request) {
186+
handler := func(w http.ResponseWriter, r *http.Request) {
181187
w.Header().Set(httpx.ContentType, httpx.JsonContentType)
182188
req, err := buildRequestWithNewTarget(r, target)
183189
if err != nil {
@@ -213,6 +219,8 @@ func (s *Server) buildHttpHandler(target *HttpClientConf) http.HandlerFunc {
213219
logc.Error(r.Context(), err)
214220
}
215221
}
222+
223+
return s.buildChainHandler(handler)
216224
}
217225

218226
func (s *Server) buildHttpRoute(up Upstream, writer mr.Writer[rest.Route]) {
@@ -263,6 +271,21 @@ func WithHeaderProcessor(processHeader func(http.Header) []string) func(*Server)
263271
}
264272
}
265273

274+
// WithMiddleware adds one or more middleware functions to process HTTP requests.
275+
// Multiple middlewares will be executed in the order they were passed (like an onion model).
276+
func WithMiddleware(middlewares ...MiddlewareFunc) func(*Server) {
277+
return func(s *Server) {
278+
s.middlewares = append(s.middlewares, middlewares...)
279+
}
280+
}
281+
282+
func (s *Server) buildChainHandler(handler http.HandlerFunc) http.HandlerFunc {
283+
for i := len(s.middlewares) - 1; i >= 0; i-- {
284+
handler = s.middlewares[i](handler)
285+
}
286+
return handler
287+
}
288+
266289
func buildRequestWithNewTarget(r *http.Request, target *HttpClientConf) (*http.Request, error) {
267290
u := *r.URL
268291
u.Host = target.Target

gateway/server_test.go

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -325,3 +325,50 @@ type badResponseWriter struct {
325325
func (w *badResponseWriter) Write([]byte) (int, error) {
326326
return 0, errors.New("bad writer")
327327
}
328+
329+
func TestWithMiddleware(t *testing.T) {
330+
var callOrder []string
331+
332+
firstMiddleware := func(next http.HandlerFunc) http.HandlerFunc {
333+
return func(w http.ResponseWriter, r *http.Request) {
334+
callOrder = append(callOrder, "first-start")
335+
w.Header().Set("X-First-Middleware", "called")
336+
next(w, r)
337+
callOrder = append(callOrder, "first-end")
338+
}
339+
}
340+
341+
secondMiddleware := func(next http.HandlerFunc) http.HandlerFunc {
342+
return func(w http.ResponseWriter, r *http.Request) {
343+
callOrder = append(callOrder, "second-start")
344+
w.Header().Set("X-Second-Middleware", "called")
345+
next(w, r)
346+
callOrder = append(callOrder, "second-end")
347+
}
348+
}
349+
350+
var c GatewayConf
351+
err := conf.FillDefault(&c)
352+
assert.Nil(t, err)
353+
// Test multiple middlewares in one call
354+
server1 := MustNewServer(c, WithMiddleware(firstMiddleware, secondMiddleware))
355+
assert.Len(t, server1.middlewares, 2, "Should have 2 middlewares from one call")
356+
// Test multiple middleware calls
357+
server2 := MustNewServer(c, WithMiddleware(firstMiddleware), WithMiddleware(secondMiddleware))
358+
assert.Len(t, server2.middlewares, 2, "Should have 2 middlewares from separate calls")
359+
// Test execution order (onion model)
360+
finalHandler := func(w http.ResponseWriter, r *http.Request) {
361+
callOrder = append(callOrder, "handler")
362+
w.WriteHeader(http.StatusOK)
363+
}
364+
365+
testHandler := server1.buildChainHandler(finalHandler)
366+
w := httptest.NewRecorder()
367+
r := httptest.NewRequest("GET", "/test", nil)
368+
testHandler(w, r)
369+
370+
expectedOrder := []string{"first-start", "second-start", "handler", "second-end", "first-end"}
371+
assert.Equal(t, expectedOrder, callOrder, "Middleware execution should follow onion model")
372+
assert.Equal(t, "called", w.Header().Get("X-First-Middleware"))
373+
assert.Equal(t, "called", w.Header().Get("X-Second-Middleware"))
374+
}

0 commit comments

Comments
 (0)