Skip to content

Commit b911aa6

Browse files
committed
fix potential memory leak issue in processing watch request
1 parent beaf3a2 commit b911aa6

File tree

2 files changed

+84
-13
lines changed

2 files changed

+84
-13
lines changed

staging/src/k8s.io/apiserver/pkg/endpoints/handlers/watch.go

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,8 @@ func (w *realTimeoutFactory) TimeoutCh() (<-chan time.Time, func() bool) {
6464
// serveWatch will serve a watch response.
6565
// TODO: the functionality in this method and in WatchServer.Serve is not cleanly decoupled.
6666
func serveWatch(watcher watch.Interface, scope *RequestScope, mediaTypeOptions negotiation.MediaTypeOptions, req *http.Request, w http.ResponseWriter, timeout time.Duration) {
67+
defer watcher.Stop()
68+
6769
options, err := optionsForTransform(mediaTypeOptions, req)
6870
if err != nil {
6971
scope.err(err, w, req)
@@ -201,7 +203,6 @@ func (s *WatchServer) ServeHTTP(w http.ResponseWriter, req *http.Request) {
201203
// ensure the connection times out
202204
timeoutCh, cleanup := s.TimeoutFactory.TimeoutCh()
203205
defer cleanup()
204-
defer s.Watching.Stop()
205206

206207
// begin the stream
207208
w.Header().Set("Content-Type", s.MediaType)
@@ -286,8 +287,6 @@ func (s *WatchServer) HandleWS(ws *websocket.Conn) {
286287
streamBuf := &bytes.Buffer{}
287288
ch := s.Watching.ResultChan()
288289

289-
defer s.Watching.Stop()
290-
291290
for {
292291
select {
293292
case <-done:

staging/src/k8s.io/apiserver/pkg/endpoints/watch_test.go

Lines changed: 82 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ import (
4343
"k8s.io/apimachinery/pkg/watch"
4444
example "k8s.io/apiserver/pkg/apis/example"
4545
"k8s.io/apiserver/pkg/endpoints/handlers"
46+
"k8s.io/apiserver/pkg/endpoints/handlers/responsewriters"
4647
apitesting "k8s.io/apiserver/pkg/endpoints/testing"
4748
"k8s.io/apiserver/pkg/registry/rest"
4849
"k8s.io/client-go/dynamic"
@@ -565,6 +566,21 @@ func (t *fakeTimeoutFactory) TimeoutCh() (<-chan time.Time, func() bool) {
565566
}
566567
}
567568

569+
// serveWatch will serve a watch response according to the watcher and watchServer.
570+
// Before watchServer.ServeHTTP, an error may occur like k8s.io/apiserver/pkg/endpoints/handlers/watch.go#serveWatch does.
571+
func serveWatch(watcher watch.Interface, watchServer *handlers.WatchServer, preServeErr error) http.HandlerFunc {
572+
return func(w http.ResponseWriter, req *http.Request) {
573+
defer watcher.Stop()
574+
575+
if preServeErr != nil {
576+
responsewriters.ErrorNegotiated(preServeErr, watchServer.Scope.Serializer, watchServer.Scope.Kind.GroupVersion(), w, req)
577+
return
578+
}
579+
580+
watchServer.ServeHTTP(w, req)
581+
}
582+
}
583+
568584
func TestWatchHTTPErrors(t *testing.T) {
569585
watcher := watch.NewFake()
570586
timeoutCh := make(chan time.Time)
@@ -590,9 +606,7 @@ func TestWatchHTTPErrors(t *testing.T) {
590606
TimeoutFactory: &fakeTimeoutFactory{timeoutCh, done},
591607
}
592608

593-
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
594-
watchServer.ServeHTTP(w, req)
595-
}))
609+
s := httptest.NewServer(serveWatch(watcher, watchServer, nil))
596610
defer s.Close()
597611

598612
// Setup a client
@@ -629,6 +643,68 @@ func TestWatchHTTPErrors(t *testing.T) {
629643
}
630644
}
631645

646+
func TestWatchHTTPErrorsBeforeServe(t *testing.T) {
647+
watcher := watch.NewFake()
648+
timeoutCh := make(chan time.Time)
649+
done := make(chan struct{})
650+
651+
info, ok := runtime.SerializerInfoForMediaType(codecs.SupportedMediaTypes(), runtime.ContentTypeJSON)
652+
if !ok || info.StreamSerializer == nil {
653+
t.Fatal(info)
654+
}
655+
serializer := info.StreamSerializer
656+
657+
// Setup a new watchserver
658+
watchServer := &handlers.WatchServer{
659+
Scope: &handlers.RequestScope{
660+
Serializer: runtime.NewSimpleNegotiatedSerializer(info),
661+
Kind: testGroupVersion.WithKind("test"),
662+
},
663+
Watching: watcher,
664+
665+
MediaType: "testcase/json",
666+
Framer: serializer.Framer,
667+
Encoder: newCodec,
668+
EmbeddedEncoder: newCodec,
669+
670+
Fixup: func(obj runtime.Object) runtime.Object { return obj },
671+
TimeoutFactory: &fakeTimeoutFactory{timeoutCh, done},
672+
}
673+
674+
errStatus := errors.NewInternalError(fmt.Errorf("we got an error"))
675+
676+
s := httptest.NewServer(serveWatch(watcher, watchServer, errStatus))
677+
defer s.Close()
678+
679+
// Setup a client
680+
dest, _ := url.Parse(s.URL)
681+
dest.Path = "/" + prefix + "/" + newGroupVersion.Group + "/" + newGroupVersion.Version + "/simple"
682+
dest.RawQuery = "watch=true"
683+
684+
req, _ := http.NewRequest("GET", dest.String(), nil)
685+
client := http.Client{}
686+
resp, err := client.Do(req)
687+
if err != nil {
688+
t.Fatalf("Unexpected error: %v", err)
689+
}
690+
691+
// We had already got an error before watch serve started
692+
decoder := json.NewDecoder(resp.Body)
693+
var status *metav1.Status
694+
err = decoder.Decode(&status)
695+
if err != nil {
696+
t.Fatalf("Unexpected error: %v", err)
697+
}
698+
if status.Kind != "Status" || status.APIVersion != "v1" || status.Code != 500 || status.Status != "Failure" || !strings.Contains(status.Message, "we got an error") {
699+
t.Fatalf("error: %#v", status)
700+
}
701+
702+
// check for leaks
703+
if !watcher.IsStopped() {
704+
t.Errorf("Leaked watcher goruntine after request done")
705+
}
706+
}
707+
632708
func TestWatchHTTPDynamicClientErrors(t *testing.T) {
633709
watcher := watch.NewFake()
634710
timeoutCh := make(chan time.Time)
@@ -654,9 +730,7 @@ func TestWatchHTTPDynamicClientErrors(t *testing.T) {
654730
TimeoutFactory: &fakeTimeoutFactory{timeoutCh, done},
655731
}
656732

657-
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
658-
watchServer.ServeHTTP(w, req)
659-
}))
733+
s := httptest.NewServer(serveWatch(watcher, watchServer, nil))
660734
defer s.Close()
661735
defer s.CloseClientConnections()
662736

@@ -699,9 +773,7 @@ func TestWatchHTTPTimeout(t *testing.T) {
699773
TimeoutFactory: &fakeTimeoutFactory{timeoutCh, done},
700774
}
701775

702-
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
703-
watchServer.ServeHTTP(w, req)
704-
}))
776+
s := httptest.NewServer(serveWatch(watcher, watchServer, nil))
705777
defer s.Close()
706778

707779
// Setup a client
@@ -729,7 +801,7 @@ func TestWatchHTTPTimeout(t *testing.T) {
729801
close(timeoutCh)
730802
select {
731803
case <-done:
732-
if !watcher.Stopped {
804+
if !watcher.IsStopped() {
733805
t.Errorf("Leaked watch on timeout")
734806
}
735807
case <-time.After(wait.ForeverTestTimeout):

0 commit comments

Comments
 (0)