Skip to content

Commit 60eeedf

Browse files
committed
tun: export GSOSplit() for external Device implementers
External implementers of tun.Device may support GSO, and may also be platform-agnostic, e.g. gVisor. Signed-off-by: Jordan Whited <[email protected]>
1 parent 2f5d148 commit 60eeedf

File tree

4 files changed

+353
-161
lines changed

4 files changed

+353
-161
lines changed

tun/offload.go

Lines changed: 220 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,220 @@
1+
package tun
2+
3+
import (
4+
"encoding/binary"
5+
"fmt"
6+
)
7+
8+
// GSOType represents the type of segmentation offload.
9+
type GSOType int
10+
11+
const (
12+
GSONone GSOType = iota
13+
GSOTCPv4
14+
GSOTCPv6
15+
GSOUDPL4
16+
)
17+
18+
func (g GSOType) String() string {
19+
switch g {
20+
case GSONone:
21+
return "GSONone"
22+
case GSOTCPv4:
23+
return "GSOTCPv4"
24+
case GSOTCPv6:
25+
return "GSOTCPv6"
26+
case GSOUDPL4:
27+
return "GSOUDPL4"
28+
default:
29+
return "unknown"
30+
}
31+
}
32+
33+
// GSOOptions is loosely modeled after struct virtio_net_hdr from the VIRTIO
34+
// specification. It is a common representation of GSO metadata that can be
35+
// applied to support packet GSO across tun.Device implementations.
36+
type GSOOptions struct {
37+
// GSOType represents the type of segmentation offload.
38+
GSOType GSOType
39+
// HdrLen is the sum of the layer 3 and 4 header lengths. This field may be
40+
// zero when GSOType == GSONone.
41+
HdrLen uint16
42+
// CsumStart is the head byte index of the packet data to be checksummed,
43+
// i.e. the start of the TCP or UDP header.
44+
CsumStart uint16
45+
// CsumOffset is the offset from CsumStart where the 2-byte checksum value
46+
// should be placed.
47+
CsumOffset uint16
48+
// GSOSize is the size of each segment exclusive of HdrLen. The tail segment
49+
// may be smaller than this value.
50+
GSOSize uint16
51+
// NeedsCsum may be set where GSOType == GSONone. When set, the checksum
52+
// at CsumStart + CsumOffset must be a partial checksum, i.e. the
53+
// pseudo-header sum.
54+
NeedsCsum bool
55+
}
56+
57+
const (
58+
ipv4SrcAddrOffset = 12
59+
ipv6SrcAddrOffset = 8
60+
)
61+
62+
const tcpFlagsOffset = 13
63+
64+
const (
65+
tcpFlagFIN uint8 = 0x01
66+
tcpFlagPSH uint8 = 0x08
67+
tcpFlagACK uint8 = 0x10
68+
)
69+
70+
const (
71+
// defined here in order to avoid importation of any platform-specific pkgs
72+
ipProtoTCP = 6
73+
ipProtoUDP = 17
74+
)
75+
76+
// GSOSplit splits packets from 'in' into outBufs[<index>][outOffset:], writing
77+
// the size of each element into sizes. It returns the number of buffers
78+
// populated, and/or an error. Callers may pass an 'in' slice that overlaps with
79+
// the first element of outBuffers, i.e. &in[0] may be equal to
80+
// &outBufs[0][outOffset]. GSONone is a valid options.GSOType regardless of the
81+
// value of options.NeedsCsum. Length of each outBufs element must be greater
82+
// than or equal to the length of 'in', otherwise output may be silently
83+
// truncated.
84+
func GSOSplit(in []byte, options GSOOptions, outBufs [][]byte, sizes []int, outOffset int) (int, error) {
85+
cSumAt := int(options.CsumStart) + int(options.CsumOffset)
86+
if cSumAt+1 >= len(in) {
87+
return 0, fmt.Errorf("end of checksum offset (%d) exceeds packet length (%d)", cSumAt+1, len(in))
88+
}
89+
90+
if len(in) < int(options.HdrLen) {
91+
return 0, fmt.Errorf("length of packet (%d) < GSO HdrLen (%d)", len(in), options.HdrLen)
92+
}
93+
94+
// Handle the conditions where we are copying a single element to outBuffs.
95+
payloadLen := len(in) - int(options.HdrLen)
96+
if options.GSOType == GSONone || payloadLen < int(options.GSOSize) {
97+
if len(in) > len(outBufs[0][outOffset:]) {
98+
return 0, fmt.Errorf("length of packet (%d) exceeds output element length (%d)", len(in), len(outBufs[0][outOffset:]))
99+
}
100+
if options.NeedsCsum {
101+
// The initial value at the checksum offset should be summed with
102+
// the checksum we compute. This is typically the pseudo-header sum.
103+
initial := binary.BigEndian.Uint16(in[cSumAt:])
104+
in[cSumAt], in[cSumAt+1] = 0, 0
105+
binary.BigEndian.PutUint16(in[cSumAt:], ^checksum(in[options.CsumStart:], initial))
106+
}
107+
sizes[0] = copy(outBufs[0][outOffset:], in)
108+
return 1, nil
109+
}
110+
111+
if options.HdrLen < options.CsumStart {
112+
return 0, fmt.Errorf("GSO HdrLen (%d) < GSO CsumStart (%d)", options.HdrLen, options.CsumStart)
113+
}
114+
115+
ipVersion := in[0] >> 4
116+
switch ipVersion {
117+
case 4:
118+
if options.GSOType != GSOTCPv4 && options.GSOType != GSOUDPL4 {
119+
return 0, fmt.Errorf("ip header version: %d, GSO type: %s", ipVersion, options.GSOType)
120+
}
121+
if len(in) < 20 {
122+
return 0, fmt.Errorf("length of packet (%d) < minimum ipv4 header size (%d)", len(in), 20)
123+
}
124+
case 6:
125+
if options.GSOType != GSOTCPv6 && options.GSOType != GSOUDPL4 {
126+
return 0, fmt.Errorf("ip header version: %d, GSO type: %s", ipVersion, options.GSOType)
127+
}
128+
if len(in) < 40 {
129+
return 0, fmt.Errorf("length of packet (%d) < minimum ipv6 header size (%d)", len(in), 40)
130+
}
131+
default:
132+
return 0, fmt.Errorf("invalid ip header version: %d", ipVersion)
133+
}
134+
135+
iphLen := int(options.CsumStart)
136+
srcAddrOffset := ipv6SrcAddrOffset
137+
addrLen := 16
138+
if ipVersion == 4 {
139+
srcAddrOffset = ipv4SrcAddrOffset
140+
addrLen = 4
141+
}
142+
transportCsumAt := int(options.CsumStart + options.CsumOffset)
143+
var firstTCPSeqNum uint32
144+
var protocol uint8
145+
if options.GSOType == GSOTCPv4 || options.GSOType == GSOTCPv6 {
146+
protocol = ipProtoTCP
147+
if len(in) < int(options.CsumStart)+20 {
148+
return 0, fmt.Errorf("length of packet (%d) < GSO CsumStart (%d) + minimum TCP header size (%d)",
149+
len(in), options.CsumStart, 20)
150+
}
151+
firstTCPSeqNum = binary.BigEndian.Uint32(in[options.CsumStart+4:])
152+
} else {
153+
protocol = ipProtoUDP
154+
}
155+
nextSegmentDataAt := int(options.HdrLen)
156+
i := 0
157+
for ; nextSegmentDataAt < len(in); i++ {
158+
if i == len(outBufs) {
159+
return i - 1, ErrTooManySegments
160+
}
161+
nextSegmentEnd := nextSegmentDataAt + int(options.GSOSize)
162+
if nextSegmentEnd > len(in) {
163+
nextSegmentEnd = len(in)
164+
}
165+
segmentDataLen := nextSegmentEnd - nextSegmentDataAt
166+
totalLen := int(options.HdrLen) + segmentDataLen
167+
sizes[i] = totalLen
168+
out := outBufs[i][outOffset:]
169+
170+
copy(out, in[:iphLen])
171+
if ipVersion == 4 {
172+
// For IPv4 we are responsible for incrementing the ID field,
173+
// updating the total len field, and recalculating the header
174+
// checksum.
175+
if i > 0 {
176+
id := binary.BigEndian.Uint16(out[4:])
177+
id += uint16(i)
178+
binary.BigEndian.PutUint16(out[4:], id)
179+
}
180+
out[10], out[11] = 0, 0 // clear ipv4 header checksum
181+
binary.BigEndian.PutUint16(out[2:], uint16(totalLen))
182+
ipv4CSum := ^checksum(out[:iphLen], 0)
183+
binary.BigEndian.PutUint16(out[10:], ipv4CSum)
184+
} else {
185+
// For IPv6 we are responsible for updating the payload length field.
186+
binary.BigEndian.PutUint16(out[4:], uint16(totalLen-iphLen))
187+
}
188+
189+
// copy transport header
190+
copy(out[options.CsumStart:options.HdrLen], in[options.CsumStart:options.HdrLen])
191+
192+
if protocol == ipProtoTCP {
193+
// set TCP seq and adjust TCP flags
194+
tcpSeq := firstTCPSeqNum + uint32(options.GSOSize*uint16(i))
195+
binary.BigEndian.PutUint32(out[options.CsumStart+4:], tcpSeq)
196+
if nextSegmentEnd != len(in) {
197+
// FIN and PSH should only be set on last segment
198+
clearFlags := tcpFlagFIN | tcpFlagPSH
199+
out[options.CsumStart+tcpFlagsOffset] &^= clearFlags
200+
}
201+
} else {
202+
// set UDP header len
203+
binary.BigEndian.PutUint16(out[options.CsumStart+4:], uint16(segmentDataLen)+(options.HdrLen-options.CsumStart))
204+
}
205+
206+
// payload
207+
copy(out[options.HdrLen:], in[nextSegmentDataAt:nextSegmentEnd])
208+
209+
// transport checksum
210+
out[transportCsumAt], out[transportCsumAt+1] = 0, 0 // clear tcp/udp checksum
211+
transportHeaderLen := int(options.HdrLen - options.CsumStart)
212+
lenForPseudo := uint16(transportHeaderLen + segmentDataLen)
213+
transportCSum := pseudoHeaderChecksum(protocol, in[srcAddrOffset:srcAddrOffset+addrLen], in[srcAddrOffset+addrLen:srcAddrOffset+addrLen*2], lenForPseudo)
214+
transportCSum = ^checksum(out[options.CsumStart:totalLen], transportCSum)
215+
binary.BigEndian.PutUint16(out[options.CsumStart+options.CsumOffset:], transportCSum)
216+
217+
nextSegmentDataAt += int(options.GSOSize)
218+
}
219+
return i, nil
220+
}

