Skip to content

Commit a379cc2

Browse files
committed
support X11UseLocalhost and XAuthLocation for X11 forwarding
1 parent 13d0d4f commit a379cc2

File tree

1 file changed

+120
-31
lines changed

1 file changed

+120
-31
lines changed

tsshd/session.go

Lines changed: 120 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -782,65 +782,153 @@ func handleX11Request(msg *startMessage) {
782782
}
783783
}
784784

785-
listener, port, err := listenTcpOnFreePort("localhost", 6000+displayOffset, min(6000+displayOffset+1000, 65535))
785+
useLocalhost := strings.ToLower(getSshdConfig("X11UseLocalhost")) != "no"
786+
listeners, port, err := listenTcpOnFreePort(useLocalhost, 6000+displayOffset, min(6000+displayOffset+1000, 65535))
786787
if err != nil {
787788
warning("X11 forwarding listen failed: %v", err)
788789
return
789790
}
790791
onExitFuncs = append(onExitFuncs, func() {
791-
_ = listener.Close()
792+
for _, listener := range listeners {
793+
_ = listener.Close()
794+
}
792795
})
796+
797+
hostname := getHostnameForX11(useLocalhost)
793798
displayNumber := port - 6000
794-
if msg.X11.AuthProtocol != "" && msg.X11.AuthCookie != "" {
795-
authDisplay := fmt.Sprintf("unix:%d.%d", displayNumber, msg.X11.ScreenNumber)
796-
input := fmt.Sprintf("remove %s\nadd %s %s %s\n", authDisplay, authDisplay, msg.X11.AuthProtocol, msg.X11.AuthCookie)
797-
if err := writeXauthData(input); err == nil {
798-
onExitFuncs = append(onExitFuncs, func() {
799-
_ = writeXauthData(fmt.Sprintf("remove %s\n", authDisplay))
800-
})
801-
}
799+
display := fmt.Sprintf("%s:%d.%d", hostname, displayNumber, msg.X11.ScreenNumber)
800+
authDisplay := display
801+
if useLocalhost {
802+
authDisplay = fmt.Sprintf("unix:%d.%d", displayNumber, msg.X11.ScreenNumber)
802803
}
803-
go handleChannelAccept(listener, msg.X11.ChannelType)
804+
805+
xauthPath := getXauthPath()
806+
xauthInput := fmt.Sprintf("remove %s\nadd %s %s %s\n", authDisplay, authDisplay, msg.X11.AuthProtocol, msg.X11.AuthCookie)
807+
if err := writeXauthData(xauthPath, xauthInput); err != nil {
808+
warning("write xauth data failed: %v", err)
809+
}
810+
onExitFuncs = append(onExitFuncs, func() {
811+
_ = writeXauthData(xauthPath, fmt.Sprintf("remove %s\n", authDisplay))
812+
})
813+
814+
for _, listener := range listeners {
815+
go handleChannelAccept(listener, msg.X11.ChannelType)
816+
}
817+
804818
if msg.Envs == nil {
805819
msg.Envs = make(map[string]string)
806820
}
807-
msg.Envs["DISPLAY"] = fmt.Sprintf("localhost:%d.%d", displayNumber, msg.X11.ScreenNumber)
821+
msg.Envs["DISPLAY"] = display
808822
}
809823

