Skip to content

Commit 0e225bc

Browse files
committed
listen on appropriate IP address
1 parent fd58fe0 commit 0e225bc

5 files changed

Lines changed: 106 additions & 64 deletions

File tree

tsshd/main.go

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@ import (
3636
"time"
3737
)
3838

39+
var kDefaultConnectTimeout = 10 * time.Second
40+
3941
var exitChan = make(chan int, 1)
4042

4143
type tsshdArgs struct {
@@ -158,6 +160,11 @@ func TsshdMain() int {
158160
// cleanup on exit
159161
defer cleanupOnExit()
160162

163+
// default connect timeout
164+
if args.ConnectTimeout <= 0 {
165+
args.ConnectTimeout = kDefaultConnectTimeout
166+
}
167+
161168
// handle exit signals
162169
handleExitSignals()
163170

@@ -181,11 +188,7 @@ func TsshdMain() int {
181188

182189
go func() {
183190
// should be connected in time
184-
connectTimeout := args.ConnectTimeout
185-
if connectTimeout <= 0 {
186-
connectTimeout = 10 * time.Second
187-
}
188-
time.Sleep(connectTimeout)
191+
time.Sleep(args.ConnectTimeout)
189192
if !serving.Load() {
190193
exitChan <- 1
191194
}

tsshd/proto.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -327,6 +327,7 @@ func newKcpClient(addr string, info *ServerInfo) (udpClient, error) {
327327
if err != nil {
328328
return nil, fmt.Errorf("kcp dial [%s] failed: %v", addr, err)
329329
}
330+
conn.SetWindowSize(1024, 1024)
330331
conn.SetNoDelay(1, 10, 2, 1)
331332
session, err := smux.Client(conn, &smuxConfig)
332333
if err != nil {

tsshd/proxy.go

Lines changed: 27 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -81,15 +81,16 @@ type udpBuffer struct {
8181
}
8282

8383
type serverProxy struct {
84-
frontendList []*net.UDPConn
85-
backendConn *net.UDPConn
86-
clientConn atomic.Pointer[udpConn]
87-
authedConn *udpConn
88-
cipherBlock *cipher.Block
89-
clientID uint64
90-
serverID uint64
91-
serialNumber uint64
92-
bufChan chan *udpBuffer
84+
connectTimeout time.Duration
85+
frontendList []*net.UDPConn
86+
backendConn *net.UDPConn
87+
clientConn atomic.Pointer[udpConn]
88+
authedConn *udpConn
89+
cipherBlock *cipher.Block
90+
clientID uint64
91+
serverID uint64
92+
serialNumber uint64
93+
bufChan chan *udpBuffer
9394
}
9495

9596
func (p *serverProxy) isClientConn(conn *udpConn) bool {
@@ -191,13 +192,22 @@ func (p *serverProxy) backendToFrontend() {
191192
}
192193

193194
func (p *serverProxy) serveFrontendConn(conn *net.UDPConn) {
195+
defer func() { _ = conn.Close() }()
196+
beginTime := time.Now()
197+
neverReceived := true
198+
194199
current := 0
195200
buffers := [2][]byte{make([]byte, 0xffff), make([]byte, 0xffff)}
196201
for {
202+
_ = conn.SetReadDeadline(time.Now().Add(p.connectTimeout))
197203
n, addr, err := conn.ReadFromUDP(buffers[current])
198204
if err != nil || n <= 0 {
205+
if neverReceived && time.Since(beginTime) > p.connectTimeout {
206+
return
207+
}
199208
continue
200209
}
210+
neverReceived = false
201211
p.bufChan <- &udpBuffer{
202212
conn: &udpConn{
203213
frontendConn: conn,
@@ -221,7 +231,7 @@ func (p *serverProxy) serveProxy() {
221231
go p.backendToFrontend()
222232
}
223233

224-
func startServerProxy(frontendList []*net.UDPConn, info *ServerInfo) ([]*net.UDPConn, error) {
234+
func startServerProxy(frontendList []*net.UDPConn, info *ServerInfo, connectTimeout time.Duration) ([]*net.UDPConn, error) {
225235
localAddr := "127.0.0.1:0"
226236
udpAddr, err := net.ResolveUDPAddr("udp", localAddr)
227237
if err != nil {
@@ -263,12 +273,13 @@ func startServerProxy(frontendList []*net.UDPConn, info *ServerInfo) ([]*net.UDP
263273
info.ServerID = binary.BigEndian.Uint64(serverID)
264274

265275
proxy := &serverProxy{
266-
frontendList: frontendList,
267-
backendConn: backendConn,
268-
cipherBlock: &cipherBlock,
269-
clientID: info.ClientID,
270-
serverID: info.ServerID,
271-
bufChan: make(chan *udpBuffer), // unbuffered channel to avaid copying buffer
276+
connectTimeout: connectTimeout,
277+
frontendList: frontendList,
278+
backendConn: backendConn,
279+
cipherBlock: &cipherBlock,
280+
clientID: info.ClientID,
281+
serverID: info.ServerID,
282+
bufChan: make(chan *udpBuffer), // unbuffered channel to avaid copying buffer
272283
}
273284
go proxy.serveProxy()
274285

tsshd/server.go

Lines changed: 66 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ import (
3737
"math/big"
3838
math_rand "math/rand"
3939
"net"
40+
"os"
4041
"strconv"
4142
"strings"
4243
"time"
@@ -74,7 +75,7 @@ func initServer(args *tsshdArgs) (*kcp.Listener, *quic.Listener, error) {
7475
}
7576

7677
if args.Proxy {
77-
conn, err = startServerProxy(conn, info)
78+
conn, err = startServerProxy(conn, info, args.ConnectTimeout)
7879
if err != nil {
7980
return nil, nil, err
8081
}
@@ -131,63 +132,92 @@ func getPortRange(args *tsshdArgs) (int, int) {
131132
return kDefaultPortRangeLow, kDefaultPortRangeHigh
132133
}
133134

134-
func getUdpNetworks(args *tsshdArgs) []string {
135-
if args.IPv4 && !args.IPv6 {
136-
return []string{"udp4"}
137-
}
138-
if !args.IPv4 && args.IPv6 {
139-
return []string{"udp6"}
140-
}
141-
ipv4, ipv6 := false, false
142-
addrs, err := net.InterfaceAddrs()
135+
func canListenOnUDP(udpAddr *net.UDPAddr) bool {
136+
conn, err := net.ListenUDP("udp", udpAddr)
143137
if err != nil {
144-
return []string{"udp"}
138+
return false
145139
}
146-
for _, addr := range addrs {
147-
if ipNet, ok := addr.(*net.IPNet); ok {
148-
if ipNet.IP.IsLoopback() {
149-
continue
140+
_ = conn.Close()
141+
return true
142+
}
143+
144+
func getUdpAddrs(args *tsshdArgs) ([]*net.UDPAddr, error) {
145+
if sshConnection := os.Getenv("SSH_CONNECTION"); sshConnection != "" {
146+
if tokens := strings.Fields(sshConnection); len(tokens) >= 3 {
147+
ip := tokens[2]
148+
if strings.HasPrefix(strings.ToLower(ip), "::ffff:") {
149+
ip = ip[7:]
150150
}
151-
if ipNet.IP.To4() != nil {
152-
ipv4 = true
153-
} else if ipNet.IP.To16() != nil {
154-
ipv6 = true
151+
udpAddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("%s:0", ip))
152+
if err == nil && canListenOnUDP(udpAddr) {
153+
return []*net.UDPAddr{udpAddr}, nil
155154
}
156155
}
157156
}
158-
if !ipv4 && !ipv6 {
159-
return []string{"udp"}
160-
}
161-
var networks []string
162-
if ipv4 {
163-
networks = append(networks, "udp4")
157+
158+
var udpAddrs []*net.UDPAddr
159+
ifaceAddrs, err := net.InterfaceAddrs()
160+
if err == nil {
161+
ipv4Only := args.IPv4 && !args.IPv6
162+
ipv6Only := !args.IPv4 && args.IPv6
163+
for _, addr := range ifaceAddrs {
164+
if ipNet, ok := addr.(*net.IPNet); ok {
165+
if ipNet.IP.IsLoopback() {
166+
continue
167+
}
168+
if ipNet.IP.To4() != nil && !ipv6Only {
169+
addr := &net.UDPAddr{IP: ipNet.IP}
170+
if canListenOnUDP(addr) {
171+
udpAddrs = append(udpAddrs, addr)
172+
}
173+
} else if ipNet.IP.To16() != nil && !ipv4Only {
174+
var zone string
175+
if ipAddr, ok := addr.(*net.IPAddr); ok {
176+
zone = ipAddr.Zone
177+
}
178+
addr := &net.UDPAddr{IP: ipNet.IP, Zone: zone}
179+
if canListenOnUDP(addr) {
180+
udpAddrs = append(udpAddrs, addr)
181+
}
182+
}
183+
}
184+
}
164185
}
165-
if ipv6 {
166-
networks = append(networks, "udp6")
186+
187+
if len(udpAddrs) == 0 {
188+
udpAddr, err := net.ResolveUDPAddr("udp", ":0")
189+
if err != nil {
190+
return nil, err
191+
}
192+
return []*net.UDPAddr{udpAddr}, nil
167193
}
168-
return networks
194+
195+
return udpAddrs, nil
169196
}
170197

171198
func listenUdpOnFreePort(args *tsshdArgs) ([]*net.UDPConn, int, error) {
172199
portRangeLow, portRangeHigh := getPortRange(args)
173200
if portRangeHigh < portRangeLow {
174201
return nil, 0, fmt.Errorf("no port in [%d,%d]", portRangeLow, portRangeHigh)
175202
}
176-
networks := getUdpNetworks(args)
203+
addrs, err := getUdpAddrs(args)
204+
if err != nil {
205+
return nil, 0, fmt.Errorf("get available udp address failed: %v", err)
206+
}
177207
var lastErr error
178208
size := portRangeHigh - portRangeLow + 1
179209
port := portRangeLow + math_rand.Intn(size)
180210
for range size {
181211
var connList []*net.UDPConn
182-
for _, network := range networks {
183-
conn, err := listenUdpOnPort(network, port)
212+
for _, addr := range addrs {
213+
conn, err := listenUdpOnPort(addr, port)
184214
if err != nil {
185215
lastErr = err
186216
break
187217
}
188218
connList = append(connList, conn)
189219
}
190-
if len(connList) == len(networks) {
220+
if len(connList) == len(addrs) {
191221
return connList, port, nil
192222
}
193223
for _, conn := range connList {
@@ -204,15 +234,11 @@ func listenUdpOnFreePort(args *tsshdArgs) ([]*net.UDPConn, int, error) {
204234
return nil, 0, fmt.Errorf("listen udp on [%d,%d] failed", portRangeLow, portRangeHigh)
205235
}
206236

207-
func listenUdpOnPort(network string, port int) (*net.UDPConn, error) {
208-
addr := fmt.Sprintf(":%d", port)
209-
udpAddr, err := net.ResolveUDPAddr(network, addr)
210-
if err != nil {
211-
return nil, fmt.Errorf("resolve [%s] addr [%s] failed: %v", network, addr, err)
212-
}
213-
conn, err := net.ListenUDP(network, udpAddr)
237+
func listenUdpOnPort(udpAddr *net.UDPAddr, port int) (*net.UDPConn, error) {
238+
udpAddr.Port = port
239+
conn, err := net.ListenUDP("udp", udpAddr)
214240
if err != nil {
215-
return nil, fmt.Errorf("listen [%s] on [%s] failed: %v", network, addr, err)
241+
return nil, fmt.Errorf("listen on [%s] failed: %v", udpAddr.String(), err)
216242
}
217243
return conn, nil
218244
}

tsshd/service.go

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,9 +37,9 @@ import (
3737
var smuxConfig = smux.Config{
3838
Version: 2,
3939
KeepAliveDisabled: true,
40-
MaxFrameSize: 32 * 1024,
41-
MaxStreamBuffer: 64 * 1024,
42-
MaxReceiveBuffer: 4 * 1024 * 1024,
40+
MaxFrameSize: 48 * 1024,
41+
MaxStreamBuffer: 10 * 1024 * 1024,
42+
MaxReceiveBuffer: 20 * 1024 * 1024,
4343
}
4444

4545
type quicStream struct {
@@ -73,6 +73,7 @@ func handleKcpConn(conn *kcp.UDPSession) {
7373
return
7474
}
7575

76+
conn.SetWindowSize(1024, 1024)
7677
conn.SetNoDelay(1, 10, 2, 1)
7778

7879
session, err := smux.Server(conn, &smuxConfig)

0 commit comments

Comments
 (0)