@@ -37,6 +37,7 @@ import (
3737 "math/big"
3838 math_rand "math/rand"
3939 "net"
40+ "os"
4041 "strconv"
4142 "strings"
4243 "time"
@@ -74,7 +75,7 @@ func initServer(args *tsshdArgs) (*kcp.Listener, *quic.Listener, error) {
7475 }
7576
7677 if args .Proxy {
77- conn , err = startServerProxy (conn , info )
78+ conn , err = startServerProxy (conn , info , args . ConnectTimeout )
7879 if err != nil {
7980 return nil , nil , err
8081 }
@@ -131,63 +132,92 @@ func getPortRange(args *tsshdArgs) (int, int) {
131132 return kDefaultPortRangeLow , kDefaultPortRangeHigh
132133}
133134
134- func getUdpNetworks (args * tsshdArgs ) []string {
135- if args .IPv4 && ! args .IPv6 {
136- return []string {"udp4" }
137- }
138- if ! args .IPv4 && args .IPv6 {
139- return []string {"udp6" }
140- }
141- ipv4 , ipv6 := false , false
142- addrs , err := net .InterfaceAddrs ()
135+ func canListenOnUDP (udpAddr * net.UDPAddr ) bool {
136+ conn , err := net .ListenUDP ("udp" , udpAddr )
143137 if err != nil {
144- return [] string { "udp" }
138+ return false
145139 }
146- for _ , addr := range addrs {
147- if ipNet , ok := addr .(* net.IPNet ); ok {
148- if ipNet .IP .IsLoopback () {
149- continue
140+ _ = conn .Close ()
141+ return true
142+ }
143+
144+ func getUdpAddrs (args * tsshdArgs ) ([]* net.UDPAddr , error ) {
145+ if sshConnection := os .Getenv ("SSH_CONNECTION" ); sshConnection != "" {
146+ if tokens := strings .Fields (sshConnection ); len (tokens ) >= 3 {
147+ ip := tokens [2 ]
148+ if strings .HasPrefix (strings .ToLower (ip ), "::ffff:" ) {
149+ ip = ip [7 :]
150150 }
151- if ipNet .IP .To4 () != nil {
152- ipv4 = true
153- } else if ipNet .IP .To16 () != nil {
154- ipv6 = true
151+ udpAddr , err := net .ResolveUDPAddr ("udp" , fmt .Sprintf ("%s:0" , ip ))
152+ if err == nil && canListenOnUDP (udpAddr ) {
153+ return []* net.UDPAddr {udpAddr }, nil
155154 }
156155 }
157156 }
158- if ! ipv4 && ! ipv6 {
159- return []string {"udp" }
160- }
161- var networks []string
162- if ipv4 {
163- networks = append (networks , "udp4" )
157+
158+ var udpAddrs []* net.UDPAddr
159+ ifaceAddrs , err := net .InterfaceAddrs ()
160+ if err == nil {
161+ ipv4Only := args .IPv4 && ! args .IPv6
162+ ipv6Only := ! args .IPv4 && args .IPv6
163+ for _ , addr := range ifaceAddrs {
164+ if ipNet , ok := addr .(* net.IPNet ); ok {
165+ if ipNet .IP .IsLoopback () {
166+ continue
167+ }
168+ if ipNet .IP .To4 () != nil && ! ipv6Only {
169+ addr := & net.UDPAddr {IP : ipNet .IP }
170+ if canListenOnUDP (addr ) {
171+ udpAddrs = append (udpAddrs , addr )
172+ }
173+ } else if ipNet .IP .To16 () != nil && ! ipv4Only {
174+ var zone string
175+ if ipAddr , ok := addr .(* net.IPAddr ); ok {
176+ zone = ipAddr .Zone
177+ }
178+ addr := & net.UDPAddr {IP : ipNet .IP , Zone : zone }
179+ if canListenOnUDP (addr ) {
180+ udpAddrs = append (udpAddrs , addr )
181+ }
182+ }
183+ }
184+ }
164185 }
165- if ipv6 {
166- networks = append (networks , "udp6" )
186+
187+ if len (udpAddrs ) == 0 {
188+ udpAddr , err := net .ResolveUDPAddr ("udp" , ":0" )
189+ if err != nil {
190+ return nil , err
191+ }
192+ return []* net.UDPAddr {udpAddr }, nil
167193 }
168- return networks
194+
195+ return udpAddrs , nil
169196}
170197
171198func listenUdpOnFreePort (args * tsshdArgs ) ([]* net.UDPConn , int , error ) {
172199 portRangeLow , portRangeHigh := getPortRange (args )
173200 if portRangeHigh < portRangeLow {
174201 return nil , 0 , fmt .Errorf ("no port in [%d,%d]" , portRangeLow , portRangeHigh )
175202 }
176- networks := getUdpNetworks (args )
203+ addrs , err := getUdpAddrs (args )
204+ if err != nil {
205+ return nil , 0 , fmt .Errorf ("get available udp address failed: %v" , err )
206+ }
177207 var lastErr error
178208 size := portRangeHigh - portRangeLow + 1
179209 port := portRangeLow + math_rand .Intn (size )
180210 for range size {
181211 var connList []* net.UDPConn
182- for _ , network := range networks {
183- conn , err := listenUdpOnPort (network , port )
212+ for _ , addr := range addrs {
213+ conn , err := listenUdpOnPort (addr , port )
184214 if err != nil {
185215 lastErr = err
186216 break
187217 }
188218 connList = append (connList , conn )
189219 }
190- if len (connList ) == len (networks ) {
220+ if len (connList ) == len (addrs ) {
191221 return connList , port , nil
192222 }
193223 for _ , conn := range connList {
@@ -204,15 +234,11 @@ func listenUdpOnFreePort(args *tsshdArgs) ([]*net.UDPConn, int, error) {
204234 return nil , 0 , fmt .Errorf ("listen udp on [%d,%d] failed" , portRangeLow , portRangeHigh )
205235}
206236
207- func listenUdpOnPort (network string , port int ) (* net.UDPConn , error ) {
208- addr := fmt .Sprintf (":%d" , port )
209- udpAddr , err := net .ResolveUDPAddr (network , addr )
210- if err != nil {
211- return nil , fmt .Errorf ("resolve [%s] addr [%s] failed: %v" , network , addr , err )
212- }
213- conn , err := net .ListenUDP (network , udpAddr )
237+ func listenUdpOnPort (udpAddr * net.UDPAddr , port int ) (* net.UDPConn , error ) {
238+ udpAddr .Port = port
239+ conn , err := net .ListenUDP ("udp" , udpAddr )
214240 if err != nil {
215- return nil , fmt .Errorf ("listen [%s] on [%s] failed: %v" , network , addr , err )
241+ return nil , fmt .Errorf ("listen on [%s] failed: %v" , udpAddr . String () , err )
216242 }
217243 return conn , nil
218244}
0 commit comments