810-
func listenTcpOnFreePort(host string, low, high int) (net.Listener, int, error) {
811-
var err error
812-
var listener net.Listener
824+
func getHostnameForX11(useLocalhost bool) string {
825+
if useLocalhost {
826+
return "localhost"
827+
}
828+
829+
hostname, err := os.Hostname()
830+
if err != nil {
831+
warning("get hostname for X11 forwarding failed: %v", err)
832+
return "localhost"
833+
}
834+
return hostname
835+
}
836+
837+
func listenTcpOnFreePort(useLocalhost bool, low, high int) ([]net.Listener, int, error) {
838+
var ipv4Host, ipv6Host string
839+
if useLocalhost {
840+
ipv4Host, ipv6Host = "127.0.0.1", "::1"
841+
} else {
842+
ipv4Host, ipv6Host = "0.0.0.0", "::"
843+
}
844+
845+
var netList, hostList []string
846+
listener4, err4 := net.Listen("tcp4", net.JoinHostPort(ipv4Host, "0"))
847+
if err4 == nil {
848+
_ = listener4.Close()
849+
netList = append(netList, "tcp4")
850+
hostList = append(hostList, ipv4Host)
851+
}
852+
listener6, err6 := net.Listen("tcp6", net.JoinHostPort(ipv6Host, "0"))
853+
if err6 == nil {
854+
_ = listener6.Close()
855+
netList = append(netList, "tcp6")
856+
hostList = append(hostList, ipv6Host)
857+
}
858+
859+
if err4 != nil && err6 != nil {
860+
return nil, 0, fmt.Errorf("ipv4 and ipv6 both listen failed: %v, %v", err4, err6)
861+
}
862+
863+
var lastErr error
813864
for port := low; port <= high; port++ {
814-
listener, err = net.Listen("tcp", fmt.Sprintf("%s:%d", host, port))
815-
if err == nil {
816-
return listener, port, nil
865+
var listenerList []net.Listener
866+
portStr := strconv.Itoa(port)
867+
for i := range len(netList) {
868+
listener, err := net.Listen(netList[i], net.JoinHostPort(hostList[i], portStr))
869+
if err != nil {
870+
lastErr = err
871+
continue
872+
}
873+
listenerList = append(listenerList, listener)
874+
}
875+
if len(listenerList) == len(netList) {
876+
return listenerList, port, nil
877+
}
878+
for _, listener := range listenerList {
879+
_ = listener.Close()
817880
}
818881
}
819-
if err != nil {
820-
return nil, 0, fmt.Errorf("listen tcp on %s:[%d,%d] failed: %v", host, low, high, err)
882+
if lastErr != nil {
883+
return nil, 0, fmt.Errorf("listen tcp on [%s,%s][%d,%d] failed: %v", ipv4Host, ipv6Host, low, high, lastErr)
884+
}
885+
return nil, 0, fmt.Errorf("listen tcp on [%s,%s][%d,%d] failed", ipv4Host, ipv6Host, low, high)
886+
}
887+
888+
func getXauthPath() string {
889+
xauthPath := getSshdConfig("XAuthLocation")
890+
if xauthPath != "" {
891+
if _, err := os.Stat(xauthPath); err != nil {
892+
warning("XAuthLocation [%s] not found: %v", xauthPath, err)
893+
return "xauth"
894+
}
895+
return xauthPath
821896
}
822-
return nil, 0, fmt.Errorf("listen tcp on %s:[%d,%d] failed", host, low, high)
897+
898+
return "xauth"
823899
}
824900

825-
func writeXauthData(input string) error {
826-
cmd := exec.Command("xauth", "-q", "-")
901+
func writeXauthData(xauthPath, xauthInput string) error {
902+
cmd := exec.Command(xauthPath, "-q", "-")
827903
stdin, err := cmd.StdinPipe()
828904
if err != nil {
829-
return err
905+
return fmt.Errorf("stdin pipe failed: %v", err)
830906
}
831907
defer func() { _ = stdin.Close() }()
908+
909+
var errBuf bytes.Buffer
910+
cmd.Stderr = &errBuf
911+
cmd.Stdout = io.Discard
912+
832913
if err := cmd.Start(); err != nil {
833-
return err
914+
return fmt.Errorf("xauth start failed: %v", err)
834915
}
835-
if _, err := stdin.Write([]byte(input)); err != nil {
836-
return err
916+
917+
if _, err := stdin.Write([]byte(xauthInput)); err != nil {
918+
return fmt.Errorf("stdin write failed: %v", err)
837919
}
838920
_ = stdin.Close()
839-
_, _ = doWithTimeout(func() (int, error) {
840-
_ = cmd.Wait()
921+
922+
_, err = doWithTimeout(func() (int, error) {
923+
if err := cmd.Wait(); err != nil {
924+
if errBuf.Len() > 0 {
925+
return 0, fmt.Errorf("%s", strings.TrimSpace(errBuf.String()))
926+
}
927+
return 0, fmt.Errorf("xauth wait failed: %v", err)
928+
}
841929
return 0, nil
842-
}, 200*time.Millisecond)
843-
return nil
930+
}, 1000*time.Millisecond)
931+
return err
844932
}
845933

846934
func handleAgentRequest(msg *startMessage) {
@@ -864,6 +952,7 @@ func handleAgentRequest(msg *startMessage) {
864952
}
865953

866954
go handleChannelAccept(listener, msg.Agent.ChannelType)
955+
867956
if msg.Envs == nil {
868957
msg.Envs = make(map[string]string)
869958
}

0 commit comments

Comments
 (0)