Skip to content

Commit e194da5

Browse files
robmryaboch
authored andcommitted
Fix SetSendTimeout/SetReceiveTimeout
They were implemented using SO_SNDTIMEO/SO_RCVTIMEO on the socket descriptor - but that doesn't work now the socket is non-blocking. Instead, set deadlines on the file read/write. Signed-off-by: Rob Murray <[email protected]>
1 parent 0cd1f79 commit e194da5

File tree

3 files changed

+137
-40
lines changed

3 files changed

+137
-40
lines changed

handle_test.go

Lines changed: 12 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,9 @@ import (
1212
"sync/atomic"
1313
"testing"
1414
"time"
15-
"unsafe"
1615

1716
"github.com/vishvananda/netlink/nl"
1817
"github.com/vishvananda/netns"
19-
"golang.org/x/sys/unix"
2018
)
2119

2220
func TestHandleCreateClose(t *testing.T) {
@@ -122,13 +120,22 @@ func TestHandleTimeout(t *testing.T) {
122120
defer h.Close()
123121

124122
for _, sh := range h.sockets {
125-
verifySockTimeVal(t, sh.Socket.GetFd(), unix.Timeval{Sec: 0, Usec: 0})
123+
verifySockTimeVal(t, sh.Socket, time.Duration(0))
126124
}
127125

128-
h.SetSocketTimeout(2*time.Second + 8*time.Millisecond)
126+
const timeout = 2*time.Second + 8*time.Millisecond
127+
h.SetSocketTimeout(timeout)
129128

130129
for _, sh := range h.sockets {
131-
verifySockTimeVal(t, sh.Socket.GetFd(), unix.Timeval{Sec: 2, Usec: 8000})
130+
verifySockTimeVal(t, sh.Socket, timeout)
131+
}
132+
}
133+
134+
func verifySockTimeVal(t *testing.T, socket *nl.NetlinkSocket, expTimeout time.Duration) {
135+
t.Helper()
136+
send, receive := socket.GetTimeouts()
137+
if send != expTimeout || receive != expTimeout {
138+
t.Fatalf("Expected timeout: %v, got Send: %v, Receive: %v", expTimeout, send, receive)
132139
}
133140
}
134141

@@ -157,30 +164,6 @@ func TestHandleReceiveBuffer(t *testing.T) {
157164
}
158165
}
159166

