Skip to content

Commit 6f57139

Browse files
nekohasekaiaboch
authored andcommitted
Fix recvfrom goroutine leak
1 parent 298ff27 commit 6f57139

File tree

2 files changed

+38
-26
lines changed

2 files changed

+38
-26
lines changed

nl/nl_linux.go

Lines changed: 37 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import (
66
"encoding/binary"
77
"fmt"
88
"net"
9+
"os"
910
"runtime"
1011
"sync"
1112
"sync/atomic"
@@ -655,8 +656,9 @@ func NewNetlinkRequest(proto, flags int) *NetlinkRequest {
655656
}
656657

657658
type NetlinkSocket struct {
658-
fd int32
659-
lsa unix.SockaddrNetlink
659+
fd int32
660+
file *os.File
661+
lsa unix.SockaddrNetlink
660662
sync.Mutex
661663
}
662664

@@ -665,8 +667,13 @@ func getNetlinkSocket(protocol int) (*NetlinkSocket, error) {
665667
if err != nil {
666668
return nil, err
667669
}
670+
err = unix.SetNonblock(fd, true)
671+
if err != nil {
672+
return nil, err
673+
}
668674
s := &NetlinkSocket{
669-
fd: int32(fd),
675+
fd: int32(fd),
676+
file: os.NewFile(uintptr(fd), "netlink"),
670677
}
671678
s.lsa.Family = unix.AF_NETLINK
672679
if err := unix.Bind(fd, &s.lsa); err != nil {
@@ -753,8 +760,13 @@ func Subscribe(protocol int, groups ...uint) (*NetlinkSocket, error) {
753760
if err != nil {
754761
return nil, err
755762
}
763+
err = unix.SetNonblock(fd, true)
764+
if err != nil {
765+
return nil, err
766+
}
756767
s := &NetlinkSocket{
757-
fd: int32(fd),
768+
fd: int32(fd),
769+
file: os.NewFile(uintptr(fd), "netlink"),
758770
}
759771
s.lsa.Family = unix.AF_NETLINK
760772

@@ -783,33 +795,36 @@ func SubscribeAt(newNs, curNs netns.NsHandle, protocol int, groups ...uint) (*Ne
783795
}
784796

785797
func (s *NetlinkSocket) Close() {
786-
fd := int(atomic.SwapInt32(&s.fd, -1))
787-
unix.Close(fd)
798+
s.file.Close()
788799
}
789800

790801
func (s *NetlinkSocket) GetFd() int {
791-
return int(atomic.LoadInt32(&s.fd))
802+
return int(s.fd)
792803
}
793804

794805
func (s *NetlinkSocket) Send(request *NetlinkRequest) error {
795-
fd := int(atomic.LoadInt32(&s.fd))
796-
if fd < 0 {
797-
return fmt.Errorf("Send called on a closed socket")
798-
}
799-
if err := unix.Sendto(fd, request.Serialize(), 0, &s.lsa); err != nil {
800-
return err
801-
}
802-
return nil
806+
return unix.Sendto(int(s.fd), request.Serialize(), 0, &s.lsa)
803807
}
804808

805809
func (s *NetlinkSocket) Receive() ([]syscall.NetlinkMessage, *unix.SockaddrNetlink, error) {
806-
fd := int(atomic.LoadInt32(&s.fd))
807-
if fd < 0 {
808-
return nil, nil, fmt.Errorf("Receive called on a closed socket")
810+
rawConn, err := s.file.SyscallConn()
811+
if err != nil {
812+
return nil, nil, err
813+
}
814+
var (
815+
fromAddr *unix.SockaddrNetlink
816+
rb [RECEIVE_BUFFER_SIZE]byte
817+
nr int
818+
from unix.Sockaddr
819+
innerErr error
820+
)
821+
err = rawConn.Read(func(fd uintptr) (done bool) {
822+
nr, from, innerErr = unix.Recvfrom(int(fd), rb[:], 0)
823+
return innerErr != unix.EWOULDBLOCK
824+
})
825+
if innerErr != nil {
826+
err = innerErr
809827
}
810-
var fromAddr *unix.SockaddrNetlink
811-
var rb [RECEIVE_BUFFER_SIZE]byte
812-
nr, from, err := unix.Recvfrom(fd, rb[:], 0)
813828
if err != nil {
814829
return nil, nil, err
815830
}
@@ -864,8 +879,7 @@ func (s *NetlinkSocket) SetExtAck(enable bool) error {
864879
}
865880

866881
func (s *NetlinkSocket) GetPid() (uint32, error) {
867-
fd := int(atomic.LoadInt32(&s.fd))
868-
lsa, err := unix.Getsockname(fd)
882+
lsa, err := unix.Getsockname(int(s.fd))
869883
if err != nil {
870884
return 0, err
871885
}

nl/nl_linux_test.go

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -68,14 +68,12 @@ func TestIfSocketCloses(t *testing.T) {
6868
if err != nil {
6969
t.Fatalf("Error on creating the socket: %v", err)
7070
}
71-
nlSock.SetReceiveTimeout(&unix.Timeval{Sec: 2, Usec: 0})
7271
endCh := make(chan error)
7372
go func(sk *NetlinkSocket, endCh chan error) {
7473
endCh <- nil
7574
for {
7675
_, _, err := sk.Receive()
77-
// Receive returned because of a timeout and the FD == -1 means that the socket got closed
78-
if err == unix.EAGAIN && nlSock.GetFd() == -1 {
76+
if err == unix.EAGAIN {
7977
endCh <- err
8078
return
8179
}

0 commit comments

Comments
 (0)