@@ -782,65 +782,153 @@ func handleX11Request(msg *startMessage) {
782782 }
783783 }
784784
785- listener , port , err := listenTcpOnFreePort ("localhost" , 6000 + displayOffset , min (6000 + displayOffset + 1000 , 65535 ))
785+ useLocalhost := strings .ToLower (getSshdConfig ("X11UseLocalhost" )) != "no"
786+ listeners , port , err := listenTcpOnFreePort (useLocalhost , 6000 + displayOffset , min (6000 + displayOffset + 1000 , 65535 ))
786787 if err != nil {
787788 warning ("X11 forwarding listen failed: %v" , err )
788789 return
789790 }
790791 onExitFuncs = append (onExitFuncs , func () {
791- _ = listener .Close ()
792+ for _ , listener := range listeners {
793+ _ = listener .Close ()
794+ }
792795 })
796+
797+ hostname := getHostnameForX11 (useLocalhost )
793798 displayNumber := port - 6000
794- if msg .X11 .AuthProtocol != "" && msg .X11 .AuthCookie != "" {
795- authDisplay := fmt .Sprintf ("unix:%d.%d" , displayNumber , msg .X11 .ScreenNumber )
796- input := fmt .Sprintf ("remove %s\n add %s %s %s\n " , authDisplay , authDisplay , msg .X11 .AuthProtocol , msg .X11 .AuthCookie )
797- if err := writeXauthData (input ); err == nil {
798- onExitFuncs = append (onExitFuncs , func () {
799- _ = writeXauthData (fmt .Sprintf ("remove %s\n " , authDisplay ))
800- })
801- }
799+ display := fmt .Sprintf ("%s:%d.%d" , hostname , displayNumber , msg .X11 .ScreenNumber )
800+ authDisplay := display
801+ if useLocalhost {
802+ authDisplay = fmt .Sprintf ("unix:%d.%d" , displayNumber , msg .X11 .ScreenNumber )
802803 }
803- go handleChannelAccept (listener , msg .X11 .ChannelType )
804+
805+ xauthPath := getXauthPath ()
806+ xauthInput := fmt .Sprintf ("remove %s\n add %s %s %s\n " , authDisplay , authDisplay , msg .X11 .AuthProtocol , msg .X11 .AuthCookie )
807+ if err := writeXauthData (xauthPath , xauthInput ); err != nil {
808+ warning ("write xauth data failed: %v" , err )
809+ }
810+ onExitFuncs = append (onExitFuncs , func () {
811+ _ = writeXauthData (xauthPath , fmt .Sprintf ("remove %s\n " , authDisplay ))
812+ })
813+
814+ for _ , listener := range listeners {
815+ go handleChannelAccept (listener , msg .X11 .ChannelType )
816+ }
817+
804818 if msg .Envs == nil {
805819 msg .Envs = make (map [string ]string )
806820 }
807- msg .Envs ["DISPLAY" ] = fmt . Sprintf ( "localhost:%d.%d" , displayNumber , msg . X11 . ScreenNumber )
821+ msg .Envs ["DISPLAY" ] = display
808822}
809823
810- func listenTcpOnFreePort (host string , low , high int ) (net.Listener , int , error ) {
811- var err error
812- var listener net.Listener
824+ func getHostnameForX11 (useLocalhost bool ) string {
825+ if useLocalhost {
826+ return "localhost"
827+ }
828+
829+ hostname , err := os .Hostname ()
830+ if err != nil {
831+ warning ("get hostname for X11 forwarding failed: %v" , err )
832+ return "localhost"
833+ }
834+ return hostname
835+ }
836+
837+ func listenTcpOnFreePort (useLocalhost bool , low , high int ) ([]net.Listener , int , error ) {
838+ var ipv4Host , ipv6Host string
839+ if useLocalhost {
840+ ipv4Host , ipv6Host = "127.0.0.1" , "::1"
841+ } else {
842+ ipv4Host , ipv6Host = "0.0.0.0" , "::"
843+ }
844+
845+ var netList , hostList []string
846+ listener4 , err4 := net .Listen ("tcp4" , net .JoinHostPort (ipv4Host , "0" ))
847+ if err4 == nil {
848+ _ = listener4 .Close ()
849+ netList = append (netList , "tcp4" )
850+ hostList = append (hostList , ipv4Host )
851+ }
852+ listener6 , err6 := net .Listen ("tcp6" , net .JoinHostPort (ipv6Host , "0" ))
853+ if err6 == nil {
854+ _ = listener6 .Close ()
855+ netList = append (netList , "tcp6" )
856+ hostList = append (hostList , ipv6Host )
857+ }
858+
859+ if err4 != nil && err6 != nil {
860+ return nil , 0 , fmt .Errorf ("ipv4 and ipv6 both listen failed: %v, %v" , err4 , err6 )
861+ }
862+
863+ var lastErr error
813864 for port := low ; port <= high ; port ++ {
814- listener , err = net .Listen ("tcp" , fmt .Sprintf ("%s:%d" , host , port ))
815- if err == nil {
816- return listener , port , nil
865+ var listenerList []net.Listener
866+ portStr := strconv .Itoa (port )
867+ for i := range len (netList ) {
868+ listener , err := net .Listen (netList [i ], net .JoinHostPort (hostList [i ], portStr ))
869+ if err != nil {
870+ lastErr = err
871+ continue
872+ }
873+ listenerList = append (listenerList , listener )
874+ }
875+ if len (listenerList ) == len (netList ) {
876+ return listenerList , port , nil
877+ }
878+ for _ , listener := range listenerList {
879+ _ = listener .Close ()
817880 }
818881 }
819- if err != nil {
820- return nil , 0 , fmt .Errorf ("listen tcp on %s:[%d,%d] failed: %v" , host , low , high , err )
882+ if lastErr != nil {
883+ return nil , 0 , fmt .Errorf ("listen tcp on [%s,%s][%d,%d] failed: %v" , ipv4Host , ipv6Host , low , high , lastErr )
884+ }
885+ return nil , 0 , fmt .Errorf ("listen tcp on [%s,%s][%d,%d] failed" , ipv4Host , ipv6Host , low , high )
886+ }
887+
888+ func getXauthPath () string {
889+ xauthPath := getSshdConfig ("XAuthLocation" )
890+ if xauthPath != "" {
891+ if _ , err := os .Stat (xauthPath ); err != nil {
892+ warning ("XAuthLocation [%s] not found: %v" , xauthPath , err )
893+ return "xauth"
894+ }
895+ return xauthPath
821896 }
822- return nil , 0 , fmt .Errorf ("listen tcp on %s:[%d,%d] failed" , host , low , high )
897+
898+ return "xauth"
823899}
824900
825- func writeXauthData (input string ) error {
826- cmd := exec .Command ("xauth" , "-q" , "-" )
901+ func writeXauthData (xauthPath , xauthInput string ) error {
902+ cmd := exec .Command (xauthPath , "-q" , "-" )
827903 stdin , err := cmd .StdinPipe ()
828904 if err != nil {
829- return err
905+ return fmt . Errorf ( "stdin pipe failed: %v" , err )
830906 }
831907 defer func () { _ = stdin .Close () }()
908+
909+ var errBuf bytes.Buffer
910+ cmd .Stderr = & errBuf
911+ cmd .Stdout = io .Discard
912+
832913 if err := cmd .Start (); err != nil {
833- return err
914+ return fmt . Errorf ( "xauth start failed: %v" , err )
834915 }
835- if _ , err := stdin .Write ([]byte (input )); err != nil {
836- return err
916+
917+ if _ , err := stdin .Write ([]byte (xauthInput )); err != nil {
918+ return fmt .Errorf ("stdin write failed: %v" , err )
837919 }
838920 _ = stdin .Close ()
839- _ , _ = doWithTimeout (func () (int , error ) {
840- _ = cmd .Wait ()
921+
922+ _ , err = doWithTimeout (func () (int , error ) {
923+ if err := cmd .Wait (); err != nil {
924+ if errBuf .Len () > 0 {
925+ return 0 , fmt .Errorf ("%s" , strings .TrimSpace (errBuf .String ()))
926+ }
927+ return 0 , fmt .Errorf ("xauth wait failed: %v" , err )
928+ }
841929 return 0 , nil
842- }, 200 * time .Millisecond )
843- return nil
930+ }, 1000 * time .Millisecond )
931+ return err
844932}
845933
846934func handleAgentRequest (msg * startMessage ) {
@@ -864,6 +952,7 @@ func handleAgentRequest(msg *startMessage) {
864952 }
865953
866954 go handleChannelAccept (listener , msg .Agent .ChannelType )
955+
867956 if msg .Envs == nil {
868957 msg .Envs = make (map [string ]string )
869958 }
0 commit comments