160-
func verifySockTimeVal(t *testing.T, fd int, tv unix.Timeval) {
161-
var (
162-
tr unix.Timeval
163-
v = uint32(0x10)
164-
)
165-
_, _, errno := unix.Syscall6(unix.SYS_GETSOCKOPT, uintptr(fd), unix.SOL_SOCKET, unix.SO_SNDTIMEO, uintptr(unsafe.Pointer(&tr)), uintptr(unsafe.Pointer(&v)), 0)
166-
if errno != 0 {
167-
t.Fatal(errno)
168-
}
169-
170-
if tr.Sec != tv.Sec || tr.Usec != tv.Usec {
171-
t.Fatalf("Unexpected timeout value read: %v. Expected: %v", tr, tv)
172-
}
173-
174-
_, _, errno = unix.Syscall6(unix.SYS_GETSOCKOPT, uintptr(fd), unix.SOL_SOCKET, unix.SO_RCVTIMEO, uintptr(unsafe.Pointer(&tr)), uintptr(unsafe.Pointer(&v)), 0)
175-
if errno != 0 {
176-
t.Fatal(errno)
177-
}
178-
179-
if tr.Sec != tv.Sec || tr.Usec != tv.Usec {
180-
t.Fatalf("Unexpected timeout value read: %v. Expected: %v", tr, tv)
181-
}
182-
}
183-
184167
var (
185168
iter = 10
186169
numThread = uint32(4)

nl/nl_linux.go

Lines changed: 62 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,15 @@ package nl
44
import (
55
"bytes"
66
"encoding/binary"
7+
"errors"
78
"fmt"
89
"net"
910
"os"
1011
"runtime"
1112
"sync"
1213
"sync/atomic"
1314
"syscall"
15+
"time"
1416
"unsafe"
1517

1618
"github.com/vishvananda/netns"
@@ -656,9 +658,11 @@ func NewNetlinkRequest(proto, flags int) *NetlinkRequest {
656658
}
657659

658660
type NetlinkSocket struct {
659-
fd int32
660-
file *os.File
661-
lsa unix.SockaddrNetlink
661+
fd int32
662+
file *os.File
663+
lsa unix.SockaddrNetlink
664+
sendTimeout int64 // Access using atomic.Load/StoreInt64
665+
receiveTimeout int64 // Access using atomic.Load/StoreInt64
662666
sync.Mutex
663667
}
664668

@@ -802,8 +806,44 @@ func (s *NetlinkSocket) GetFd() int {
802806
return int(s.fd)
803807
}
804808

809+
func (s *NetlinkSocket) GetTimeouts() (send, receive time.Duration) {
810+
return time.Duration(atomic.LoadInt64(&s.sendTimeout)),
811+
time.Duration(atomic.LoadInt64(&s.receiveTimeout))
812+
}
813+
805814
func (s *NetlinkSocket) Send(request *NetlinkRequest) error {
806-
return unix.Sendto(int(s.fd), request.Serialize(), 0, &s.lsa)
815+
rawConn, err := s.file.SyscallConn()
816+
if err != nil {
817+
return err
818+
}
819+
var (
820+
deadline time.Time
821+
innerErr error
822+
)
823+
sendTimeout := atomic.LoadInt64(&s.sendTimeout)
824+
if sendTimeout != 0 {
825+
deadline = time.Now().Add(time.Duration(sendTimeout))
826+
}
827+
if err := s.file.SetWriteDeadline(deadline); err != nil {
828+
return err
829+
}
830+
serializedReq := request.Serialize()
831+
err = rawConn.Write(func(fd uintptr) (done bool) {
832+
innerErr = unix.Sendto(int(s.fd), serializedReq, 0, &s.lsa)
833+
return innerErr != unix.EWOULDBLOCK
834+
})
835+
if innerErr != nil {
836+
return innerErr
837+
}
838+
if err != nil {
839+
// The timeout was previously implemented using SO_SNDTIMEO on a blocking
840+
// socket. So, continue to return EAGAIN when the timeout is reached.
841+
if errors.Is(err, os.ErrDeadlineExceeded) {
842+
return unix.EAGAIN
843+
}
844+
return err
845+
}
846+
return nil
807847
}
808848

809849
func (s *NetlinkSocket) Receive() ([]syscall.NetlinkMessage, *unix.SockaddrNetlink, error) {
@@ -812,20 +852,33 @@ func (s *NetlinkSocket) Receive() ([]syscall.NetlinkMessage, *unix.SockaddrNetli
812852
return nil, nil, err
813853
}
814854
var (
855+
deadline time.Time
815856
fromAddr *unix.SockaddrNetlink
816857
rb [RECEIVE_BUFFER_SIZE]byte
817858
nr int
818859
from unix.Sockaddr
819860
innerErr error
820861
)
862+
receiveTimeout := atomic.LoadInt64(&s.receiveTimeout)
863+
if receiveTimeout != 0 {
864+
deadline = time.Now().Add(time.Duration(receiveTimeout))
865+
}
866+
if err := s.file.SetReadDeadline(deadline); err != nil {
867+
return nil, nil, err
868+
}
821869
err = rawConn.Read(func(fd uintptr) (done bool) {
822870
nr, from, innerErr = unix.Recvfrom(int(fd), rb[:], 0)
823871
return innerErr != unix.EWOULDBLOCK
824872
})
825873
if innerErr != nil {
826-
err = innerErr
874+
return nil, nil, innerErr
827875
}
828876
if err != nil {
877+
// The timeout was previously implemented using SO_RCVTIMEO on a blocking
878+
// socket. So, continue to return EAGAIN when the timeout is reached.
879+
if errors.Is(err, os.ErrDeadlineExceeded) {
880+
return nil, nil, unix.EAGAIN
881+
}
829882
return nil, nil, err
830883
}
831884
fromAddr, ok := from.(*unix.SockaddrNetlink)
@@ -847,16 +900,14 @@ func (s *NetlinkSocket) Receive() ([]syscall.NetlinkMessage, *unix.SockaddrNetli
847900

848901
// SetSendTimeout allows to set a send timeout on the socket
849902
func (s *NetlinkSocket) SetSendTimeout(timeout *unix.Timeval) error {
850-
// Set a send timeout of SOCKET_SEND_TIMEOUT, this will allow the Send to periodically unblock and avoid that a routine
851-
// remains stuck on a send on a closed fd
852-
return unix.SetsockoptTimeval(int(s.fd), unix.SOL_SOCKET, unix.SO_SNDTIMEO, timeout)
903+
atomic.StoreInt64(&s.sendTimeout, timeout.Nano())
904+
return nil
853905
}
854906

855907
// SetReceiveTimeout allows to set a receive timeout on the socket
856908
func (s *NetlinkSocket) SetReceiveTimeout(timeout *unix.Timeval) error {
857-
// Set a read timeout of SOCKET_READ_TIMEOUT, this will allow the Read to periodically unblock and avoid that a routine
858-
// remains stuck on a recvmsg on a closed fd
859-
return unix.SetsockoptTimeval(int(s.fd), unix.SOL_SOCKET, unix.SO_RCVTIMEO, timeout)
909+
atomic.StoreInt64(&s.receiveTimeout, timeout.Nano())
910+
return nil
860911
}
861912

862913
// SetReceiveBufferSize allows to set a receive buffer size on the socket

nl/nl_linux_test.go

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,69 @@ func TestIfSocketCloses(t *testing.T) {
9797
}
9898
}
9999

100+
func TestReceiveTimeout(t *testing.T) {
101+
nlSock, err := getNetlinkSocket(unix.NETLINK_ROUTE)
102+
if err != nil {
103+
t.Fatalf("Error creating the socket: %v", err)
104+
}
105+
// Even if the test fails because the timeout doesn't work, closing the
106+
// socket at the end of the test should result in an EAGAIN (as long as
107+
// TestIfSocketCloses completed, otherwise this test will leak the
108+
// goroutines running the Receive).
109+
defer nlSock.Close()
110+
const failAfter = time.Second
111+
112+
tests := []struct {
113+
name string
114+
timeout time.Duration
115+
}{
116+
{
117+
name: "1us timeout", // The smallest value accepted by Handle.SetSocketTimeout
118+
timeout: time.Microsecond,
119+
},
120+
{
121+
name: "100ms timeout",
122+
timeout: 100 * time.Millisecond,
123+
},
124+
{
125+
name: "500ms timeout",
126+
timeout: 500 * time.Millisecond,
127+
},
128+
}
129+
for _, tc := range tests {
130+
tc := tc
131+
t.Run(tc.name, func(t *testing.T) {
132+
timeout := unix.NsecToTimeval(int64(tc.timeout))
133+
nlSock.SetReceiveTimeout(&timeout)
134+
135+
doneC := make(chan time.Duration)
136+
errC := make(chan error)
137+
go func() {
138+
start := time.Now()
139+
_, _, err := nlSock.Receive()
140+
dur := time.Since(start)
141+
if err != unix.EAGAIN {
142+
errC <- err
143+
return
144+
}
145+
doneC <- dur
146+
}()
147+
148+
failTimerC := time.After(failAfter)
149+
select {
150+
case dur := <-doneC:
151+
if dur < tc.timeout || dur > (tc.timeout+(100*time.Millisecond)) {
152+
t.Fatalf("Expected timeout %v got %v", tc.timeout, dur)
153+
}
154+
case err := <-errC:
155+
t.Fatalf("Expected EAGAIN, but got: %v", err)
156+
case <-failTimerC:
157+
t.Fatalf("No timeout received")
158+
}
159+
})
160+
}
161+
}
162+
100163
func (msg *CnMsgOp) write(b []byte) {
101164
native := NativeEndian()
102165
native.PutUint32(b[0:4], msg.ID.Idx)

0 commit comments

Comments
 (0)