Skip to content

Commit 348ee5b

Browse files
committed
support unixgram for UDP local forwarding
1 parent 0201247 commit 348ee5b

File tree

1 file changed

+56
-18
lines changed

1 file changed

+56
-18
lines changed

tsshd/datagram.go

Lines changed: 56 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ import (
3030
"fmt"
3131
"io"
3232
"net"
33+
"os"
3334
"strings"
3435
"sync"
3536
"sync/atomic"
@@ -138,7 +139,6 @@ func (f *udpForwarder) startWorker() {
138139

139140
func (f *udpForwarder) sendDatagram(id uint64, buf []byte) bool {
140141
if len(buf) > int(f.conn.GetMaxDatagramSize()) {
141-
debug("datagram buffer size [%d] larger than [%d]", len(buf), f.conn.GetMaxDatagramSize())
142142
return false
143143
}
144144

@@ -314,7 +314,7 @@ func handleDialUdpEvent(stream Stream) {
314314
return
315315
}
316316

317-
if msg.Net == "unix" {
317+
if msg.Net == "unixgram" {
318318
if v := strings.ToLower(getSshdConfig("AllowStreamLocalForwarding")); v == "no" || v == "remote" {
319319
sendProhibited(stream, "AllowStreamLocalForwarding")
320320
return
@@ -326,21 +326,7 @@ func handleDialUdpEvent(stream Stream) {
326326
return
327327
}
328328

329-
var err error
330-
var addr *net.UDPAddr
331-
if msg.Timeout > 0 {
332-
addr, err = doWithTimeout(func() (*net.UDPAddr, error) {
333-
return net.ResolveUDPAddr(msg.Net, msg.Addr)
334-
}, msg.Timeout)
335-
} else {
336-
addr, err = net.ResolveUDPAddr(msg.Net, msg.Addr)
337-
}
338-
if err != nil {
339-
sendError(stream, err)
340-
return
341-
}
342-
343-
conn, err := net.DialUDP(msg.Net, nil, addr)
329+
conn, err := dialUDP(&msg)
344330
if err != nil {
345331
sendError(stream, err)
346332
return
@@ -364,7 +350,59 @@ func handleDialUdpEvent(stream Stream) {
364350
forwardUDP(pconn, conn)
365351
}
366352

367-
func forwardUDP(pconn *packetConn, conn *net.UDPConn) {
353+
type unixgramConn struct {
354+
io.ReadWriteCloser
355+
localAddr string
356+
}
357+
358+
func (c *unixgramConn) Close() error {
359+
err := c.ReadWriteCloser.Close()
360+
_ = os.Remove(c.localAddr)
361+
return err
362+
}
363+
364+
func dialUDP(msg *dialUdpMessage) (io.ReadWriteCloser, error) {
365+
if msg.Net == "unixgram" {
366+
tmpFile, err := os.CreateTemp("", "tsshd_unixgram_*.sock")
367+
if err != nil {
368+
return nil, fmt.Errorf("create temp file failed: %v", err)
369+
}
370+
localAddr := tmpFile.Name()
371+
if err := tmpFile.Close(); err != nil {
372+
return nil, fmt.Errorf("close temp file failed: %v", err)
373+
}
374+
if err := os.Remove(localAddr); err != nil {
375+
return nil, fmt.Errorf("remove temp file failed: %v", err)
376+
}
377+
laddr := &net.UnixAddr{Net: "unixgram", Name: localAddr}
378+
raddr := &net.UnixAddr{Net: "unixgram", Name: msg.Addr}
379+
conn, err := net.DialUnix("unixgram", laddr, raddr)
380+
if err != nil {
381+
if _, err := os.Stat(localAddr); err == nil {
382+
_ = os.Remove(localAddr)
383+
}
384+
return nil, err
385+
}
386+
return &unixgramConn{conn, localAddr}, nil
387+
}
388+
389+
var err error
390+
var addr *net.UDPAddr
391+
if msg.Timeout > 0 {
392+
addr, err = doWithTimeout(func() (*net.UDPAddr, error) {
393+
return net.ResolveUDPAddr(msg.Net, msg.Addr)
394+
}, msg.Timeout)
395+
} else {
396+
addr, err = net.ResolveUDPAddr(msg.Net, msg.Addr)
397+
}
398+
if err != nil {
399+
return nil, err
400+
}
401+
402+
return net.DialUDP(msg.Net, nil, addr)
403+
}
404+
405+
func forwardUDP(pconn *packetConn, conn io.ReadWriteCloser) {
368406
defer func() {
369407
_ = conn.Close()
370408
_ = pconn.Close()

0 commit comments

Comments
 (0)