@@ -30,6 +30,7 @@ import (
3030 "fmt"
3131 "io"
3232 "net"
33+ "os"
3334 "strings"
3435 "sync"
3536 "sync/atomic"
@@ -138,7 +139,6 @@ func (f *udpForwarder) startWorker() {
138139
139140func (f * udpForwarder ) sendDatagram (id uint64 , buf []byte ) bool {
140141 if len (buf ) > int (f .conn .GetMaxDatagramSize ()) {
141- debug ("datagram buffer size [%d] larger than [%d]" , len (buf ), f .conn .GetMaxDatagramSize ())
142142 return false
143143 }
144144
@@ -314,7 +314,7 @@ func handleDialUdpEvent(stream Stream) {
314314 return
315315 }
316316
317- if msg .Net == "unix " {
317+ if msg .Net == "unixgram " {
318318 if v := strings .ToLower (getSshdConfig ("AllowStreamLocalForwarding" )); v == "no" || v == "remote" {
319319 sendProhibited (stream , "AllowStreamLocalForwarding" )
320320 return
@@ -326,21 +326,7 @@ func handleDialUdpEvent(stream Stream) {
326326 return
327327 }
328328
329- var err error
330- var addr * net.UDPAddr
331- if msg .Timeout > 0 {
332- addr , err = doWithTimeout (func () (* net.UDPAddr , error ) {
333- return net .ResolveUDPAddr (msg .Net , msg .Addr )
334- }, msg .Timeout )
335- } else {
336- addr , err = net .ResolveUDPAddr (msg .Net , msg .Addr )
337- }
338- if err != nil {
339- sendError (stream , err )
340- return
341- }
342-
343- conn , err := net .DialUDP (msg .Net , nil , addr )
329+ conn , err := dialUDP (& msg )
344330 if err != nil {
345331 sendError (stream , err )
346332 return
@@ -364,7 +350,59 @@ func handleDialUdpEvent(stream Stream) {
364350 forwardUDP (pconn , conn )
365351}
366352
367- func forwardUDP (pconn * packetConn , conn * net.UDPConn ) {
353+ type unixgramConn struct {
354+ io.ReadWriteCloser
355+ localAddr string
356+ }
357+
358+ func (c * unixgramConn ) Close () error {
359+ err := c .ReadWriteCloser .Close ()
360+ _ = os .Remove (c .localAddr )
361+ return err
362+ }
363+
364+ func dialUDP (msg * dialUdpMessage ) (io.ReadWriteCloser , error ) {
365+ if msg .Net == "unixgram" {
366+ tmpFile , err := os .CreateTemp ("" , "tsshd_unixgram_*.sock" )
367+ if err != nil {
368+ return nil , fmt .Errorf ("create temp file failed: %v" , err )
369+ }
370+ localAddr := tmpFile .Name ()
371+ if err := tmpFile .Close (); err != nil {
372+ return nil , fmt .Errorf ("close temp file failed: %v" , err )
373+ }
374+ if err := os .Remove (localAddr ); err != nil {
375+ return nil , fmt .Errorf ("remove temp file failed: %v" , err )
376+ }
377+ laddr := & net.UnixAddr {Net : "unixgram" , Name : localAddr }
378+ raddr := & net.UnixAddr {Net : "unixgram" , Name : msg .Addr }
379+ conn , err := net .DialUnix ("unixgram" , laddr , raddr )
380+ if err != nil {
381+ if _ , err := os .Stat (localAddr ); err == nil {
382+ _ = os .Remove (localAddr )
383+ }
384+ return nil , err
385+ }
386+ return & unixgramConn {conn , localAddr }, nil
387+ }
388+
389+ var err error
390+ var addr * net.UDPAddr
391+ if msg .Timeout > 0 {
392+ addr , err = doWithTimeout (func () (* net.UDPAddr , error ) {
393+ return net .ResolveUDPAddr (msg .Net , msg .Addr )
394+ }, msg .Timeout )
395+ } else {
396+ addr , err = net .ResolveUDPAddr (msg .Net , msg .Addr )
397+ }
398+ if err != nil {
399+ return nil , err
400+ }
401+
402+ return net .DialUDP (msg .Net , nil , addr )
403+ }
404+
405+ func forwardUDP (pconn * packetConn , conn io.ReadWriteCloser ) {
368406 defer func () {
369407 _ = conn .Close ()
370408 _ = pconn .Close ()
0 commit comments