tun/offload_linux.go

Lines changed: 26 additions & 108 deletions
Original file line numberDiff line numberDiff line change
@@ -9,21 +9,14 @@ import (
99
"bytes"
1010
"encoding/binary"
1111
"errors"
12+
"fmt"
1213
"io"
1314
"unsafe"
1415

1516
"github.com/tailscale/wireguard-go/conn"
1617
"golang.org/x/sys/unix"
1718
)
1819

19-
const tcpFlagsOffset = 13
20-
21-
const (
22-
tcpFlagFIN uint8 = 0x01
23-
tcpFlagPSH uint8 = 0x08
24-
tcpFlagACK uint8 = 0x10
25-
)
26-
2720
// virtioNetHdr is defined in the kernel in include/uapi/linux/virtio_net.h. The
2821
// kernel symbol is virtio_net_hdr.
2922
type virtioNetHdr struct {
@@ -35,6 +28,30 @@ type virtioNetHdr struct {
3528
csumOffset uint16
3629
}
3730

31+
func (v *virtioNetHdr) toGSOOptions() (GSOOptions, error) {
32+
var gsoType GSOType
33+
switch v.gsoType {
34+
case unix.VIRTIO_NET_HDR_GSO_NONE:
35+
gsoType = GSONone
36+
case unix.VIRTIO_NET_HDR_GSO_TCPV4:
37+
gsoType = GSOTCPv4
38+
case unix.VIRTIO_NET_HDR_GSO_TCPV6:
39+
gsoType = GSOTCPv6
40+
case unix.VIRTIO_NET_HDR_GSO_UDP_L4:
41+
gsoType = GSOUDPL4
42+
default:
43+
return GSOOptions{}, fmt.Errorf("unsupported virtio gsoType: %d", v.gsoType)
44+
}
45+
return GSOOptions{
46+
GSOType: gsoType,
47+
HdrLen: v.hdrLen,
48+
CsumStart: v.csumStart,
49+
CsumOffset: v.csumOffset,
50+
GSOSize: v.gsoSize,
51+
NeedsCsum: v.flags&unix.VIRTIO_NET_HDR_F_NEEDS_CSUM != 0,
52+
}, nil
53+
}
54+
3855
func (v *virtioNetHdr) decode(b []byte) error {
3956
if len(b) < virtioNetHdrLen {
4057
return io.ErrShortBuffer
@@ -510,9 +527,7 @@ const (
510527
)
511528

512529
const (
513-
ipv4SrcAddrOffset = 12
514-
ipv6SrcAddrOffset = 8
515-
maxUint16 = 1<<16 - 1
530+
maxUint16 = 1<<16 - 1
516531
)
517532

518533
type groResult int
@@ -894,100 +909,3 @@ func handleGRO(bufs [][]byte, offset int, tcpTable *tcpGROTable, udpTable *udpGR
894909
errUDP := applyUDPCoalesceAccounting(bufs, offset, udpTable)
895910
return errors.Join(errTCP, errUDP)
896911
}
897-
898-
// gsoSplit splits packets from in into outBuffs, writing the size of each
899-
// element into sizes. It returns the number of buffers populated, and/or an
900-
// error.
901-
func gsoSplit(in []byte, hdr virtioNetHdr, outBuffs [][]byte, sizes []int, outOffset int, isV6 bool) (int, error) {
902-
iphLen := int(hdr.csumStart)
903-
srcAddrOffset := ipv6SrcAddrOffset
904-
addrLen := 16
905-
if !isV6 {
906-
in[10], in[11] = 0, 0 // clear ipv4 header checksum
907-
srcAddrOffset = ipv4SrcAddrOffset
908-
addrLen = 4
909-
}
910-
transportCsumAt := int(hdr.csumStart + hdr.csumOffset)
911-
in[transportCsumAt], in[transportCsumAt+1] = 0, 0 // clear tcp/udp checksum
912-
var firstTCPSeqNum uint32
913-
var protocol uint8
914-
if hdr.gsoType == unix.VIRTIO_NET_HDR_GSO_TCPV4 || hdr.gsoType == unix.VIRTIO_NET_HDR_GSO_TCPV6 {
915-
protocol = unix.IPPROTO_TCP
916-
firstTCPSeqNum = binary.BigEndian.Uint32(in[hdr.csumStart+4:])
917-
} else {
918-
protocol = unix.IPPROTO_UDP
919-
}
920-
nextSegmentDataAt := int(hdr.hdrLen)
921-
i := 0
922-
for ; nextSegmentDataAt < len(in); i++ {
923-
if i == len(outBuffs) {
924-
return i - 1, ErrTooManySegments
925-
}
926-
nextSegmentEnd := nextSegmentDataAt + int(hdr.gsoSize)
927-
if nextSegmentEnd > len(in) {
928-
nextSegmentEnd = len(in)
929-
}
930-
segmentDataLen := nextSegmentEnd - nextSegmentDataAt
931-
totalLen := int(hdr.hdrLen) + segmentDataLen
932-
sizes[i] = totalLen
933-
out := outBuffs[i][outOffset:]
934-
935-
copy(out, in[:iphLen])
936-
if !isV6 {
937-
// For IPv4 we are responsible for incrementing the ID field,
938-
// updating the total len field, and recalculating the header
939-
// checksum.
940-
if i > 0 {
941-
id := binary.BigEndian.Uint16(out[4:])
942-
id += uint16(i)
943-
binary.BigEndian.PutUint16(out[4:], id)
944-
}
945-
binary.BigEndian.PutUint16(out[2:], uint16(totalLen))
946-
ipv4CSum := ^checksum(out[:iphLen], 0)
947-
binary.BigEndian.PutUint16(out[10:], ipv4CSum)
948-
} else {
949-
// For IPv6 we are responsible for updating the payload length field.
950-
binary.BigEndian.PutUint16(out[4:], uint16(totalLen-iphLen))
951-
}
952-
953-
// copy transport header
954-
copy(out[hdr.csumStart:hdr.hdrLen], in[hdr.csumStart:hdr.hdrLen])
955-
956-
if protocol == unix.IPPROTO_TCP {
957-
// set TCP seq and adjust TCP flags
958-
tcpSeq := firstTCPSeqNum + uint32(hdr.gsoSize*uint16(i))
959-
binary.BigEndian.PutUint32(out[hdr.csumStart+4:], tcpSeq)
960-
if nextSegmentEnd != len(in) {
961-
// FIN and PSH should only be set on last segment
962-
clearFlags := tcpFlagFIN | tcpFlagPSH
963-
out[hdr.csumStart+tcpFlagsOffset] &^= clearFlags
964-
}
965-
} else {
966-
// set UDP header len
967-
binary.BigEndian.PutUint16(out[hdr.csumStart+4:], uint16(segmentDataLen)+(hdr.hdrLen-hdr.csumStart))
968-
}
969-
970-
// payload
971-
copy(out[hdr.hdrLen:], in[nextSegmentDataAt:nextSegmentEnd])
972-
973-
// transport checksum
974-
transportHeaderLen := int(hdr.hdrLen - hdr.csumStart)
975-
lenForPseudo := uint16(transportHeaderLen + segmentDataLen)
976-
transportCSum := pseudoHeaderChecksum(protocol, in[srcAddrOffset:srcAddrOffset+addrLen], in[srcAddrOffset+addrLen:srcAddrOffset+addrLen*2], lenForPseudo)
977-
transportCSum = ^checksum(out[hdr.csumStart:totalLen], transportCSum)
978-
binary.BigEndian.PutUint16(out[hdr.csumStart+hdr.csumOffset:], transportCSum)
979-
980-
nextSegmentDataAt += int(hdr.gsoSize)
981-
}
982-
return i, nil
983-
}
984-
985-
func gsoNoneChecksum(in []byte, cSumStart, cSumOffset uint16) error {
986-
cSumAt := cSumStart + cSumOffset
987-
// The initial value at the checksum offset should be summed with the
988-
// checksum we compute. This is typically the pseudo-header checksum.
989-
initial := binary.BigEndian.Uint16(in[cSumAt:])
990-
in[cSumAt], in[cSumAt+1] = 0, 0
991-
binary.BigEndian.PutUint16(in[cSumAt:], ^checksum(in[cSumStart:], initial))
992-
return nil
993-
}

0 commit comments

Comments
 (0)