Skip to content

Commit 75cebb6

Browse files
authored
fix: timeout 0s not working (#4932)
Signed-off-by: kevin <[email protected]>
1 parent 410f56e commit 75cebb6

File tree

5 files changed

+124
-25
lines changed

5 files changed

+124
-25
lines changed

rest/engine.go

Lines changed: 33 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,10 @@ var ErrSignatureConfig = errors.New("bad config for Signature")
2828
type engine struct {
2929
conf RestConf
3030
routes []featuredRoutes
31-
// timeout is the max timeout of all routes
31+
// timeout is the max timeout of all routes,
32+
// and is used to set http.Server.ReadTimeout and http.Server.WriteTimeout.
33+
// this network timeout is used to avoid DoS attacks by sending data slowly
34+
// or receiving data slowly with many connections to exhaust server resources.
3235
timeout time.Duration
3336
unauthorizedCallback handler.UnauthorizedCallback
3437
unsignedCallback handler.UnsignedCallback
@@ -60,11 +63,7 @@ func (ng *engine) addRoutes(r featuredRoutes) {
6063
}
6164
ng.routes = append(ng.routes, r)
6265

63-
// need to guarantee the timeout is the max of all routes
64-
// otherwise impossible to set http.Server.ReadTimeout & WriteTimeout
65-
if r.timeout > ng.timeout {
66-
ng.timeout = r.timeout
67-
}
66+
ng.mightUpdateTimeout(r)
6867
}
6968

7069
func buildSSERoutes(routes []Route) []Route {
@@ -192,11 +191,12 @@ func (ng *engine) checkedMaxBytes(bytes int64) int64 {
192191
return ng.conf.MaxBytes
193192
}
194193

195-
func (ng *engine) checkedTimeout(timeout time.Duration) time.Duration {
196-
if timeout > 0 {
197-
return timeout
194+
func (ng *engine) checkedTimeout(timeout *time.Duration) time.Duration {
195+
if timeout != nil {
196+
return *timeout
198197
}
199198

199+
// if timeout not set in featured routes, use global timeout
200200
return time.Duration(ng.conf.Timeout) * time.Millisecond
201201
}
202202

@@ -232,6 +232,28 @@ func (ng *engine) hasTimeout() bool {
232232
return ng.conf.Middlewares.Timeout && ng.timeout > 0
233233
}
234234

235+
// mightUpdateTimeout checks if the route timeout is greater than the current,
236+
// and updates the engine's timeout accordingly.
237+
func (ng *engine) mightUpdateTimeout(r featuredRoutes) {
238+
// if global timeout is set to 0, it means no need to set read/write timeout
239+
// if route timeout is nil, no need to update ng.timeout
240+
if ng.timeout == 0 || r.timeout == nil {
241+
return
242+
}
243+
244+
// if route timeout is 0 (means no timeout), cannot set read/write timeout
245+
if *r.timeout == 0 {
246+
ng.timeout = 0
247+
return
248+
}
249+
250+
// need to guarantee the timeout is the max of all routes
251+
// otherwise impossible to set http.Server.ReadTimeout & WriteTimeout
252+
if *r.timeout > ng.timeout {
253+
ng.timeout = *r.timeout
254+
}
255+
}
256+
235257
// notFoundHandler returns a middleware that handles 404 not found requests.
236258
func (ng *engine) notFoundHandler(next http.Handler) http.Handler {
237259
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
@@ -333,7 +355,7 @@ func (ng *engine) start(router httpx.Router, opts ...StartOption) error {
333355
}
334356

335357
// make sure user defined options overwrite default options
336-
opts = append([]StartOption{ng.withTimeout()}, opts...)
358+
opts = append([]StartOption{ng.withNetworkTimeout()}, opts...)
337359

338360
if len(ng.conf.CertFile) == 0 && len(ng.conf.KeyFile) == 0 {
339361
return internal.StartHttp(ng.conf.Host, ng.conf.Port, router, opts...)
@@ -356,7 +378,7 @@ func (ng *engine) use(middleware Middleware) {
356378
ng.middlewares = append(ng.middlewares, middleware)
357379
}
358380

359-
func (ng *engine) withTimeout() internal.StartOption {
381+
func (ng *engine) withNetworkTimeout() internal.StartOption {
360382
return func(svr *http.Server) {
361383
if !ng.hasTimeout() {
362384
return

rest/engine_test.go

Lines changed: 83 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,17 @@ Verbose: true
7373
Path: "/",
7474
Handler: func(w http.ResponseWriter, r *http.Request) {},
7575
}},
76-
timeout: time.Minute,
76+
timeout: ptrOfDuration(time.Minute),
77+
},
78+
{
79+
jwt: jwtSetting{},
80+
signature: signatureSetting{},
81+
routes: []Route{{
82+
Method: http.MethodGet,
83+
Path: "/",
84+
Handler: func(w http.ResponseWriter, r *http.Request) {},
85+
}},
86+
timeout: ptrOfDuration(0),
7787
},
7888
{
7989
priority: true,
@@ -84,7 +94,7 @@ Verbose: true
8494
Path: "/",
8595
Handler: func(w http.ResponseWriter, r *http.Request) {},
8696
}},
87-
timeout: time.Second,
97+
timeout: ptrOfDuration(time.Second),
8898
},
8999
{
90100
priority: true,
@@ -227,19 +237,82 @@ Verbose: true
227237
}))
228238

229239
timeout := time.Second * 3
230-
if route.timeout > timeout {
231-
timeout = route.timeout
240+
if route.timeout != nil {
241+
if *route.timeout == 0 {
242+
timeout = 0
243+
} else if *route.timeout > timeout {
244+
timeout = *route.timeout
245+
}
232246
}
233247
assert.Equal(t, timeout, ng.timeout)
234248
})
235249
}
236250
}
237251
}
238252

