Skip to content

Commit 34f3492

Browse files
authored
Merge pull request kubernetes#85410 from answer1991/bugfix/memory-leak-in-watch
fix potential memory leak issue in processing watch request
2 parents 78c56e6 + b911aa6 commit 34f3492

File tree

2 files changed

+84
-13
lines changed

2 files changed

+84
-13
lines changed

staging/src/k8s.io/apiserver/pkg/endpoints/handlers/watch.go

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,8 @@ func (w *realTimeoutFactory) TimeoutCh() (<-chan time.Time, func() bool) {
6363
// serveWatch will serve a watch response.
6464
// TODO: the functionality in this method and in WatchServer.Serve is not cleanly decoupled.
6565
func serveWatch(watcher watch.Interface, scope *RequestScope, mediaTypeOptions negotiation.MediaTypeOptions, req *http.Request, w http.ResponseWriter, timeout time.Duration) {
66+
defer watcher.Stop()
67+
6668
options, err := optionsForTransform(mediaTypeOptions, req)
6769
if err != nil {
6870
scope.err(err, w, req)
@@ -193,7 +195,6 @@ func (s *WatchServer) ServeHTTP(w http.ResponseWriter, req *http.Request) {
193195
// ensure the connection times out
194196
timeoutCh, cleanup := s.TimeoutFactory.TimeoutCh()
195197
defer cleanup()
196-
defer s.Watching.Stop()
197198

198199
// begin the stream
199200
w.Header().Set("Content-Type", s.MediaType)
@@ -280,8 +281,6 @@ func (s *WatchServer) HandleWS(ws *websocket.Conn) {
280281
streamBuf := &bytes.Buffer{}
281282
ch := s.Watching.ResultChan()
282283

283-
defer s.Watching.Stop()
284-
285284
for {
286285
select {
287286
case <-done:

staging/src/k8s.io/apiserver/pkg/endpoints/watch_test.go

Lines changed: 82 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ import (
4343
"k8s.io/apimachinery/pkg/watch"
4444
example "k8s.io/apiserver/pkg/apis/example"
4545
"k8s.io/apiserver/pkg/endpoints/handlers"
46+
"k8s.io/apiserver/pkg/endpoints/handlers/responsewriters"
4647
apitesting "k8s.io/apiserver/pkg/endpoints/testing"
4748
"k8s.io/apiserver/pkg/registry/rest"
4849
"k8s.io/client-go/dynamic"
@@ -607,6 +608,21 @@ func (t *fakeTimeoutFactory) TimeoutCh() (<-chan time.Time, func() bool) {
607608
}
608609
}
609610

611+
// serveWatch will serve a watch response according to the watcher and watchServer.
612+
// Before watchServer.ServeHTTP, an error may occur like k8s.io/apiserver/pkg/endpoints/handlers/watch.go#serveWatch does.
613+
func serveWatch(watcher watch.Interface, watchServer *handlers.WatchServer, preServeErr error) http.HandlerFunc {
614+
return func(w http.ResponseWriter, req *http.Request) {
615+
defer watcher.Stop()
616+
617+
if preServeErr != nil {
618+
responsewriters.ErrorNegotiated(preServeErr, watchServer.Scope.Serializer, watchServer.Scope.Kind.GroupVersion(), w, req)
619+
return
620+
}
621+
622+
watchServer.ServeHTTP(w, req)
623+
}
624+
}
625+
610626
func TestWatchHTTPErrors(t *testing.T) {
611627
watcher := watch.NewFake()
612628
timeoutCh := make(chan time.Time)
@@ -632,9 +648,7 @@ func TestWatchHTTPErrors(t *testing.T) {
632648
TimeoutFactory: &fakeTimeoutFactory{timeoutCh, done},
633649
}
634650

635-
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
636-
watchServer.ServeHTTP(w, req)
637-
}))
651+
s := httptest.NewServer(serveWatch(watcher, watchServer, nil))
638652
defer s.Close()
639653

640654
// Setup a client
@@ -671,6 +685,68 @@ func TestWatchHTTPErrors(t *testing.T) {
671685
}
672686
}
673687

688+
func TestWatchHTTPErrorsBeforeServe(t *testing.T) {
689+
watcher := watch.NewFake()
690+
timeoutCh := make(chan time.Time)
691+
done := make(chan struct{})
692+
693+
info, ok := runtime.SerializerInfoForMediaType(codecs.SupportedMediaTypes(), runtime.ContentTypeJSON)
694+
if !ok || info.StreamSerializer == nil {
695+
t.Fatal(info)
696+
}
697+
serializer := info.StreamSerializer
698+
699+
// Setup a new watchserver
700+
watchServer := &handlers.WatchServer{
701+
Scope: &handlers.RequestScope{
702+
Serializer: runtime.NewSimpleNegotiatedSerializer(info),
703+
Kind: testGroupVersion.WithKind("test"),
704+
},
705+
Watching: watcher,
706+
707+
MediaType: "testcase/json",
708+
Framer: serializer.Framer,
709+
Encoder: newCodec,
710+
EmbeddedEncoder: newCodec,
711+
712+
Fixup: func(obj runtime.Object) runtime.Object { return obj },
713+
TimeoutFactory: &fakeTimeoutFactory{timeoutCh, done},
714+
}
715+
716+
errStatus := errors.NewInternalError(fmt.Errorf("we got an error"))
717+
718+
s := httptest.NewServer(serveWatch(watcher, watchServer, errStatus))
719+
defer s.Close()
720+
721+
// Setup a client
722+
dest, _ := url.Parse(s.URL)
723+
dest.Path = "/" + prefix + "/" + newGroupVersion.Group + "/" + newGroupVersion.Version + "/simple"
724+
dest.RawQuery = "watch=true"
725+
726+
req, _ := http.NewRequest("GET", dest.String(), nil)
727+
client := http.Client{}
728+
resp, err := client.Do(req)
729+
if err != nil {
730+
t.Fatalf("Unexpected error: %v", err)
731+
}
732+
733+
// We had already got an error before watch serve started
734+
decoder := json.NewDecoder(resp.Body)
735+
var status *metav1.Status
736+
err = decoder.Decode(&status)
737+
if err != nil {
738+
t.Fatalf("Unexpected error: %v", err)
739+
}
740+
if status.Kind != "Status" || status.APIVersion != "v1" || status.Code != 500 || status.Status != "Failure" || !strings.Contains(status.Message, "we got an error") {
741+
t.Fatalf("error: %#v", status)
742+
}
743+
744+
// check for leaks
745+
if !watcher.IsStopped() {
746+
t.Errorf("Leaked watcher goruntine after request done")
747+
}
748+
}
749+
674750
func TestWatchHTTPDynamicClientErrors(t *testing.T) {
675751
watcher := watch.NewFake()
676752
timeoutCh := make(chan time.Time)
@@ -696,9 +772,7 @@ func TestWatchHTTPDynamicClientErrors(t *testing.T) {
696772
TimeoutFactory: &fakeTimeoutFactory{timeoutCh, done},
697773
}
698774

699-
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
700-
watchServer.ServeHTTP(w, req)
701-
}))
775+
s := httptest.NewServer(serveWatch(watcher, watchServer, nil))
702776
defer s.Close()
703777
defer s.CloseClientConnections()
704778

@@ -741,9 +815,7 @@ func TestWatchHTTPTimeout(t *testing.T) {
741815
TimeoutFactory: &fakeTimeoutFactory{timeoutCh, done},
742816
}
743817

744-
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
745-
watchServer.ServeHTTP(w, req)
746-
}))
818+
s := httptest.NewServer(serveWatch(watcher, watchServer, nil))
747819
defer s.Close()
748820

749821
// Setup a client
@@ -771,7 +843,7 @@ func TestWatchHTTPTimeout(t *testing.T) {
771843
close(timeoutCh)
772844
select {
773845
case <-done:
774-
if !watcher.Stopped {
846+
if !watcher.IsStopped() {
775847
t.Errorf("Leaked watch on timeout")
776848
}
777849
case <-time.After(wait.ForeverTestTimeout):

0 commit comments

Comments
 (0)