253+
func TestNewEngine_unsignedCallback(t *testing.T) {
254+
priKeyfile, err := fs.TempFilenameWithText(priKey)
255+
assert.Nil(t, err)
256+
defer os.Remove(priKeyfile)
257+
258+
yaml := `Name: foo
259+
Host: localhost
260+
Port: 0
261+
Middlewares:
262+
Log: false
263+
`
264+
route := featuredRoutes{
265+
priority: true,
266+
jwt: jwtSetting{
267+
enabled: true,
268+
},
269+
signature: signatureSetting{
270+
enabled: true,
271+
SignatureConf: SignatureConf{
272+
Strict: true,
273+
PrivateKeys: []PrivateKeyConf{
274+
{
275+
Fingerprint: "a",
276+
KeyFile: priKeyfile,
277+
},
278+
},
279+
},
280+
},
281+
routes: []Route{{
282+
Method: http.MethodGet,
283+
Path: "/",
284+
Handler: func(w http.ResponseWriter, r *http.Request) {},
285+
}},
286+
}
287+
288+
var index int32
289+
t.Run(fmt.Sprintf("%s-%v", yaml, route.routes), func(t *testing.T) {
290+
var cnf RestConf
291+
assert.Nil(t, conf.LoadFromYamlBytes([]byte(yaml), &cnf))
292+
ng := newEngine(cnf)
293+
if atomic.AddInt32(&index, 1)%2 == 0 {
294+
ng.setUnsignedCallback(func(w http.ResponseWriter, r *http.Request,
295+
next http.Handler, strict bool, code int) {
296+
})
297+
}
298+
ng.addRoutes(route)
299+
ng.use(func(next http.HandlerFunc) http.HandlerFunc {
300+
return func(w http.ResponseWriter, r *http.Request) {
301+
next.ServeHTTP(w, r)
302+
}
303+
})
304+
305+
assert.NotNil(t, ng.start(mockedRouter{}, func(svr *http.Server) {
306+
}))
307+
308+
assert.Equal(t, time.Duration(time.Second*3), ng.timeout)
309+
})
310+
}
311+
239312
func TestEngine_checkedTimeout(t *testing.T) {
240313
tests := []struct {
241314
name string
242-
timeout time.Duration
315+
timeout *time.Duration
243316
expect time.Duration
244317
}{
245318
{
@@ -248,17 +321,17 @@ func TestEngine_checkedTimeout(t *testing.T) {
248321
},
249322
{
250323
name: "less",
251-
timeout: time.Millisecond * 500,
324+
timeout: ptrOfDuration(time.Millisecond * 500),
252325
expect: time.Millisecond * 500,
253326
},
254327
{
255328
name: "equal",
256-
timeout: time.Second,
329+
timeout: ptrOfDuration(time.Second),
257330
expect: time.Second,
258331
},
259332
{
260333
name: "more",
261-
timeout: time.Millisecond * 1500,
334+
timeout: ptrOfDuration(time.Millisecond * 1500),
262335
expect: time.Millisecond * 1500,
263336
},
264337
}
@@ -401,7 +474,7 @@ func TestEngine_withTimeout(t *testing.T) {
401474
},
402475
})
403476
svr := &http.Server{}
404-
ng.withTimeout()(svr)
477+
ng.withNetworkTimeout()(svr)
405478

406479
assert.Equal(t, time.Duration(test.timeout)*time.Millisecond*4/5, svr.ReadTimeout)
407480
assert.Equal(t, time.Duration(0), svr.ReadHeaderTimeout)
@@ -451,7 +524,7 @@ func TestEngine_ReadWriteTimeout(t *testing.T) {
451524
},
452525
})
453526
svr := &http.Server{}
454-
ng.withTimeout()(svr)
527+
ng.withNetworkTimeout()(svr)
455528

456529
assert.Equal(t, time.Duration(0), svr.ReadHeaderTimeout)
457530
assert.Equal(t, time.Duration(0), svr.IdleTimeout)

rest/server.go

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -283,14 +283,14 @@ func WithSignature(signature SignatureConf) RouteOption {
283283
func WithSSE() RouteOption {
284284
return func(r *featuredRoutes) {
285285
r.sse = true
286-
r.timeout = 0
286+
r.timeout = ptrOfDuration(0)
287287
}
288288
}
289289

290290
// WithTimeout returns a RouteOption to set timeout with given value.
291291
func WithTimeout(timeout time.Duration) RouteOption {
292292
return func(r *featuredRoutes) {
293-
r.timeout = timeout
293+
r.timeout = &timeout
294294
}
295295
}
296296

@@ -325,6 +325,10 @@ func handleError(err error) {
325325
panic(err)
326326
}
327327

328+
func ptrOfDuration(d time.Duration) *time.Duration {
329+
return &d
330+
}
331+
328332
func validateSecret(secret string) {
329333
if len(secret) < 8 {
330334
panic("secret's length can't be less than 8")

rest/server_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -345,7 +345,7 @@ func TestWithPriority(t *testing.T) {
345345
func TestWithTimeout(t *testing.T) {
346346
var fr featuredRoutes
347347
WithTimeout(time.Hour)(&fr)
348-
assert.Equal(t, time.Hour, fr.timeout)
348+
assert.Equal(t, time.Hour, *fr.timeout)
349349
}
350350

351351
func TestWithTLSConfig(t *testing.T) {

rest/types.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ type (
3131
}
3232

3333
featuredRoutes struct {
34-
timeout time.Duration
34+
timeout *time.Duration
3535
priority bool
3636
jwt jwtSetting
3737
signature signatureSetting

0 commit comments

Comments
 (0)