From e7cef934d4bf31780f3f6ee2ac3b88a6b5f94a01 Mon Sep 17 00:00:00 2001 From: xvzc Date: Mon, 22 Dec 2025 00:44:41 +0900 Subject: [PATCH 01/39] refactor(socks5): move rule matching to proxy, extract handlers --- internal/proxy/socks5/handler.go | 13 ++ internal/proxy/socks5/socks5_proxy.go | 219 ++++++++++---------------- internal/proxy/socks5/tcp_handler.go | 116 ++++++++++++++ internal/proxy/socks5/udp_handler.go | 191 ++++++++++++++++++++++ 4 files changed, 403 insertions(+), 136 deletions(-) create mode 100644 internal/proxy/socks5/handler.go create mode 100644 internal/proxy/socks5/tcp_handler.go create mode 100644 internal/proxy/socks5/udp_handler.go diff --git a/internal/proxy/socks5/handler.go b/internal/proxy/socks5/handler.go new file mode 100644 index 00000000..7bcf8f20 --- /dev/null +++ b/internal/proxy/socks5/handler.go @@ -0,0 +1,13 @@ +package socks5 + +import ( + "context" + "net" + + "github.com/xvzc/SpoofDPI/internal/config" + "github.com/xvzc/SpoofDPI/internal/proto" +) + +type Handler interface { + Handle(ctx context.Context, conn net.Conn, req *proto.SOCKS5Request, rule *config.Rule, addrs []net.IPAddr) error +} \ No newline at end of file diff --git a/internal/proxy/socks5/socks5_proxy.go b/internal/proxy/socks5/socks5_proxy.go index f9097cae..d2011b70 100644 --- a/internal/proxy/socks5/socks5_proxy.go +++ b/internal/proxy/socks5/socks5_proxy.go @@ -3,7 +3,6 @@ package socks5 import ( "context" "encoding/json" - "errors" "fmt" "io" "net" @@ -25,11 +24,13 @@ import ( type SOCKS5Proxy struct { logger zerolog.Logger - resolver dns.Resolver - httpsHandler *http.HTTPSHandler // SOCKS5 primarily uses CONNECT, so we leverage HTTPSHandler - ruleMatcher matcher.RuleMatcher - serverOpts *config.ServerOptions - policyOpts *config.PolicyOptions + resolver dns.Resolver + ruleMatcher matcher.RuleMatcher + serverOpts *config.ServerOptions + policyOpts *config.PolicyOptions + + tcpHandler Handler + udpHandler Handler } func NewSOCKS5Proxy( @@ -41,12 +42,17 @@ func NewSOCKS5Proxy( policyOpts *config.PolicyOptions, ) proxy.ProxyServer { return &SOCKS5Proxy{ - logger: logger, - resolver: resolver, - httpsHandler: httpsHandler, - ruleMatcher: ruleMatcher, - serverOpts: serverOpts, - policyOpts: policyOpts, + logger: logger, + resolver: resolver, + ruleMatcher: ruleMatcher, + serverOpts: serverOpts, + policyOpts: policyOpts, + tcpHandler: NewTCPHandler( + logger, + httpsHandler, + serverOpts, + ), + udpHandler: NewUDPHandler(logger), } } @@ -101,26 +107,45 @@ func (p *SOCKS5Proxy) handleConnection(ctx context.Context, conn net.Conn) { return } - // Only support CONNECT for now - if req.Cmd != proto.CmdConnect { - _ = proto.SOCKS5CommandNotSupportedResponse().Write(conn) - logger.Warn().Uint8("cmd", req.Cmd).Msg("unsupported socks5 command") - return - } - // Setup Logging Context remoteInfo := req.Domain if remoteInfo == "" { remoteInfo = req.IP.String() } ctx = session.WithHostInfo(ctx, remoteInfo) - logger = logger.With().Ctx(ctx).Logger() - logger.Debug(). - Str("from", conn.RemoteAddr().String()). - Msg("new socks5 request") + switch req.Cmd { + case proto.CmdConnect: + rule, addrs, err := p.resolveAndMatch(ctx, req) + if err != nil { + return // resolveAndMatch logs error and writes failure response if needed + } + if err := p.tcpHandler.Handle(ctx, conn, req, rule, addrs); err != nil { + return // Handler logs error + } + + // Auto Config Check (Duplicate logic moved here or kept in handler? + // User said: "Handler just takes the rule". + // Auto config logic updates the policy. It feels like "Proxy" responsibility or "Matcher" responsibility. + // If I keep it here, it's cleaner for handler. + p.handleAutoConfig(ctx, req, addrs, rule) + + case proto.CmdUDPAssociate: + // UDP Associate usually doesn't have destination info in the request + p.udpHandler.Handle(ctx, conn, req, nil, nil) + default: + _ = proto.SOCKS5CommandNotSupportedResponse().Write(conn) + logger.Warn().Uint8("cmd", req.Cmd).Msg("unsupported socks5 command") + } +} + +func (p *SOCKS5Proxy) resolveAndMatch( + ctx context.Context, + req *proto.SOCKS5Request, +) (*config.Rule, []net.IPAddr, error) { + logger := zerolog.Ctx(ctx) - // 3. Match Domain Rules (if domain provided) + // 1. Match Domain Rules (if domain provided) var nameMatch *config.Rule if req.Domain != "" { nameMatch = p.ruleMatcher.Search( @@ -135,8 +160,7 @@ func (p *SOCKS5Proxy) handleConnection(ctx context.Context, conn net.Conn) { } } - // 4. DNS Resolution - // SOCKS5 allows IP or Domain. If Domain, we resolve. If IP, we use it directly. + // 2. DNS Resolution t1 := time.Now() var addrs []net.IPAddr @@ -144,13 +168,28 @@ func (p *SOCKS5Proxy) handleConnection(ctx context.Context, conn net.Conn) { // Resolve Domain rSet, err := p.resolver.Resolve(ctx, req.Domain, nil, nameMatch) if err != nil { - _ = proto.SOCKS5FailureResponse().Write(conn) - logging.ErrorUnwrapped(&logger, "dns lookup failed", err) - return + // We can't write to conn here easily as it's not passed, + // but we can return error and let caller handle or just fail. + // Ideally we should write failure. + // Let's rely on caller or pass conn? + // The caller `handleConnection` has `conn`. + // I'll make this function just return error, and caller handles the UI part? + // But wait, `SOCKS5FailureResponse` needs to be written. + // I'll pass conn to this function? No, keep it pure logic if possible. + // But standard practice: Write failure on error. + // I will return error and let `handleConnection` assume failure was not written? + // Or just handle it here? + // I'll handle writing failure in `handleConnection` if this returns error? + // But I need to differentiate errors. + // Let's just return error, and the caller (handleConnection) will catch it. + // Wait, caller `handleConnection` has `conn`. + // But `resolveAndMatch` doesn't have `conn`. + // I'll just log here. The caller should write failure. + logging.ErrorUnwrapped(logger, "dns lookup failed", err) + return nil, nil, err } addrs = rSet.Addrs } else { - // IP Request - Just wrap the IP addrs = []net.IPAddr{{IP: req.IP}} } @@ -160,17 +199,7 @@ func (p *SOCKS5Proxy) handleConnection(ctx context.Context, conn net.Conn) { Str("took", fmt.Sprintf("%dms", dt)). Msg("dns lookup ok") - // Avoid recursively querying self. - ok, err := validateDestination(addrs, req.Port, p.serverOpts.ListenAddr) - if err != nil { - logger.Debug().Err(err).Msg("error determining if valid destination") - if !ok { - _ = proto.SOCKS5FailureResponse().Write(conn) - return - } - } - - // 6. Match IP Rules + // 3. Match IP Rules var selectors []*matcher.Selector for _, v := range addrs { selectors = append(selectors, &matcher.Selector{ @@ -187,62 +216,21 @@ func (p *SOCKS5Proxy) handleConnection(ctx context.Context, conn net.Conn) { } bestMatch := matcher.GetHigherPriorityRule(addrMatch, nameMatch) - if bestMatch != nil && *bestMatch.Block { - logger.Debug().Msg("request is blocked by policy") - _ = proto.SOCKS5FailureResponse().Write(conn) - // Or specific code for blocked - return - } - - // 7. Handover to Handler - // Important: We must send a success reply to the client BEFORE handing over to the handler, - // because the handler (SpoofDPI) typically expects a raw stream where it can start TLS handshake immediately. - // However, standard SOCKS5 expects the proxy to connect to the target FIRST, then send success. - // Since SpoofDPI handler does the connection, we might need to send success here optimistically. - - // [Optimistic Success] - // We tell the client "OK, we are connected" so it starts sending data (e.g. ClientHello). - // The real connection happens inside p.tcpHandler.HandleRequest. - // Note: BIND addr/port is usually 0.0.0.0:0 if we don't care. - if err := proto.SOCKS5SuccessResponse().Bind(net.IPv4zero).Port(0).Write(conn); err != nil { - logger.Error().Err(err).Msg("failed to write socks5 success reply") - return - } - - dst := &http.Destination{ - Domain: req.Domain, - Addrs: addrs, - Port: req.Port, - Timeout: *p.serverOpts.Timeout, - } - - // Note: 'req' is nil here because it's not an HTTP request yet. - // The handler must be able to handle nil req or we wrap a dummy one. - // Assuming Handler is adapted to deal with raw streams or nil reqs for SOCKS mode. - handleErr := p.httpsHandler.HandleRequest(ctx, conn, dst, bestMatch) - - if handleErr == nil { - return - } - - logger.Warn().Err(handleErr).Msg("error handling request") - if !errors.Is(handleErr, netutil.ErrBlocked) { - return - } + return bestMatch, addrs, nil +} - // 8. Auto Config (Duplicate logic from HTTPProxy) - if nameMatch != nil { - logger.Info(). - Interface("match", nameMatch.Match.Domains). - Str("name", *nameMatch.Name). - Msg("skipping auto-config (duplicate policy)") - return - } +func (p *SOCKS5Proxy) handleAutoConfig( + ctx context.Context, + req *proto.SOCKS5Request, + addrs []net.IPAddr, + matchedRule *config.Rule, +) { + logger := zerolog.Ctx(ctx) - if addrMatch != nil { + if matchedRule != nil { logger.Info(). - Interface("match", addrMatch.Match.Addrs). - Str("name", *addrMatch.Name). + Interface("match", matchedRule.Match). + Str("name", *matchedRule.Name). Msg("skipping auto-config (duplicate policy)") return } @@ -251,8 +239,6 @@ func (p *SOCKS5Proxy) handleConnection(ctx context.Context, conn net.Conn) { newRule := p.policyOpts.Template.Clone() targetDomain := req.Domain if targetDomain == "" && len(addrs) > 0 { - // If request was by IP, we can't really add a domain rule, - // maybe add IP rule or skip. Use domain if available. targetDomain = addrs[0].IP.String() } @@ -266,45 +252,6 @@ func (p *SOCKS5Proxy) handleConnection(ctx context.Context, conn net.Conn) { } } -// validateDestination checks if we are recursively querying ourselves. -// This function needs to be duplicated or moved to a shared utility if common logic is identical. -// For now, I'll use a local helper or assume the logic is specific enough. -// Since `isRecursiveDst` was a method on HTTPProxy, I'll assume similar logic is needed here. -// I'll implement a local version for now to keep it self-contained in this package or duplicated from http/proxy.go. -// Or I can just omit it if I don't want to duplicate, but safety is important. -// I'll add a simple check against listening port/loopback. -func validateDestination( - dstAddrs []net.IPAddr, - dstPort int, - listenAddr *net.TCPAddr, -) (bool, error) { - if dstPort != int(listenAddr.Port) { - return true, nil - } - - for _, dstAddr := range dstAddrs { - ip := dstAddr.IP - if ip.IsLoopback() { - return false, nil - } - - ifAddrs, err := net.InterfaceAddrs() - if err != nil { - return false, err - } - - for _, addr := range ifAddrs { - if ipnet, ok := addr.(*net.IPNet); ok { - if ipnet.IP.Equal(ip) { - return false, nil - } - } - } - } - return true, nil -} - -// negotiate performs SOCKS5 auth negotiation (NoAuth only). func (p *SOCKS5Proxy) negotiate(conn net.Conn) error { header := make([]byte, 2) if _, err := io.ReadFull(conn, header); err != nil { @@ -324,4 +271,4 @@ func (p *SOCKS5Proxy) negotiate(conn net.Conn) error { // Respond: Version 5, Method NoAuth(0) _, err := conn.Write([]byte{proto.SOCKSVersion, proto.AuthNone}) return err -} +} \ No newline at end of file diff --git a/internal/proxy/socks5/tcp_handler.go b/internal/proxy/socks5/tcp_handler.go new file mode 100644 index 00000000..5c4d5fc1 --- /dev/null +++ b/internal/proxy/socks5/tcp_handler.go @@ -0,0 +1,116 @@ +package socks5 + +import ( + "context" + "errors" + "net" + + "github.com/rs/zerolog" + "github.com/xvzc/SpoofDPI/internal/config" + "github.com/xvzc/SpoofDPI/internal/netutil" + "github.com/xvzc/SpoofDPI/internal/proto" + "github.com/xvzc/SpoofDPI/internal/proxy/http" +) + +type TCPHandler struct { + logger zerolog.Logger + httpsHandler *http.HTTPSHandler + serverOpts *config.ServerOptions +} + +func NewTCPHandler( + logger zerolog.Logger, + httpsHandler *http.HTTPSHandler, + serverOpts *config.ServerOptions, +) *TCPHandler { + return &TCPHandler{ + logger: logger, + httpsHandler: httpsHandler, + serverOpts: serverOpts, + } +} + +func (h *TCPHandler) Handle( + ctx context.Context, + conn net.Conn, + req *proto.SOCKS5Request, + rule *config.Rule, + addrs []net.IPAddr, +) error { + logger := h.logger.With().Ctx(ctx).Logger() + + // 1. Validate Destination (Avoid Recursive Loop) + ok, err := validateDestination(addrs, req.Port, h.serverOpts.ListenAddr) + if err != nil { + logger.Debug().Err(err).Msg("error determining if valid destination") + if !ok { + _ = proto.SOCKS5FailureResponse().Write(conn) + return err + } + } + + // 2. Check if blocked + if rule != nil && *rule.Block { + logger.Debug().Msg("request is blocked by policy") + _ = proto.SOCKS5FailureResponse().Write(conn) + return netutil.ErrBlocked + } + + // 3. Send Success Response Optimistically + if err := proto.SOCKS5SuccessResponse().Bind(net.IPv4zero).Port(0).Write(conn); err != nil { + logger.Error().Err(err).Msg("failed to write socks5 success reply") + return err + } + + dst := &http.Destination{ + Domain: req.Domain, + Addrs: addrs, + Port: req.Port, + Timeout: *h.serverOpts.Timeout, + } + + // 4. Handover to HTTPSHandler + handleErr := h.httpsHandler.HandleRequest(ctx, conn, dst, rule) + if handleErr == nil { + return nil + } + + logger.Warn().Err(handleErr).Msg("error handling request") + if !errors.Is(handleErr, netutil.ErrBlocked) { + return handleErr + } + + return nil +} + +// validateDestination checks if we are recursively querying ourselves. +func validateDestination( + dstAddrs []net.IPAddr, + dstPort int, + listenAddr *net.TCPAddr, +) (bool, error) { + if dstPort != int(listenAddr.Port) { + return true, nil + } + + for _, dstAddr := range dstAddrs { + ip := dstAddr.IP + if ip.IsLoopback() { + return false, nil + } + + ifAddrs, err := net.InterfaceAddrs() + if err != nil { + return false, err + } + + for _, addr := range ifAddrs { + if ipnet, ok := addr.(*net.IPNet); ok { + if ipnet.IP.Equal(ip) { + return false, nil + } + } + } + } + return true, nil +} \ No newline at end of file diff --git a/internal/proxy/socks5/udp_handler.go b/internal/proxy/socks5/udp_handler.go new file mode 100644 index 00000000..dc4faf72 --- /dev/null +++ b/internal/proxy/socks5/udp_handler.go @@ -0,0 +1,191 @@ +package socks5 + +import ( + "context" + "encoding/binary" + "fmt" + "io" + "net" + + "github.com/rs/zerolog" + "github.com/xvzc/SpoofDPI/internal/config" + "github.com/xvzc/SpoofDPI/internal/proto" +) + +type UDPHandler struct { + logger zerolog.Logger +} + +func NewUDPHandler(logger zerolog.Logger) *UDPHandler { + return &UDPHandler{ + logger: logger, + } +} + +func (h *UDPHandler) Handle( + ctx context.Context, + conn net.Conn, + req *proto.SOCKS5Request, + rule *config.Rule, + addrs []net.IPAddr, +) error { + logger := h.logger.With().Ctx(ctx).Logger() + + // 1. Listen on a random UDP port + udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0}) + if err != nil { + logger.Error().Err(err).Msg("failed to create udp listener") + _ = proto.SOCKS5FailureResponse().Write(conn) + return err + } + defer udpConn.Close() + + lAddr := udpConn.LocalAddr().(*net.UDPAddr) + + logger.Debug(). + Str("bind_addr", lAddr.String()). + Msg("socks5 udp associate established") + + // 2. Reply with the bound address + if err := proto.SOCKS5SuccessResponse().Bind(lAddr.IP).Port(lAddr.Port).Write(conn); err != nil { + logger.Error().Err(err).Msg("failed to write socks5 success reply") + return err + } + + // 3. Keep TCP Alive & Relay + // We need to monitor TCP for closure. + done := make(chan struct{}) + go func() { + _, _ = io.Copy(io.Discard, conn) // Block until TCP closes + close(done) + }() + + go func() { + <-done + udpConn.Close() + }() + + buf := make([]byte, 65535) + var clientAddr *net.UDPAddr + + for { + n, addr, err := udpConn.ReadFromUDP(buf) + if err != nil { + // Normal closure check + select { + case <-done: + return nil + default: + logger.Debug().Err(err).Msg("error reading from udp") + return err + } + } + + // Initial Client Identification + if clientAddr == nil { + clientAddr = addr + } + + if addr.IP.Equal(clientAddr.IP) && addr.Port == clientAddr.Port { + // Outbound: Client -> Proxy -> Target + targetAddr, payload, err := parseUDPHeader(buf[:n]) + if err != nil { + logger.Warn().Err(err).Msg("failed to parse socks5 udp header") + continue + } + + // We use the same UDP socket to send to target. + // The Target will reply to this socket. + resolvedAddr, err := net.ResolveUDPAddr("udp", targetAddr) + if err != nil { + logger.Warn().Err(err).Str("addr", targetAddr).Msg("failed to resolve udp target") + continue + } + + if _, err := udpConn.WriteTo(payload, resolvedAddr); err != nil { + logger.Warn().Err(err).Msg("failed to write udp to target") + } + } else { + // Inbound: Target -> Proxy -> Client + // Wrap with SOCKS5 Header + header := createUDPHeaderFromAddr(addr) + response := append(header, buf[:n]...) + + if _, err := udpConn.WriteToUDP(response, clientAddr); err != nil { + logger.Warn().Err(err).Msg("failed to write udp to client") + } + } + } +} + +func parseUDPHeader(b []byte) (string, []byte, error) { + if len(b) < 4 { + return "", nil, fmt.Errorf("header too short") + } + // RSV(2) FRAG(1) ATYP(1) + if b[0] != 0 || b[1] != 0 { + return "", nil, fmt.Errorf("invalid rsv") + } + frag := b[2] + if frag != 0 { + return "", nil, fmt.Errorf("fragmentation not supported") + } + + atyp := b[3] + var host string + var pos int + + switch atyp { + case proto.ATYPIPv4: + if len(b) < 10 { + return "", nil, fmt.Errorf("header too short for ipv4") + } + host = net.IP(b[4:8]).String() + pos = 8 + case proto.ATYPIPv6: + if len(b) < 22 { + return "", nil, fmt.Errorf("header too short for ipv6") + } + host = net.IP(b[4:20]).String() + pos = 20 + case proto.ATYPFQDN: + if len(b) < 5 { + return "", nil, fmt.Errorf("header too short for fqdn") + } + l := int(b[4]) + if len(b) < 5+l+2 { + return "", nil, fmt.Errorf("header too short for fqdn data") + } + host = string(b[5 : 5+l]) + pos = 5 + l + default: + return "", nil, fmt.Errorf("unsupported atyp: %d", atyp) + } + + port := binary.BigEndian.Uint16(b[pos : pos+2]) + pos += 2 + + addr := net.JoinHostPort(host, fmt.Sprintf("%d", port)) + return addr, b[pos:], nil +} + +func createUDPHeaderFromAddr(addr *net.UDPAddr) []byte { + // RSV(2) FRAG(1) ATYP(1) ... + buf := make([]byte, 0, 24) + buf = append(buf, 0, 0, 0) // RSV, FRAG + + ip4 := addr.IP.To4() + if ip4 != nil { + buf = append(buf, proto.ATYPIPv4) + buf = append(buf, ip4...) + } else { + buf = append(buf, proto.ATYPIPv6) + buf = append(buf, addr.IP.To16()...) + } + + portBuf := make([]byte, 2) + binary.BigEndian.PutUint16(portBuf, uint16(addr.Port)) + buf = append(buf, portBuf...) + + return buf +} \ No newline at end of file From f211a1085c318365751a8b07dc515868f40383c0 Mon Sep 17 00:00:00 2001 From: xvzc Date: Mon, 22 Dec 2025 01:15:10 +0900 Subject: [PATCH 02/39] refactor: extract Destination to netutil, update socks5 to use it, restore http pipe --- internal/netutil/addr.go | 10 ++++++- internal/proxy/http/http_handler.go | 11 ++++--- internal/proxy/http/http_proxy.go | 9 +----- internal/proxy/http/https_handler.go | 2 +- internal/proxy/socks5/handler.go | 13 -------- internal/proxy/socks5/socks5_proxy.go | 33 ++++++-------------- internal/proxy/socks5/tcp_handler.go | 43 ++------------------------- internal/proxy/socks5/udp_handler.go | 3 +- 8 files changed, 31 insertions(+), 93 deletions(-) delete mode 100644 internal/proxy/socks5/handler.go diff --git a/internal/netutil/addr.go b/internal/netutil/addr.go index 2f856eca..f7fae42c 100644 --- a/internal/netutil/addr.go +++ b/internal/netutil/addr.go @@ -3,8 +3,16 @@ package netutil import ( "fmt" "net" + "time" ) +type Destination struct { + Domain string + Addrs []net.IPAddr + Port int + Timeout time.Duration +} + func ValidateDestination( dstAddrs []net.IPAddr, dstPort int, @@ -34,4 +42,4 @@ func ValidateDestination( } return true, err -} +} \ No newline at end of file diff --git a/internal/proxy/http/http_handler.go b/internal/proxy/http/http_handler.go index 3cc415f6..d847610d 100644 --- a/internal/proxy/http/http_handler.go +++ b/internal/proxy/http/http_handler.go @@ -16,9 +16,7 @@ type HTTPHandler struct { logger zerolog.Logger } -func NewHTTPHandler( - logger zerolog.Logger, -) *HTTPHandler { +func NewHTTPHandler(logger zerolog.Logger) *HTTPHandler { return &HTTPHandler{ logger: logger, } @@ -28,13 +26,14 @@ func (h *HTTPHandler) HandleRequest( ctx context.Context, lConn net.Conn, // Use the net.Conn interface, not a concrete *net.TCPConn. req *proto.HTTPRequest, // Assumes HttpRequest is a custom type for request parsing. - dst *Destination, + dst *netutil.Destination, rule *config.Rule, ) error { logger := logging.WithLocalScope(ctx, h.logger, "http") rConn, err := netutil.DialFastest(ctx, "tcp", dst.Addrs, dst.Port, dst.Timeout) if err != nil { + _ = proto.HTTPBadGatewayResponse().Write(lConn) return err } @@ -64,6 +63,10 @@ func (h *HTTPHandler) HandleRequest( continue } + if netutil.IsConnectionResetByPeer(e) { + return netutil.ErrBlocked + } + return fmt.Errorf( "unsuccessful tunnel %s -> %s: %w", lConn.RemoteAddr(), diff --git a/internal/proxy/http/http_proxy.go b/internal/proxy/http/http_proxy.go index 16c8a7ff..5bb1b88b 100644 --- a/internal/proxy/http/http_proxy.go +++ b/internal/proxy/http/http_proxy.go @@ -21,13 +21,6 @@ import ( "github.com/xvzc/SpoofDPI/internal/session" ) -type Destination struct { - Domain string - Addrs []net.IPAddr - Port int - Timeout time.Duration -} - type HTTPProxy struct { logger zerolog.Logger @@ -186,7 +179,7 @@ func (p *HTTPProxy) handleNewConnection(ctx context.Context, conn net.Conn) { return } - dst := &Destination{ + dst := &netutil.Destination{ Domain: domain, Addrs: rSet.Addrs, Port: dstPort, diff --git a/internal/proxy/http/https_handler.go b/internal/proxy/http/https_handler.go index 2a9f8a1e..0c9d6f45 100644 --- a/internal/proxy/http/https_handler.go +++ b/internal/proxy/http/https_handler.go @@ -44,7 +44,7 @@ func NewHTTPSHandler( func (h *HTTPSHandler) HandleRequest( ctx context.Context, lConn net.Conn, - dst *Destination, + dst *netutil.Destination, rule *config.Rule, ) error { httpsOpts := h.httpsOpts diff --git a/internal/proxy/socks5/handler.go b/internal/proxy/socks5/handler.go deleted file mode 100644 index 7bcf8f20..00000000 --- a/internal/proxy/socks5/handler.go +++ /dev/null @@ -1,13 +0,0 @@ -package socks5 - -import ( - "context" - "net" - - "github.com/xvzc/SpoofDPI/internal/config" - "github.com/xvzc/SpoofDPI/internal/proto" -) - -type Handler interface { - Handle(ctx context.Context, conn net.Conn, req *proto.SOCKS5Request, rule *config.Rule, addrs []net.IPAddr) error -} \ No newline at end of file diff --git a/internal/proxy/socks5/socks5_proxy.go b/internal/proxy/socks5/socks5_proxy.go index d2011b70..44499f63 100644 --- a/internal/proxy/socks5/socks5_proxy.go +++ b/internal/proxy/socks5/socks5_proxy.go @@ -29,8 +29,8 @@ type SOCKS5Proxy struct { serverOpts *config.ServerOptions policyOpts *config.PolicyOptions - tcpHandler Handler - udpHandler Handler + tcpHandler *TCPHandler + udpHandler *UDPHandler } func NewSOCKS5Proxy( @@ -120,14 +120,16 @@ func (p *SOCKS5Proxy) handleConnection(ctx context.Context, conn net.Conn) { if err != nil { return // resolveAndMatch logs error and writes failure response if needed } - if err := p.tcpHandler.Handle(ctx, conn, req, rule, addrs); err != nil { + dst := &netutil.Destination{ + Domain: req.Domain, + Addrs: addrs, + Port: req.Port, + Timeout: *p.serverOpts.Timeout, + } + if err := p.tcpHandler.Handle(ctx, conn, req, dst, rule); err != nil { return // Handler logs error } - // Auto Config Check (Duplicate logic moved here or kept in handler? - // User said: "Handler just takes the rule". - // Auto config logic updates the policy. It feels like "Proxy" responsibility or "Matcher" responsibility. - // If I keep it here, it's cleaner for handler. p.handleAutoConfig(ctx, req, addrs, rule) case proto.CmdUDPAssociate: @@ -168,23 +170,6 @@ func (p *SOCKS5Proxy) resolveAndMatch( // Resolve Domain rSet, err := p.resolver.Resolve(ctx, req.Domain, nil, nameMatch) if err != nil { - // We can't write to conn here easily as it's not passed, - // but we can return error and let caller handle or just fail. - // Ideally we should write failure. - // Let's rely on caller or pass conn? - // The caller `handleConnection` has `conn`. - // I'll make this function just return error, and caller handles the UI part? - // But wait, `SOCKS5FailureResponse` needs to be written. - // I'll pass conn to this function? No, keep it pure logic if possible. - // But standard practice: Write failure on error. - // I will return error and let `handleConnection` assume failure was not written? - // Or just handle it here? - // I'll handle writing failure in `handleConnection` if this returns error? - // But I need to differentiate errors. - // Let's just return error, and the caller (handleConnection) will catch it. - // Wait, caller `handleConnection` has `conn`. - // But `resolveAndMatch` doesn't have `conn`. - // I'll just log here. The caller should write failure. logging.ErrorUnwrapped(logger, "dns lookup failed", err) return nil, nil, err } diff --git a/internal/proxy/socks5/tcp_handler.go b/internal/proxy/socks5/tcp_handler.go index 5c4d5fc1..4a1878e2 100644 --- a/internal/proxy/socks5/tcp_handler.go +++ b/internal/proxy/socks5/tcp_handler.go @@ -34,13 +34,13 @@ func (h *TCPHandler) Handle( ctx context.Context, conn net.Conn, req *proto.SOCKS5Request, + dst *netutil.Destination, rule *config.Rule, - addrs []net.IPAddr, ) error { logger := h.logger.With().Ctx(ctx).Logger() // 1. Validate Destination (Avoid Recursive Loop) - ok, err := validateDestination(addrs, req.Port, h.serverOpts.ListenAddr) + ok, err := netutil.ValidateDestination(dst.Addrs, dst.Port, h.serverOpts.ListenAddr) if err != nil { logger.Debug().Err(err).Msg("error determining if valid destination") if !ok { @@ -62,13 +62,6 @@ func (h *TCPHandler) Handle( return err } - dst := &http.Destination{ - Domain: req.Domain, - Addrs: addrs, - Port: req.Port, - Timeout: *h.serverOpts.Timeout, - } - // 4. Handover to HTTPSHandler handleErr := h.httpsHandler.HandleRequest(ctx, conn, dst, rule) if handleErr == nil { @@ -81,36 +74,4 @@ func (h *TCPHandler) Handle( } return nil -} - -// validateDestination checks if we are recursively querying ourselves. -func validateDestination( - dstAddrs []net.IPAddr, - dstPort int, - listenAddr *net.TCPAddr, -) (bool, error) { - if dstPort != int(listenAddr.Port) { - return true, nil - } - - for _, dstAddr := range dstAddrs { - ip := dstAddr.IP - if ip.IsLoopback() { - return false, nil - } - - ifAddrs, err := net.InterfaceAddrs() - if err != nil { - return false, err - } - - for _, addr := range ifAddrs { - if ipnet, ok := addr.(*net.IPNet); ok { - if ipnet.IP.Equal(ip) { - return false, nil - } - } - } - } - return true, nil } \ No newline at end of file diff --git a/internal/proxy/socks5/udp_handler.go b/internal/proxy/socks5/udp_handler.go index dc4faf72..6c6a7db6 100644 --- a/internal/proxy/socks5/udp_handler.go +++ b/internal/proxy/socks5/udp_handler.go @@ -9,6 +9,7 @@ import ( "github.com/rs/zerolog" "github.com/xvzc/SpoofDPI/internal/config" + "github.com/xvzc/SpoofDPI/internal/netutil" "github.com/xvzc/SpoofDPI/internal/proto" ) @@ -26,8 +27,8 @@ func (h *UDPHandler) Handle( ctx context.Context, conn net.Conn, req *proto.SOCKS5Request, + dst *netutil.Destination, rule *config.Rule, - addrs []net.IPAddr, ) error { logger := h.logger.With().Ctx(ctx).Logger() From d5363c9225b8381c906896c4cde558d4e4c183c2 Mon Sep 17 00:00:00 2001 From: xvzc Date: Mon, 22 Dec 2025 01:30:46 +0900 Subject: [PATCH 03/39] refactor: extract TLS tunneling logic to handler.Bridge --- cmd/spoofdpi/main.go | 10 +- internal/proxy/handler/bridge.go | 130 ++++++++++++++++++++++++++ internal/proxy/http/https_handler.go | 127 +++---------------------- internal/proxy/socks5/socks5_proxy.go | 10 +- internal/proxy/socks5/tcp_handler.go | 24 ++--- 5 files changed, 167 insertions(+), 134 deletions(-) create mode 100644 internal/proxy/handler/bridge.go diff --git a/cmd/spoofdpi/main.go b/cmd/spoofdpi/main.go index e26874a7..6929be91 100644 --- a/cmd/spoofdpi/main.go +++ b/cmd/spoofdpi/main.go @@ -18,6 +18,7 @@ import ( "github.com/xvzc/SpoofDPI/internal/matcher" "github.com/xvzc/SpoofDPI/internal/packet" "github.com/xvzc/SpoofDPI/internal/proxy" + "github.com/xvzc/SpoofDPI/internal/proxy/handler" "github.com/xvzc/SpoofDPI/internal/proxy/http" "github.com/xvzc/SpoofDPI/internal/session" "github.com/xvzc/SpoofDPI/internal/system" @@ -258,8 +259,8 @@ func createProxy( } } - httpsHandler := http.NewHTTPSHandler( - logging.WithScope(logger, "hnd"), + bridge := handler.NewBridge( + logging.WithScope(logger, "brg"), desync.NewTLSDesyncer( writer, sniffer, @@ -269,6 +270,11 @@ func createProxy( cfg.HTTPS.Clone(), ) + httpsHandler := http.NewHTTPSHandler( + logging.WithScope(logger, "hnd"), + bridge, + ) + // if cfg.Server.EnableSocks5 != nil && *cfg.Server.EnableSocks5 { // return socks5.NewSocks5Proxy( // logging.WithScope(logger, "pxy"), diff --git a/internal/proxy/handler/bridge.go b/internal/proxy/handler/bridge.go new file mode 100644 index 00000000..c82443c3 --- /dev/null +++ b/internal/proxy/handler/bridge.go @@ -0,0 +1,130 @@ +package handler + +import ( + "context" + "fmt" + "net" + + "github.com/rs/zerolog" + "github.com/xvzc/SpoofDPI/internal/config" + "github.com/xvzc/SpoofDPI/internal/desync" + "github.com/xvzc/SpoofDPI/internal/logging" + "github.com/xvzc/SpoofDPI/internal/netutil" + "github.com/xvzc/SpoofDPI/internal/packet" + "github.com/xvzc/SpoofDPI/internal/proto" + "github.com/xvzc/SpoofDPI/internal/ptr" +) + +type Bridge struct { + logger zerolog.Logger + desyncer *desync.TLSDesyncer + sniffer packet.Sniffer + httpsOpts *config.HTTPSOptions +} + +func NewBridge( + logger zerolog.Logger, + desyncer *desync.TLSDesyncer, + sniffer packet.Sniffer, + httpsOpts *config.HTTPSOptions, +) *Bridge { + return &Bridge{ + logger: logger, + desyncer: desyncer, + sniffer: sniffer, + httpsOpts: httpsOpts, + } +} + +// Tunnel creates a bi-directional tunnel between lConn and dst. +// It detects the first packet from lConn. If it's a ClientHello, it applies the desync strategy. +func (b *Bridge) Tunnel( + ctx context.Context, + lConn net.Conn, + dst *netutil.Destination, + rule *config.Rule, +) error { + httpsOpts := b.httpsOpts + if rule != nil { + httpsOpts = httpsOpts.Merge(rule.HTTPS) + } + + if b.sniffer != nil && ptr.FromPtr(httpsOpts.FakeCount) > 0 { + b.sniffer.RegisterUntracked(dst.Addrs, dst.Port) + } + + logger := logging.WithLocalScope(ctx, b.logger, "https") + + rConn, err := netutil.DialFastest(ctx, "tcp", dst.Addrs, dst.Port, dst.Timeout) + if err != nil { + return err + } + defer netutil.CloseConns(rConn) + + logger.Debug().Msgf("new remote conn -> %s", rConn.RemoteAddr()) + + // Read the first message from the client (expected to be ClientHello) + tlsMsg, err := proto.ReadTLSMessage(lConn) + if err != nil { + logger.Trace().Err(err).Msgf("failed to read first message from client") + return nil // Client might have closed connection or sent garbage + } + + logger.Debug(). + Int("len", tlsMsg.Len()). + Msgf("client hello received <- %s", lConn.RemoteAddr()) + + if !tlsMsg.IsClientHello() { + logger.Trace().Int("len", tlsMsg.Len()).Msg("not a client hello. aborting") + return nil + } + + // Send ClientHello to the remote server (with desync if configured) + n, err := b.sendClientHello(ctx, rConn, tlsMsg, httpsOpts) + if err != nil { + return fmt.Errorf("failed to send client hello: %w", err) + } + + logger.Debug(). + Int("len", n). + Msgf("sent client hello -> %s", rConn.RemoteAddr()) + + // Start bi-directional tunneling + errCh := make(chan error, 2) + + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + go netutil.TunnelConns(ctx, logger, errCh, rConn, lConn) + go netutil.TunnelConns(ctx, logger, errCh, lConn, rConn) + + for range 2 { + e := <-errCh + if e == nil { + continue + } + + if netutil.IsConnectionResetByPeer(e) { + return netutil.ErrBlocked + } + + return fmt.Errorf( + "unsuccessful tunnel %s -> %s: %w", + lConn.RemoteAddr(), + rConn.RemoteAddr(), + e, + ) + } + + return nil +} + +func (b *Bridge) sendClientHello( + ctx context.Context, + conn net.Conn, + msg *proto.TLSMessage, + httpsOpts *config.HTTPSOptions, +) (int, error) { + logger := logging.WithLocalScope(ctx, b.logger, "client_hello") + return b.desyncer.Send(ctx, logger, conn, msg, httpsOpts) +} diff --git a/internal/proxy/http/https_handler.go b/internal/proxy/http/https_handler.go index 0c9d6f45..7d5873c4 100644 --- a/internal/proxy/http/https_handler.go +++ b/internal/proxy/http/https_handler.go @@ -9,148 +9,45 @@ import ( "github.com/rs/zerolog" "github.com/xvzc/SpoofDPI/internal/config" - "github.com/xvzc/SpoofDPI/internal/desync" "github.com/xvzc/SpoofDPI/internal/logging" "github.com/xvzc/SpoofDPI/internal/netutil" - "github.com/xvzc/SpoofDPI/internal/packet" "github.com/xvzc/SpoofDPI/internal/proto" - "github.com/xvzc/SpoofDPI/internal/ptr" + "github.com/xvzc/SpoofDPI/internal/proxy/handler" ) type HTTPSHandler struct { - logger zerolog.Logger - desyncer *desync.TLSDesyncer - sniffer packet.Sniffer - httpsOpts *config.HTTPSOptions + logger zerolog.Logger + bridge *handler.Bridge } func NewHTTPSHandler( logger zerolog.Logger, - desyncer *desync.TLSDesyncer, - sniffer packet.Sniffer, - httpsOpts *config.HTTPSOptions, + bridge *handler.Bridge, ) *HTTPSHandler { return &HTTPSHandler{ - logger: logger, - desyncer: desyncer, - sniffer: sniffer, - httpsOpts: httpsOpts, + logger: logger, + bridge: bridge, } } -// func (h *HTTPSHandler) DefaultRule() *policy.Rule { -// return h.defaultAttrs.Clone() -// } func (h *HTTPSHandler) HandleRequest( ctx context.Context, lConn net.Conn, dst *netutil.Destination, rule *config.Rule, ) error { - httpsOpts := h.httpsOpts - if rule != nil { - httpsOpts = httpsOpts.Merge(rule.HTTPS) - } - - if h.sniffer != nil && ptr.FromPtr(httpsOpts.FakeCount) > 0 { - h.sniffer.RegisterUntracked(dst.Addrs, dst.Port) - } - - logger := logging.WithLocalScope(ctx, h.logger, "https") - - rConn, err := netutil.DialFastest(ctx, "tcp", dst.Addrs, dst.Port, dst.Timeout) - if err != nil { - return err - } - defer netutil.CloseConns(rConn) - - logger.Debug().Msgf("new remote conn -> %s", rConn.RemoteAddr()) + logger := logging.WithLocalScope(ctx, h.logger, "handshake") - tlsMsg, err := h.handleProxyHandshake(ctx, lConn) - if err != nil { + // 1. Send 200 Connection Established + if err := proto.HTTPConnectionEstablishedResponse().Write(lConn); err != nil { if !netutil.IsConnectionResetByPeer(err) && !errors.Is(err, io.EOF) { logger.Trace().Err(err).Msgf("proxy handshake error") return fmt.Errorf("failed to handle proxy handshake: %w", err) } - return nil } - - if !tlsMsg.IsClientHello() { - logger.Trace().Int("len", tlsMsg.Len()).Msg("not a client hello. aborting") - return nil - } - - n, err := h.sendClientHello(ctx, rConn, tlsMsg, httpsOpts) - if err != nil { - return fmt.Errorf("failed to send client hello: %w", err) - } - - logger.Debug(). - Int("len", n). - Msgf("sent client hello -> %s", rConn.RemoteAddr()) - - errCh := make(chan error, 2) - - ctx, cancel := context.WithCancel(ctx) - defer cancel() - - go netutil.TunnelConns(ctx, logger, errCh, rConn, lConn) - go netutil.TunnelConns(ctx, logger, errCh, lConn, rConn) - - for range 2 { - e := <-errCh - if e == nil { - continue - } - - if netutil.IsConnectionResetByPeer(e) { - return netutil.ErrBlocked - } - - return fmt.Errorf( - "unsuccessful tunnel %s -> %s: %w", - lConn.RemoteAddr(), - rConn.RemoteAddr(), - e, - ) - } - - return nil -} - -// handleProxyHandshake sends "200 Connection Established" -// and reads the subsequent Client Hello. -func (h *HTTPSHandler) handleProxyHandshake( - ctx context.Context, - lConn net.Conn, -) (*proto.TLSMessage, error) { - logger := logging.WithLocalScope(ctx, h.logger, "handshake") - - if err := proto.HTTPConnectionEstablishedResponse().Write(lConn); err != nil { - return nil, err - } logger.Trace().Msgf("sent 200 connection established -> %s", lConn.RemoteAddr()) - tlsMsg, err := proto.ReadTLSMessage(lConn) - if err != nil { - return nil, err - } - - logger.Debug(). - Int("len", tlsMsg.Len()). - Msgf("client hello received <- %s", lConn.RemoteAddr()) - - return tlsMsg, nil -} - -// sendClientHello decides whether to spoof and sends the Client Hello accordingly. -func (h *HTTPSHandler) sendClientHello( - ctx context.Context, - conn net.Conn, - msg *proto.TLSMessage, - httpsOpts *config.HTTPSOptions, -) (int, error) { - logger := logging.WithLocalScope(ctx, h.logger, "client_hello") - return h.desyncer.Send(ctx, logger, conn, msg, httpsOpts) -} + // 2. Delegate to Bridge + return h.bridge.Tunnel(ctx, lConn, dst, rule) +} \ No newline at end of file diff --git a/internal/proxy/socks5/socks5_proxy.go b/internal/proxy/socks5/socks5_proxy.go index 44499f63..878b7d44 100644 --- a/internal/proxy/socks5/socks5_proxy.go +++ b/internal/proxy/socks5/socks5_proxy.go @@ -16,7 +16,7 @@ import ( "github.com/xvzc/SpoofDPI/internal/netutil" "github.com/xvzc/SpoofDPI/internal/proto" "github.com/xvzc/SpoofDPI/internal/proxy" - "github.com/xvzc/SpoofDPI/internal/proxy/http" + "github.com/xvzc/SpoofDPI/internal/proxy/handler" "github.com/xvzc/SpoofDPI/internal/ptr" "github.com/xvzc/SpoofDPI/internal/session" ) @@ -36,7 +36,7 @@ type SOCKS5Proxy struct { func NewSOCKS5Proxy( logger zerolog.Logger, resolver dns.Resolver, - httpsHandler *http.HTTPSHandler, + bridge *handler.Bridge, ruleMatcher matcher.RuleMatcher, serverOpts *config.ServerOptions, policyOpts *config.PolicyOptions, @@ -49,7 +49,7 @@ func NewSOCKS5Proxy( policyOpts: policyOpts, tcpHandler: NewTCPHandler( logger, - httpsHandler, + bridge, serverOpts, ), udpHandler: NewUDPHandler(logger), @@ -134,7 +134,7 @@ func (p *SOCKS5Proxy) handleConnection(ctx context.Context, conn net.Conn) { case proto.CmdUDPAssociate: // UDP Associate usually doesn't have destination info in the request - p.udpHandler.Handle(ctx, conn, req, nil, nil) + _ = p.udpHandler.Handle(ctx, conn, req, nil, nil) default: _ = proto.SOCKS5CommandNotSupportedResponse().Write(conn) logger.Warn().Uint8("cmd", req.Cmd).Msg("unsupported socks5 command") @@ -256,4 +256,4 @@ func (p *SOCKS5Proxy) negotiate(conn net.Conn) error { // Respond: Version 5, Method NoAuth(0) _, err := conn.Write([]byte{proto.SOCKSVersion, proto.AuthNone}) return err -} \ No newline at end of file +} diff --git a/internal/proxy/socks5/tcp_handler.go b/internal/proxy/socks5/tcp_handler.go index 4a1878e2..ba74ecaa 100644 --- a/internal/proxy/socks5/tcp_handler.go +++ b/internal/proxy/socks5/tcp_handler.go @@ -9,24 +9,24 @@ import ( "github.com/xvzc/SpoofDPI/internal/config" "github.com/xvzc/SpoofDPI/internal/netutil" "github.com/xvzc/SpoofDPI/internal/proto" - "github.com/xvzc/SpoofDPI/internal/proxy/http" + "github.com/xvzc/SpoofDPI/internal/proxy/handler" ) type TCPHandler struct { - logger zerolog.Logger - httpsHandler *http.HTTPSHandler - serverOpts *config.ServerOptions + logger zerolog.Logger + bridge *handler.Bridge + serverOpts *config.ServerOptions } func NewTCPHandler( logger zerolog.Logger, - httpsHandler *http.HTTPSHandler, + bridge *handler.Bridge, serverOpts *config.ServerOptions, ) *TCPHandler { return &TCPHandler{ - logger: logger, - httpsHandler: httpsHandler, - serverOpts: serverOpts, + logger: logger, + bridge: bridge, + serverOpts: serverOpts, } } @@ -39,7 +39,7 @@ func (h *TCPHandler) Handle( ) error { logger := h.logger.With().Ctx(ctx).Logger() - // 1. Validate Destination (Avoid Recursive Loop) + // 1. Validate Destination ok, err := netutil.ValidateDestination(dst.Addrs, dst.Port, h.serverOpts.ListenAddr) if err != nil { logger.Debug().Err(err).Msg("error determining if valid destination") @@ -56,14 +56,14 @@ func (h *TCPHandler) Handle( return netutil.ErrBlocked } - // 3. Send Success Response Optimistically + // 3. Send Success Response if err := proto.SOCKS5SuccessResponse().Bind(net.IPv4zero).Port(0).Write(conn); err != nil { logger.Error().Err(err).Msg("failed to write socks5 success reply") return err } - // 4. Handover to HTTPSHandler - handleErr := h.httpsHandler.HandleRequest(ctx, conn, dst, rule) + // 4. Delegate to Bridge + handleErr := h.bridge.Tunnel(ctx, conn, dst, rule) if handleErr == nil { return nil } From 87f4155bb5c84a24c85161c3c5bf08351b851940 Mon Sep 17 00:00:00 2001 From: xvzc Date: Mon, 22 Dec 2025 01:36:19 +0900 Subject: [PATCH 04/39] refactor: restructure proxy packages --- cmd/spoofdpi/main.go | 7 +++---- .../{http/http_handler.go => handler/http.go} | 2 +- .../{http/https_handler.go => handler/https.go} | 7 +++---- .../tcp_handler.go => handler/socks5_tcp.go} | 7 +++---- .../udp_handler.go => handler/socks5_udp.go} | 13 ++++++++----- internal/proxy/{http => }/http_proxy.go | 14 +++++++------- internal/proxy/{socks5 => }/socks5_proxy.go | 13 ++++++------- 7 files changed, 31 insertions(+), 32 deletions(-) rename internal/proxy/{http/http_handler.go => handler/http.go} (99%) rename internal/proxy/{http/https_handler.go => handler/https.go} (90%) rename internal/proxy/{socks5/tcp_handler.go => handler/socks5_tcp.go} (93%) rename internal/proxy/{socks5/udp_handler.go => handler/socks5_udp.go} (95%) rename internal/proxy/{http => }/http_proxy.go (96%) rename internal/proxy/{socks5 => }/socks5_proxy.go (97%) diff --git a/cmd/spoofdpi/main.go b/cmd/spoofdpi/main.go index 6929be91..0bd6adcb 100644 --- a/cmd/spoofdpi/main.go +++ b/cmd/spoofdpi/main.go @@ -19,7 +19,6 @@ import ( "github.com/xvzc/SpoofDPI/internal/packet" "github.com/xvzc/SpoofDPI/internal/proxy" "github.com/xvzc/SpoofDPI/internal/proxy/handler" - "github.com/xvzc/SpoofDPI/internal/proxy/http" "github.com/xvzc/SpoofDPI/internal/session" "github.com/xvzc/SpoofDPI/internal/system" ) @@ -246,7 +245,7 @@ func createProxy( } // create an HTTP handler. - httpHandler := http.NewHTTPHandler(logging.WithScope(logger, "hnd")) + httpHandler := handler.NewHTTPHandler(logging.WithScope(logger, "hnd")) var sniffer packet.Sniffer var writer packet.Writer @@ -270,7 +269,7 @@ func createProxy( cfg.HTTPS.Clone(), ) - httpsHandler := http.NewHTTPSHandler( + httpsHandler := handler.NewHTTPSHandler( logging.WithScope(logger, "hnd"), bridge, ) @@ -286,7 +285,7 @@ func createProxy( // ), nil // } - return http.NewHTTPProxy( + return proxy.NewHTTPProxy( logging.WithScope(logger, "pxy"), resolver, httpHandler, diff --git a/internal/proxy/http/http_handler.go b/internal/proxy/handler/http.go similarity index 99% rename from internal/proxy/http/http_handler.go rename to internal/proxy/handler/http.go index d847610d..a21dc9e5 100644 --- a/internal/proxy/http/http_handler.go +++ b/internal/proxy/handler/http.go @@ -1,4 +1,4 @@ -package http +package handler import ( "context" diff --git a/internal/proxy/http/https_handler.go b/internal/proxy/handler/https.go similarity index 90% rename from internal/proxy/http/https_handler.go rename to internal/proxy/handler/https.go index 7d5873c4..139d38dc 100644 --- a/internal/proxy/http/https_handler.go +++ b/internal/proxy/handler/https.go @@ -1,4 +1,4 @@ -package http +package handler import ( "context" @@ -12,17 +12,16 @@ import ( "github.com/xvzc/SpoofDPI/internal/logging" "github.com/xvzc/SpoofDPI/internal/netutil" "github.com/xvzc/SpoofDPI/internal/proto" - "github.com/xvzc/SpoofDPI/internal/proxy/handler" ) type HTTPSHandler struct { logger zerolog.Logger - bridge *handler.Bridge + bridge *Bridge } func NewHTTPSHandler( logger zerolog.Logger, - bridge *handler.Bridge, + bridge *Bridge, ) *HTTPSHandler { return &HTTPSHandler{ logger: logger, diff --git a/internal/proxy/socks5/tcp_handler.go b/internal/proxy/handler/socks5_tcp.go similarity index 93% rename from internal/proxy/socks5/tcp_handler.go rename to internal/proxy/handler/socks5_tcp.go index ba74ecaa..425aee02 100644 --- a/internal/proxy/socks5/tcp_handler.go +++ b/internal/proxy/handler/socks5_tcp.go @@ -1,4 +1,4 @@ -package socks5 +package handler import ( "context" @@ -9,18 +9,17 @@ import ( "github.com/xvzc/SpoofDPI/internal/config" "github.com/xvzc/SpoofDPI/internal/netutil" "github.com/xvzc/SpoofDPI/internal/proto" - "github.com/xvzc/SpoofDPI/internal/proxy/handler" ) type TCPHandler struct { logger zerolog.Logger - bridge *handler.Bridge + bridge *Bridge serverOpts *config.ServerOptions } func NewTCPHandler( logger zerolog.Logger, - bridge *handler.Bridge, + bridge *Bridge, serverOpts *config.ServerOptions, ) *TCPHandler { return &TCPHandler{ diff --git a/internal/proxy/socks5/udp_handler.go b/internal/proxy/handler/socks5_udp.go similarity index 95% rename from internal/proxy/socks5/udp_handler.go rename to internal/proxy/handler/socks5_udp.go index 6c6a7db6..8be1ee4a 100644 --- a/internal/proxy/socks5/udp_handler.go +++ b/internal/proxy/handler/socks5_udp.go @@ -1,4 +1,4 @@ -package socks5 +package handler import ( "context" @@ -39,7 +39,7 @@ func (h *UDPHandler) Handle( _ = proto.SOCKS5FailureResponse().Write(conn) return err } - defer udpConn.Close() + netutil.CloseConns(udpConn) lAddr := udpConn.LocalAddr().(*net.UDPAddr) @@ -63,7 +63,7 @@ func (h *UDPHandler) Handle( go func() { <-done - udpConn.Close() + netutil.CloseConns(udpConn) }() buf := make([]byte, 65535) @@ -99,7 +99,10 @@ func (h *UDPHandler) Handle( // The Target will reply to this socket. resolvedAddr, err := net.ResolveUDPAddr("udp", targetAddr) if err != nil { - logger.Warn().Err(err).Str("addr", targetAddr).Msg("failed to resolve udp target") + logger.Warn(). + Err(err). + Str("addr", targetAddr). + Msg("failed to resolve udp target") continue } @@ -189,4 +192,4 @@ func createUDPHeaderFromAddr(addr *net.UDPAddr) []byte { buf = append(buf, portBuf...) return buf -} \ No newline at end of file +} diff --git a/internal/proxy/http/http_proxy.go b/internal/proxy/http_proxy.go similarity index 96% rename from internal/proxy/http/http_proxy.go rename to internal/proxy/http_proxy.go index 5bb1b88b..9e714d93 100644 --- a/internal/proxy/http/http_proxy.go +++ b/internal/proxy/http_proxy.go @@ -1,4 +1,4 @@ -package http +package proxy import ( "context" @@ -16,7 +16,7 @@ import ( "github.com/xvzc/SpoofDPI/internal/matcher" "github.com/xvzc/SpoofDPI/internal/netutil" "github.com/xvzc/SpoofDPI/internal/proto" - "github.com/xvzc/SpoofDPI/internal/proxy" + "github.com/xvzc/SpoofDPI/internal/proxy/handler" "github.com/xvzc/SpoofDPI/internal/ptr" "github.com/xvzc/SpoofDPI/internal/session" ) @@ -25,8 +25,8 @@ type HTTPProxy struct { logger zerolog.Logger resolver dns.Resolver - httpHandler *HTTPHandler - httpsHandler *HTTPSHandler + httpHandler *handler.HTTPHandler + httpsHandler *handler.HTTPSHandler ruleMatcher matcher.RuleMatcher serverOpts *config.ServerOptions policyOpts *config.PolicyOptions @@ -35,12 +35,12 @@ type HTTPProxy struct { func NewHTTPProxy( logger zerolog.Logger, resolver dns.Resolver, - httpHandler *HTTPHandler, - httpsHandler *HTTPSHandler, + httpHandler *handler.HTTPHandler, + httpsHandler *handler.HTTPSHandler, ruleMatcher matcher.RuleMatcher, serverOpts *config.ServerOptions, policyOpts *config.PolicyOptions, -) proxy.ProxyServer { +) ProxyServer { return &HTTPProxy{ logger: logger, resolver: resolver, diff --git a/internal/proxy/socks5/socks5_proxy.go b/internal/proxy/socks5_proxy.go similarity index 97% rename from internal/proxy/socks5/socks5_proxy.go rename to internal/proxy/socks5_proxy.go index 878b7d44..2e14b422 100644 --- a/internal/proxy/socks5/socks5_proxy.go +++ b/internal/proxy/socks5_proxy.go @@ -1,4 +1,4 @@ -package socks5 +package proxy import ( "context" @@ -15,7 +15,6 @@ import ( "github.com/xvzc/SpoofDPI/internal/matcher" "github.com/xvzc/SpoofDPI/internal/netutil" "github.com/xvzc/SpoofDPI/internal/proto" - "github.com/xvzc/SpoofDPI/internal/proxy" "github.com/xvzc/SpoofDPI/internal/proxy/handler" "github.com/xvzc/SpoofDPI/internal/ptr" "github.com/xvzc/SpoofDPI/internal/session" @@ -29,8 +28,8 @@ type SOCKS5Proxy struct { serverOpts *config.ServerOptions policyOpts *config.PolicyOptions - tcpHandler *TCPHandler - udpHandler *UDPHandler + tcpHandler *handler.TCPHandler + udpHandler *handler.UDPHandler } func NewSOCKS5Proxy( @@ -40,19 +39,19 @@ func NewSOCKS5Proxy( ruleMatcher matcher.RuleMatcher, serverOpts *config.ServerOptions, policyOpts *config.PolicyOptions, -) proxy.ProxyServer { +) ProxyServer { return &SOCKS5Proxy{ logger: logger, resolver: resolver, ruleMatcher: ruleMatcher, serverOpts: serverOpts, policyOpts: policyOpts, - tcpHandler: NewTCPHandler( + tcpHandler: handler.NewTCPHandler( logger, bridge, serverOpts, ), - udpHandler: NewUDPHandler(logger), + udpHandler: handler.NewUDPHandler(logger), } } From 919fbf63c1592d8ac712731195bad6fd9e7c38ac Mon Sep 17 00:00:00 2001 From: xvzc Date: Mon, 22 Dec 2025 01:41:52 +0900 Subject: [PATCH 05/39] chore: rename handler files for clarity --- internal/proxy/handler/{http.go => http_http.go} | 0 internal/proxy/handler/{https.go => http_https.go} | 0 internal/proxy/handler/{bridge.go => tls.go} | 0 3 files changed, 0 insertions(+), 0 deletions(-) rename internal/proxy/handler/{http.go => http_http.go} (100%) rename internal/proxy/handler/{https.go => http_https.go} (100%) rename internal/proxy/handler/{bridge.go => tls.go} (100%) diff --git a/internal/proxy/handler/http.go b/internal/proxy/handler/http_http.go similarity index 100% rename from internal/proxy/handler/http.go rename to internal/proxy/handler/http_http.go diff --git a/internal/proxy/handler/https.go b/internal/proxy/handler/http_https.go similarity index 100% rename from internal/proxy/handler/https.go rename to internal/proxy/handler/http_https.go diff --git a/internal/proxy/handler/bridge.go b/internal/proxy/handler/tls.go similarity index 100% rename from internal/proxy/handler/bridge.go rename to internal/proxy/handler/tls.go From 0f6b76f82e75507af54e15988ecc6c8ed65d93a3 Mon Sep 17 00:00:00 2001 From: xvzc Date: Mon, 22 Dec 2025 01:55:42 +0900 Subject: [PATCH 06/39] refactor: restructure proxy packages to http, socks5, and tlsutil --- cmd/spoofdpi/main.go | 11 +- .../http_https.go => http/connect.go} | 10 +- .../{handler/http_http.go => http/handler.go} | 2 +- internal/proxy/http/proxy.go | 224 +++++++++++++++ internal/proxy/socks5/proxy.go | 259 ++++++++++++++++++ .../{handler/socks5_tcp.go => socks5/tcp.go} | 7 +- .../{handler/socks5_udp.go => socks5/udp.go} | 2 +- .../{handler/tls.go => tlsutil/bridge.go} | 14 +- 8 files changed, 508 insertions(+), 21 deletions(-) rename internal/proxy/{handler/http_https.go => http/connect.go} (90%) rename internal/proxy/{handler/http_http.go => http/handler.go} (99%) create mode 100644 internal/proxy/http/proxy.go create mode 100644 internal/proxy/socks5/proxy.go rename internal/proxy/{handler/socks5_tcp.go => socks5/tcp.go} (92%) rename internal/proxy/{handler/socks5_udp.go => socks5/udp.go} (99%) rename internal/proxy/{handler/tls.go => tlsutil/bridge.go} (94%) diff --git a/cmd/spoofdpi/main.go b/cmd/spoofdpi/main.go index 0bd6adcb..e68ef286 100644 --- a/cmd/spoofdpi/main.go +++ b/cmd/spoofdpi/main.go @@ -18,7 +18,8 @@ import ( "github.com/xvzc/SpoofDPI/internal/matcher" "github.com/xvzc/SpoofDPI/internal/packet" "github.com/xvzc/SpoofDPI/internal/proxy" - "github.com/xvzc/SpoofDPI/internal/proxy/handler" + "github.com/xvzc/SpoofDPI/internal/proxy/http" + "github.com/xvzc/SpoofDPI/internal/proxy/tlsutil" "github.com/xvzc/SpoofDPI/internal/session" "github.com/xvzc/SpoofDPI/internal/system" ) @@ -245,7 +246,7 @@ func createProxy( } // create an HTTP handler. - httpHandler := handler.NewHTTPHandler(logging.WithScope(logger, "hnd")) + httpHandler := http.NewHTTPHandler(logging.WithScope(logger, "hnd")) var sniffer packet.Sniffer var writer packet.Writer @@ -258,7 +259,7 @@ func createProxy( } } - bridge := handler.NewBridge( + bridge := tlsutil.NewTLSBridge( logging.WithScope(logger, "brg"), desync.NewTLSDesyncer( writer, @@ -269,7 +270,7 @@ func createProxy( cfg.HTTPS.Clone(), ) - httpsHandler := handler.NewHTTPSHandler( + httpsHandler := http.NewHTTPSHandler( logging.WithScope(logger, "hnd"), bridge, ) @@ -285,7 +286,7 @@ func createProxy( // ), nil // } - return proxy.NewHTTPProxy( + return http.NewHTTPProxy( logging.WithScope(logger, "pxy"), resolver, httpHandler, diff --git a/internal/proxy/handler/http_https.go b/internal/proxy/http/connect.go similarity index 90% rename from internal/proxy/handler/http_https.go rename to internal/proxy/http/connect.go index 139d38dc..be7fb7f7 100644 --- a/internal/proxy/handler/http_https.go +++ b/internal/proxy/http/connect.go @@ -1,4 +1,4 @@ -package handler +package http import ( "context" @@ -12,16 +12,17 @@ import ( "github.com/xvzc/SpoofDPI/internal/logging" "github.com/xvzc/SpoofDPI/internal/netutil" "github.com/xvzc/SpoofDPI/internal/proto" + "github.com/xvzc/SpoofDPI/internal/proxy/tlsutil" ) type HTTPSHandler struct { logger zerolog.Logger - bridge *Bridge + bridge *tlsutil.TLSBridge } func NewHTTPSHandler( logger zerolog.Logger, - bridge *Bridge, + bridge *tlsutil.TLSBridge, ) *HTTPSHandler { return &HTTPSHandler{ logger: logger, @@ -49,4 +50,5 @@ func (h *HTTPSHandler) HandleRequest( // 2. Delegate to Bridge return h.bridge.Tunnel(ctx, lConn, dst, rule) -} \ No newline at end of file +} + diff --git a/internal/proxy/handler/http_http.go b/internal/proxy/http/handler.go similarity index 99% rename from internal/proxy/handler/http_http.go rename to internal/proxy/http/handler.go index a21dc9e5..d847610d 100644 --- a/internal/proxy/handler/http_http.go +++ b/internal/proxy/http/handler.go @@ -1,4 +1,4 @@ -package handler +package http import ( "context" diff --git a/internal/proxy/http/proxy.go b/internal/proxy/http/proxy.go new file mode 100644 index 00000000..5bb1b88b --- /dev/null +++ b/internal/proxy/http/proxy.go @@ -0,0 +1,224 @@ +package http + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "io" + "net" + "time" + + "github.com/rs/zerolog" + "github.com/xvzc/SpoofDPI/internal/config" + "github.com/xvzc/SpoofDPI/internal/dns" + "github.com/xvzc/SpoofDPI/internal/logging" + "github.com/xvzc/SpoofDPI/internal/matcher" + "github.com/xvzc/SpoofDPI/internal/netutil" + "github.com/xvzc/SpoofDPI/internal/proto" + "github.com/xvzc/SpoofDPI/internal/proxy" + "github.com/xvzc/SpoofDPI/internal/ptr" + "github.com/xvzc/SpoofDPI/internal/session" +) + +type HTTPProxy struct { + logger zerolog.Logger + + resolver dns.Resolver + httpHandler *HTTPHandler + httpsHandler *HTTPSHandler + ruleMatcher matcher.RuleMatcher + serverOpts *config.ServerOptions + policyOpts *config.PolicyOptions +} + +func NewHTTPProxy( + logger zerolog.Logger, + resolver dns.Resolver, + httpHandler *HTTPHandler, + httpsHandler *HTTPSHandler, + ruleMatcher matcher.RuleMatcher, + serverOpts *config.ServerOptions, + policyOpts *config.PolicyOptions, +) proxy.ProxyServer { + return &HTTPProxy{ + logger: logger, + resolver: resolver, + httpHandler: httpHandler, + httpsHandler: httpsHandler, + ruleMatcher: ruleMatcher, + serverOpts: serverOpts, + policyOpts: policyOpts, + } +} + +func (p *HTTPProxy) ListenAndServe(ctx context.Context, wait chan struct{}) { + <-wait + + logger := p.logger.With().Ctx(ctx).Logger() + + listener, err := net.ListenTCP("tcp", p.serverOpts.ListenAddr) + if err != nil { + p.logger.Fatal(). + Err(err). + Msgf("error creating listener on %s", p.serverOpts.ListenAddr.String()) + } + + logger.Info(). + Msgf("created a listener on %s", p.serverOpts.ListenAddr) + + for { + conn, err := listener.Accept() + if err != nil { + p.logger.Error(). + Err(err). + Msgf("failed to accept new connection") + + continue + } + + go p.handleNewConnection(session.WithNewTraceID(context.Background()), conn) + } +} + +func (p *HTTPProxy) handleNewConnection(ctx context.Context, conn net.Conn) { + logger := logging.WithLocalScope(ctx, p.logger, "conn") + + ctx, cancel := context.WithCancel(ctx) + defer cancel() + defer netutil.CloseConns(conn) + + req, err := proto.ReadHttpRequest(conn) + if err != nil { + if err != io.EOF { + logger.Warn().Err(err).Msg("failed to read http request") + } + + return + } + + if !req.IsValidMethod() { + logger.Warn().Str("method", req.Method).Msg("unsupported method. abort") + _ = proto.HTTPNotImplementedResponse().Write(conn) + + return + } + + domain := req.ExtractDomain() + dstPort, err := req.ExtractPort() + if err != nil { + logger.Warn().Str("host", req.Host).Msg("failed to extract port") + _ = proto.HTTPBadRequestResponse().Write(conn) + + return + } + + ctx = session.WithHostInfo(ctx, domain) + logger = logger.With().Ctx(ctx).Logger() + + logger.Debug(). + Str("method", req.Method). + Str("from", conn.RemoteAddr().String()). + Msg("new request") + + nameMatch := p.ruleMatcher.Search( + &matcher.Selector{Kind: matcher.MatchKindDomain, Domain: ptr.FromValue(domain)}, + ) + if nameMatch != nil && logger.GetLevel() == zerolog.TraceLevel { + jsonAttrs, _ := json.Marshal(nameMatch) + logger.Trace().RawJSON("values", jsonAttrs).Msg("name match") + } + + t1 := time.Now() + rSet, err := p.resolver.Resolve(ctx, domain, nil, nameMatch) + dt := time.Since(t1).Milliseconds() + if err != nil { + _ = proto.HTTPBadGatewayResponse().Write(conn) + logging.ErrorUnwrapped(&logger, "dns lookup failed", err) + + return + } + + logger.Debug(). + Int("cnt", len(rSet.Addrs)). + Str("took", fmt.Sprintf("%dms", dt)). + Msgf("dns lookup ok") + + // Avoid recursively querying self. + ok, err := netutil.ValidateDestination(rSet.Addrs, dstPort, p.serverOpts.ListenAddr) + if err != nil { + logger.Debug().Err(err).Msg("error validating dst addrs") + if !ok { + _ = proto.HTTPForbiddenResponse().Write(conn) + } + } + + var selectors []*matcher.Selector + for _, v := range rSet.Addrs { + selectors = append(selectors, &matcher.Selector{ + Kind: matcher.MatchKindAddr, + IP: ptr.FromValue(v.IP), + Port: ptr.FromValue(uint16(dstPort)), + }) + } + + addrMatch := p.ruleMatcher.SearchAll(selectors) + if addrMatch != nil && logger.GetLevel() == zerolog.TraceLevel { + jsonAttrs, _ := json.Marshal(addrMatch) + logger.Trace().RawJSON("values", jsonAttrs).Msg("addr match") + } + + bestMatch := matcher.GetHigherPriorityRule(addrMatch, nameMatch) + if bestMatch != nil && logger.GetLevel() == zerolog.TraceLevel { + jsonAttrs, _ := json.Marshal(bestMatch) + logger.Trace().RawJSON("values", jsonAttrs).Msg("best match") + } + + if bestMatch != nil && *bestMatch.Block { + logger.Debug().Msg("request is blocked by policy") + return + } + + dst := &netutil.Destination{ + Domain: domain, + Addrs: rSet.Addrs, + Port: dstPort, + Timeout: *p.serverOpts.Timeout, + } + + var handleErr error + if req.IsConnectMethod() { + handleErr = p.httpsHandler.HandleRequest(ctx, conn, dst, bestMatch) + } else { + handleErr = p.httpHandler.HandleRequest(ctx, conn, req, dst, bestMatch) + } + + if handleErr == nil { // Early exit if no error found + return + } + + logger.Warn().Err(handleErr).Msg("error handling request") + if !errors.Is(handleErr, netutil.ErrBlocked) { // Early exit if not blocked + return + } + + // ┌─────────────┐ + // │ AUTO config │ + // └─────────────┘ + if bestMatch != nil && logger.GetLevel() == zerolog.TraceLevel { + logger.Info().Msg("skipping auto-config (duplicate policy)") + return + } + + // Perform auto config if enabled and RuleTemplate is not nil + if *p.policyOpts.Auto && p.policyOpts.Template != nil { + newRule := p.policyOpts.Template.Clone() + newRule.Match = &config.MatchAttrs{Domains: []string{domain}} + + if err := p.ruleMatcher.Add(newRule); err != nil { + logger.Info().Err(err).Msg("failed to add config automatically") + } else { + logger.Info().Msg("automatically added to config") + } + } +} diff --git a/internal/proxy/socks5/proxy.go b/internal/proxy/socks5/proxy.go new file mode 100644 index 00000000..e0edb86e --- /dev/null +++ b/internal/proxy/socks5/proxy.go @@ -0,0 +1,259 @@ +package socks5 + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net" + "time" + + "github.com/rs/zerolog" + "github.com/xvzc/SpoofDPI/internal/config" + "github.com/xvzc/SpoofDPI/internal/dns" + "github.com/xvzc/SpoofDPI/internal/logging" + "github.com/xvzc/SpoofDPI/internal/matcher" + "github.com/xvzc/SpoofDPI/internal/netutil" + "github.com/xvzc/SpoofDPI/internal/proto" + "github.com/xvzc/SpoofDPI/internal/proxy" + "github.com/xvzc/SpoofDPI/internal/proxy/tlsutil" + "github.com/xvzc/SpoofDPI/internal/ptr" + "github.com/xvzc/SpoofDPI/internal/session" +) + +type SOCKS5Proxy struct { + logger zerolog.Logger + + resolver dns.Resolver + ruleMatcher matcher.RuleMatcher + serverOpts *config.ServerOptions + policyOpts *config.PolicyOptions + + tcpHandler *TCPHandler + udpHandler *UDPHandler +} + +func NewSOCKS5Proxy( + logger zerolog.Logger, + resolver dns.Resolver, + bridge *tlsutil.TLSBridge, + ruleMatcher matcher.RuleMatcher, + serverOpts *config.ServerOptions, + policyOpts *config.PolicyOptions, +) proxy.ProxyServer { + return &SOCKS5Proxy{ + logger: logger, + resolver: resolver, + ruleMatcher: ruleMatcher, + serverOpts: serverOpts, + policyOpts: policyOpts, + tcpHandler: NewTCPHandler( + logger, + bridge, + serverOpts, + ), + udpHandler: NewUDPHandler(logger), + } +} + +func (p *SOCKS5Proxy) ListenAndServe(ctx context.Context, wait chan struct{}) { + <-wait + + logger := p.logger.With().Ctx(ctx).Logger() + + // Using ListenTCP to match HTTPProxy style, though net.Listen is also fine + listener, err := net.ListenTCP("tcp", p.serverOpts.ListenAddr) + if err != nil { + p.logger.Fatal(). + Err(err). + Msgf("error creating socks5 listener on %s", p.serverOpts.ListenAddr.String()) + } + + logger.Info(). + Msgf("created a socks5 listener on %s", p.serverOpts.ListenAddr) + + for { + conn, err := listener.Accept() + if err != nil { + p.logger.Error(). + Err(err). + Msg("failed to accept new connection") + continue + } + + go p.handleConnection(session.WithNewTraceID(context.Background()), conn) + } +} + +func (p *SOCKS5Proxy) handleConnection(ctx context.Context, conn net.Conn) { + logger := logging.WithLocalScope(ctx, p.logger, "socks5") + + ctx, cancel := context.WithCancel(ctx) + defer cancel() + defer netutil.CloseConns(conn) + + // 1. Negotiation Phase + if err := p.negotiate(conn); err != nil { + logger.Debug().Err(err).Msg("socks5 negotiation failed") + return + } + + // 2. Request Phase + req, err := proto.ReadSocks5Request(conn) + if err != nil { + if err != io.EOF { + logger.Warn().Err(err).Msg("failed to read socks5 request") + } + return + } + + // Setup Logging Context + remoteInfo := req.Domain + if remoteInfo == "" { + remoteInfo = req.IP.String() + } + ctx = session.WithHostInfo(ctx, remoteInfo) + + switch req.Cmd { + case proto.CmdConnect: + rule, addrs, err := p.resolveAndMatch(ctx, req) + if err != nil { + return // resolveAndMatch logs error and writes failure response if needed + } + dst := &netutil.Destination{ + Domain: req.Domain, + Addrs: addrs, + Port: req.Port, + Timeout: *p.serverOpts.Timeout, + } + if err := p.tcpHandler.Handle(ctx, conn, req, dst, rule); err != nil { + return // Handler logs error + } + + p.handleAutoConfig(ctx, req, addrs, rule) + + case proto.CmdUDPAssociate: + // UDP Associate usually doesn't have destination info in the request + _ = p.udpHandler.Handle(ctx, conn, req, nil, nil) + default: + _ = proto.SOCKS5CommandNotSupportedResponse().Write(conn) + logger.Warn().Uint8("cmd", req.Cmd).Msg("unsupported socks5 command") + } +} + +func (p *SOCKS5Proxy) resolveAndMatch( + ctx context.Context, + req *proto.SOCKS5Request, +) (*config.Rule, []net.IPAddr, error) { + logger := zerolog.Ctx(ctx) + + // 1. Match Domain Rules (if domain provided) + var nameMatch *config.Rule + if req.Domain != "" { + nameMatch = p.ruleMatcher.Search( + &matcher.Selector{ + Kind: matcher.MatchKindDomain, + Domain: ptr.FromValue(req.Domain), + }, + ) + if nameMatch != nil && logger.GetLevel() == zerolog.TraceLevel { + jsonAttrs, _ := json.Marshal(nameMatch) + logger.Trace().RawJSON("values", jsonAttrs).Msg("name match") + } + } + + // 2. DNS Resolution + t1 := time.Now() + var addrs []net.IPAddr + + if req.Domain != "" { + // Resolve Domain + rSet, err := p.resolver.Resolve(ctx, req.Domain, nil, nameMatch) + if err != nil { + logging.ErrorUnwrapped(logger, "dns lookup failed", err) + return nil, nil, err + } + addrs = rSet.Addrs + } else { + addrs = []net.IPAddr{{IP: req.IP}} + } + + dt := time.Since(t1).Milliseconds() + logger.Debug(). + Int("cnt", len(addrs)). + Str("took", fmt.Sprintf("%dms", dt)). + Msg("dns lookup ok") + + // 3. Match IP Rules + var selectors []*matcher.Selector + for _, v := range addrs { + selectors = append(selectors, &matcher.Selector{ + Kind: matcher.MatchKindAddr, + IP: ptr.FromValue(v.IP), + Port: ptr.FromValue(uint16(req.Port)), + }) + } + + addrMatch := p.ruleMatcher.SearchAll(selectors) + if addrMatch != nil && logger.GetLevel() == zerolog.TraceLevel { + jsonAttrs, _ := json.Marshal(addrMatch) + logger.Trace().RawJSON("values", jsonAttrs).Msg("addr match") + } + + bestMatch := matcher.GetHigherPriorityRule(addrMatch, nameMatch) + return bestMatch, addrs, nil +} + +func (p *SOCKS5Proxy) handleAutoConfig( + ctx context.Context, + req *proto.SOCKS5Request, + addrs []net.IPAddr, + matchedRule *config.Rule, +) { + logger := zerolog.Ctx(ctx) + + if matchedRule != nil { + logger.Info(). + Interface("match", matchedRule.Match). + Str("name", *matchedRule.Name). + Msg("skipping auto-config (duplicate policy)") + return + } + + if *p.policyOpts.Auto && p.policyOpts.Template != nil { + newRule := p.policyOpts.Template.Clone() + targetDomain := req.Domain + if targetDomain == "" && len(addrs) > 0 { + targetDomain = addrs[0].IP.String() + } + + newRule.Match = &config.MatchAttrs{Domains: []string{targetDomain}} + + if err := p.ruleMatcher.Add(newRule); err != nil { + logger.Info().Err(err).Msg("failed to add config automatically") + } else { + logger.Info().Msg("automatically added to config") + } + } +} + +func (p *SOCKS5Proxy) negotiate(conn net.Conn) error { + header := make([]byte, 2) + if _, err := io.ReadFull(conn, header); err != nil { + return err + } + + if header[0] != proto.SOCKSVersion { + return fmt.Errorf("unsupported version: %d", header[0]) + } + + nMethods := int(header[1]) + methods := make([]byte, nMethods) + if _, err := io.ReadFull(conn, methods); err != nil { + return err + } + + // Respond: Version 5, Method NoAuth(0) + _, err := conn.Write([]byte{proto.SOCKSVersion, proto.AuthNone}) + return err +} diff --git a/internal/proxy/handler/socks5_tcp.go b/internal/proxy/socks5/tcp.go similarity index 92% rename from internal/proxy/handler/socks5_tcp.go rename to internal/proxy/socks5/tcp.go index 425aee02..bcb53b48 100644 --- a/internal/proxy/handler/socks5_tcp.go +++ b/internal/proxy/socks5/tcp.go @@ -1,4 +1,4 @@ -package handler +package socks5 import ( "context" @@ -9,17 +9,18 @@ import ( "github.com/xvzc/SpoofDPI/internal/config" "github.com/xvzc/SpoofDPI/internal/netutil" "github.com/xvzc/SpoofDPI/internal/proto" + "github.com/xvzc/SpoofDPI/internal/proxy/tlsutil" ) type TCPHandler struct { logger zerolog.Logger - bridge *Bridge + bridge *tlsutil.TLSBridge serverOpts *config.ServerOptions } func NewTCPHandler( logger zerolog.Logger, - bridge *Bridge, + bridge *tlsutil.TLSBridge, serverOpts *config.ServerOptions, ) *TCPHandler { return &TCPHandler{ diff --git a/internal/proxy/handler/socks5_udp.go b/internal/proxy/socks5/udp.go similarity index 99% rename from internal/proxy/handler/socks5_udp.go rename to internal/proxy/socks5/udp.go index 8be1ee4a..9368449a 100644 --- a/internal/proxy/handler/socks5_udp.go +++ b/internal/proxy/socks5/udp.go @@ -1,4 +1,4 @@ -package handler +package socks5 import ( "context" diff --git a/internal/proxy/handler/tls.go b/internal/proxy/tlsutil/bridge.go similarity index 94% rename from internal/proxy/handler/tls.go rename to internal/proxy/tlsutil/bridge.go index c82443c3..2f669734 100644 --- a/internal/proxy/handler/tls.go +++ b/internal/proxy/tlsutil/bridge.go @@ -1,4 +1,4 @@ -package handler +package tlsutil import ( "context" @@ -15,20 +15,20 @@ import ( "github.com/xvzc/SpoofDPI/internal/ptr" ) -type Bridge struct { +type TLSBridge struct { logger zerolog.Logger desyncer *desync.TLSDesyncer sniffer packet.Sniffer httpsOpts *config.HTTPSOptions } -func NewBridge( +func NewTLSBridge( logger zerolog.Logger, desyncer *desync.TLSDesyncer, sniffer packet.Sniffer, httpsOpts *config.HTTPSOptions, -) *Bridge { - return &Bridge{ +) *TLSBridge { + return &TLSBridge{ logger: logger, desyncer: desyncer, sniffer: sniffer, @@ -38,7 +38,7 @@ func NewBridge( // Tunnel creates a bi-directional tunnel between lConn and dst. // It detects the first packet from lConn. If it's a ClientHello, it applies the desync strategy. -func (b *Bridge) Tunnel( +func (b *TLSBridge) Tunnel( ctx context.Context, lConn net.Conn, dst *netutil.Destination, @@ -119,7 +119,7 @@ func (b *Bridge) Tunnel( return nil } -func (b *Bridge) sendClientHello( +func (b *TLSBridge) sendClientHello( ctx context.Context, conn net.Conn, msg *proto.TLSMessage, From 92247c3587ebf2a7539aa9085eddb8921bb5c1b8 Mon Sep 17 00:00:00 2001 From: xvzc Date: Wed, 14 Jan 2026 19:53:07 +0900 Subject: [PATCH 07/39] feat: support socks5 and tun --- cmd/spoofdpi/main.go | 242 ++++++--- cmd/spoofdpi/main_test.go | 62 +-- go.mod | 12 +- go.sum | 28 +- internal/cache/cache.go | 4 + internal/cache/lru_cache.go | 33 ++ internal/cache/ttl_cache.go | 25 + internal/config/cli.go | 137 +++-- internal/config/cli_test.go | 41 +- internal/config/config.go | 67 ++- internal/config/config_test.go | 24 +- internal/config/parse.go | 26 + internal/config/segment_test.go | 1 + internal/config/toml.go | 4 +- internal/config/toml_test.go | 11 +- internal/config/types.go | 400 ++++++++++---- internal/config/types_test.go | 171 ++++-- internal/config/validate.go | 10 +- internal/config/validate_test.go | 30 +- internal/desync/tls.go | 228 +++++--- internal/desync/tls_test.go | 498 ++++++------------ internal/desync/udp.go | 72 +++ internal/dns/addrselect/addrselect.go | 34 +- internal/dns/cache.go | 9 +- internal/dns/https.go | 111 ++-- internal/dns/resolver.go | 17 +- internal/dns/route.go | 29 +- internal/dns/system.go | 38 +- internal/logging/logging.go | 21 +- internal/matcher/addr.go | 6 +- internal/matcher/addr_test.go | 44 +- internal/matcher/domain.go | 6 +- internal/matcher/domain_test.go | 28 +- internal/matcher/matcher.go | 4 +- internal/matcher/matcher_test.go | 6 +- internal/netutil/addr.go | 134 ++++- internal/netutil/conn.go | 169 +++++- internal/netutil/conn_pool.go | 206 ++++++++ internal/netutil/dial.go | 25 +- internal/netutil/dial_darwin.go | 40 ++ internal/netutil/dial_other.go | 11 + internal/netutil/pac.go | 34 ++ internal/packet/LICENSE | 201 ------- internal/packet/handle_linux.go | 21 +- internal/packet/network_detector.go | 100 +++- internal/packet/sniffer.go | 56 +- internal/packet/tcp_sniffer.go | 168 +++--- internal/packet/tcp_writer.go | 12 +- internal/packet/udp_sniffer.go | 191 +++++++ internal/packet/udp_writer.go | 197 +++++++ internal/proto/http.go | 4 +- internal/proto/socks5.go | 71 ++- internal/proxy/http/connect.go | 54 -- internal/proxy/http_proxy.go | 224 -------- internal/proxy/proxy.go | 9 - internal/proxy/socks5/proxy.go | 259 --------- internal/proxy/socks5/tcp.go | 77 --- internal/proxy/socks5/udp.go | 195 ------- internal/proxy/socks5_proxy.go | 258 --------- internal/proxy/tlsutil/bridge.go | 130 ----- .../http/handler.go => server/http/http.go} | 40 +- internal/server/http/https.go | 162 ++++++ internal/server/http/network.go | 13 + internal/server/http/network_darwin.go | 116 ++++ internal/server/http/network_linux.go | 14 + internal/{proxy => server}/http/proxy.go | 120 +++-- internal/server/server.go | 17 + internal/server/socks5/bind.go | 96 ++++ internal/server/socks5/connect.go | 200 +++++++ internal/server/socks5/network.go | 13 + internal/server/socks5/network_darwin.go | 116 ++++ internal/server/socks5/network_linux.go | 14 + internal/server/socks5/server.go | 296 +++++++++++ internal/server/socks5/udp_associate.go | 244 +++++++++ internal/server/tun/network.go | 15 + internal/server/tun/network_darwin.go | 117 ++++ internal/server/tun/server.go | 372 +++++++++++++ internal/server/tun/tcp.go | 225 ++++++++ internal/server/tun/udp.go | 119 +++++ internal/session/session.go | 3 +- internal/system/sysproxy.go | 13 - internal/system/sysproxy_darwin.go | 100 ---- internal/system/sysproxy_linux.go | 14 - 83 files changed, 5070 insertions(+), 2694 deletions(-) create mode 100644 internal/config/segment_test.go create mode 100644 internal/desync/udp.go create mode 100644 internal/netutil/conn_pool.go create mode 100644 internal/netutil/dial_darwin.go create mode 100644 internal/netutil/dial_other.go create mode 100644 internal/netutil/pac.go delete mode 100644 internal/packet/LICENSE create mode 100644 internal/packet/udp_sniffer.go create mode 100644 internal/packet/udp_writer.go delete mode 100644 internal/proxy/http/connect.go delete mode 100644 internal/proxy/http_proxy.go delete mode 100644 internal/proxy/proxy.go delete mode 100644 internal/proxy/socks5/proxy.go delete mode 100644 internal/proxy/socks5/tcp.go delete mode 100644 internal/proxy/socks5/udp.go delete mode 100644 internal/proxy/socks5_proxy.go delete mode 100644 internal/proxy/tlsutil/bridge.go rename internal/{proxy/http/handler.go => server/http/http.go} (73%) create mode 100644 internal/server/http/https.go create mode 100644 internal/server/http/network.go create mode 100644 internal/server/http/network_darwin.go create mode 100644 internal/server/http/network_linux.go rename internal/{proxy => server}/http/proxy.go (66%) create mode 100644 internal/server/server.go create mode 100644 internal/server/socks5/bind.go create mode 100644 internal/server/socks5/connect.go create mode 100644 internal/server/socks5/network.go create mode 100644 internal/server/socks5/network_darwin.go create mode 100644 internal/server/socks5/network_linux.go create mode 100644 internal/server/socks5/server.go create mode 100644 internal/server/socks5/udp_associate.go create mode 100644 internal/server/tun/network.go create mode 100644 internal/server/tun/network_darwin.go create mode 100644 internal/server/tun/server.go create mode 100644 internal/server/tun/tcp.go create mode 100644 internal/server/tun/udp.go delete mode 100644 internal/system/sysproxy.go delete mode 100644 internal/system/sysproxy_darwin.go delete mode 100644 internal/system/sysproxy_linux.go diff --git a/cmd/spoofdpi/main.go b/cmd/spoofdpi/main.go index e68ef286..84bd475d 100644 --- a/cmd/spoofdpi/main.go +++ b/cmd/spoofdpi/main.go @@ -16,12 +16,13 @@ import ( "github.com/xvzc/SpoofDPI/internal/dns" "github.com/xvzc/SpoofDPI/internal/logging" "github.com/xvzc/SpoofDPI/internal/matcher" + "github.com/xvzc/SpoofDPI/internal/netutil" "github.com/xvzc/SpoofDPI/internal/packet" - "github.com/xvzc/SpoofDPI/internal/proxy" - "github.com/xvzc/SpoofDPI/internal/proxy/http" - "github.com/xvzc/SpoofDPI/internal/proxy/tlsutil" + "github.com/xvzc/SpoofDPI/internal/server" + "github.com/xvzc/SpoofDPI/internal/server/http" // Add http import + "github.com/xvzc/SpoofDPI/internal/server/socks5" + "github.com/xvzc/SpoofDPI/internal/server/tun" "github.com/xvzc/SpoofDPI/internal/session" - "github.com/xvzc/SpoofDPI/internal/system" ) // Version and commit are set at build time. @@ -55,32 +56,31 @@ func runApp(ctx context.Context, configDir string, cfg *config.Config) { Msgf("config file loaded") } - // set system-wide proxy configuration. - if !*cfg.General.SetSystemProxy { - logger.Info().Msg("use `--system-proxy` to automatically set system proxy") - } - resolver := createResolver(logger, cfg) - p, err := createProxy(logger, cfg, resolver) + + srv, err := createServer(logger, cfg, resolver) if err != nil { - logger.Fatal(). - Err(err). - Msg("failed to create proxy") + logger.Fatal().Err(err).Msg("failed to create server") } - // start app - wait := make(chan struct{}) // wait for setup logs to be printed - go p.ListenAndServe(ctx, wait) + // Start server + ready := make(chan struct{}) + go func() { + if err := srv.Start(ctx, ready); err != nil { + logger.Fatal().Err(err).Msgf("failed to start server: %T", srv) + } + }() + + <-ready - // set system-wide proxy configuration. - if *cfg.General.SetSystemProxy { - port := cfg.Server.ListenAddr.Port - if err := system.SetProxy(logger, uint16(port)); err != nil { - logger.Fatal().Err(err).Msg("failed to enable system proxy") + // System Proxy Config + if *cfg.General.SetNetworkConfig { + if err := srv.SetNetworkConfig(); err != nil { + logger.Fatal().Err(err).Msg("failed to set system network config") } defer func() { - if err := system.UnsetProxy(logger); err != nil { - logger.Fatal().Err(err).Msg("failed to disable system proxy") + if err := srv.UnsetNetworkConfig(); err != nil { + logger.Error().Err(err).Msg("failed to unset system network config") } }() } @@ -114,7 +114,9 @@ func runApp(ctx context.Context, configDir string, cfg *config.Config) { Msgf("connection timeout") } - wait <- struct{}{} + logger.Info().Msgf("server-mode; %s", cfg.Server.Mode.String()) + + logger.Info().Msgf("server started on %s", srv.Addr()) sigs := make(chan os.Signal, 1) done := make(chan bool, 1) @@ -132,6 +134,9 @@ func runApp(ctx context.Context, configDir string, cfg *config.Config) { }() <-done + + // Graceful shutdown + _ = srv.Stop() } func createResolver(logger zerolog.Logger, cfg *config.Config) dns.Resolver { @@ -170,14 +175,14 @@ func createResolver(logger zerolog.Logger, cfg *config.Config) dns.Resolver { func createPacketObjects( logger zerolog.Logger, cfg *config.Config, -) (packet.Sniffer, packet.Writer, error) { +) (packet.Sniffer, packet.Writer, packet.Sniffer, packet.Writer, error) { // create a network detector for passive discovery networkDetector := packet.NewNetworkDetector( logging.WithScope(logger, "pkt"), ) if err := networkDetector.Start(context.Background()); err != nil { - return nil, nil, fmt.Errorf("error starting network detector: %w", err) + return nil, nil, nil, nil, fmt.Errorf("error starting network detector: %w", err) } // Wait for gateway MAC with timeout @@ -186,15 +191,27 @@ func createPacketObjects( gatewayMAC, err := networkDetector.WaitForGatewayMAC(ctx) if err != nil { - return nil, nil, fmt.Errorf("failed to detect gateway (timeout): %w", err) + return nil, nil, nil, nil, fmt.Errorf( + "failed to detect gateway (timeout): %w", + err, + ) } iface := networkDetector.GetInterface() // create a pcap handle for packet capturing. - handle, err := packet.NewHandle(iface) + tcpHandle, err := packet.NewHandle(iface) if err != nil { - return nil, nil, fmt.Errorf( + return nil, nil, nil, nil, fmt.Errorf( + "error opening pcap handle on interface %s: %w", + iface.Name, + err, + ) + } + + udpHandle, err := packet.NewHandle(iface) + if err != nil { + return nil, nil, nil, nil, fmt.Errorf( "error opening pcap handle on interface %s: %w", iface.Name, err, @@ -205,34 +222,57 @@ func createPacketObjects( logger.Info().Str("name", iface.Name). Str("mac", iface.HardwareAddr.String()). Msg(" interface") + + gatewayMACStr := gatewayMAC.String() + if gatewayMACStr == "" { + gatewayMACStr = "none" + } logger.Info(). - Str("mac", gatewayMAC.String()). + Str("mac", gatewayMACStr). Msg(" gateway (passive detection)") hopCache := cache.NewLRUCache(4096) - sniffer := packet.NewTCPSniffer( + + // TCP Objects + tcpSniffer := packet.NewTCPSniffer( logging.WithScope(logger, "pkt"), hopCache, - handle, + tcpHandle, uint8(*cfg.Server.DefaultTTL), ) - sniffer.StartCapturing() + tcpSniffer.StartCapturing() - writer := packet.NewTCPWriter( + tcpWriter := packet.NewTCPWriter( logging.WithScope(logger, "pkt"), - handle, + tcpHandle, iface, gatewayMAC, ) - return sniffer, writer, nil + // UDP Objects + udpSniffer := packet.NewUDPSniffer( + logging.WithScope(logger, "pkt"), + hopCache, + udpHandle, + uint8(*cfg.Server.DefaultTTL), + ) + udpSniffer.StartCapturing() + + udpWriter := packet.NewUDPWriter( + logging.WithScope(logger, "pkt"), + udpHandle, + iface, + gatewayMAC, + ) + + return tcpSniffer, tcpWriter, udpSniffer, udpWriter, nil } -func createProxy( +func createServer( logger zerolog.Logger, cfg *config.Config, resolver dns.Resolver, -) (proxy.ProxyServer, error) { +) (server.Server, error) { ruleMatcher := matcher.NewRuleMatcher( matcher.NewAddrMatcher(), matcher.NewDomainMatcher(), @@ -245,56 +285,104 @@ func createProxy( } } - // create an HTTP handler. - httpHandler := http.NewHTTPHandler(logging.WithScope(logger, "hnd")) - - var sniffer packet.Sniffer - var writer packet.Writer + var tcpSniffer packet.Sniffer + var tcpWriter packet.Writer + var udpSniffer packet.Sniffer + var udpWriter packet.Writer if cfg.ShouldEnablePcap() { var err error - sniffer, writer, err = createPacketObjects(logger, cfg) + tcpSniffer, tcpWriter, udpSniffer, udpWriter, err = createPacketObjects( + logger, + cfg, + ) if err != nil { return nil, err } } - bridge := tlsutil.NewTLSBridge( - logging.WithScope(logger, "brg"), - desync.NewTLSDesyncer( - writer, - sniffer, - &desync.TLSDesyncerAttrs{DefaultTTL: *cfg.Server.DefaultTTL}, - ), - sniffer, - cfg.HTTPS.Clone(), + desyncer := desync.NewTLSDesyncer( + tcpWriter, + tcpSniffer, ) - httpsHandler := http.NewHTTPSHandler( - logging.WithScope(logger, "hnd"), - bridge, - ) + switch *cfg.Server.Mode { + case config.ServerModeHTTP: + httpHandler := http.NewHTTPHandler(logging.WithScope(logger, "hnd")) + httpsHandler := http.NewHTTPSHandler( + logging.WithScope(logger, "hnd"), + desyncer, + tcpSniffer, + cfg.HTTPS.Clone(), + ) + + return http.NewHTTPProxy( + logging.WithScope(logger, "srv"), + resolver, + httpHandler, + httpsHandler, + ruleMatcher, + cfg.Server.Clone(), + cfg.Policy.Clone(), + ), nil + case config.ServerModeSOCKS5: + connectHandler := socks5.NewConnectHandler( + logging.WithScope(logger, "hnd"), + desyncer, + tcpSniffer, + cfg.Server.Clone(), + cfg.HTTPS.Clone(), + ) + udpAssociateHandler := socks5.NewUdpAssociateHandler( + logging.WithScope(logger, "hnd"), + netutil.NewConnPool(4096, 60*time.Second), + ) + bindHandler := socks5.NewBindHandler(logging.WithScope(logger, "hnd")) + + return socks5.NewSOCKS5Proxy( + logging.WithScope(logger, "srv"), + resolver, + ruleMatcher, + connectHandler, + bindHandler, + udpAssociateHandler, + cfg.Server.Clone(), + cfg.Policy.Clone(), + ), nil + case config.ServerModeTUN: + tcpHandler := tun.NewTCPHandler( + logging.WithScope(logger, "hnd"), + ruleMatcher, // For domain-based TLS matching + cfg.HTTPS.Clone(), + desyncer, + tcpSniffer, // For TTL tracking + "", // iface and gateway will be set later + "", + ) - // if cfg.Server.EnableSocks5 != nil && *cfg.Server.EnableSocks5 { - // return socks5.NewSocks5Proxy( - // logging.WithScope(logger, "pxy"), - // resolver, - // httpsHandler, - // ruleMatcher, - // cfg.Server.Clone(), - // cfg.Policy.Clone(), - // ), nil - // } - - return http.NewHTTPProxy( - logging.WithScope(logger, "pxy"), - resolver, - httpHandler, - httpsHandler, - ruleMatcher, - cfg.Server.Clone(), - cfg.Policy.Clone(), - ), nil + udpDesyncer := desync.NewUDPDesyncer( + logging.WithScope(logger, "hnd"), + udpWriter, + udpSniffer, + ) + + udpHandler := tun.NewUDPHandler( + logging.WithScope(logger, "hnd"), + udpDesyncer, + cfg.UDP.Clone(), + netutil.NewConnPool(4096, 60*time.Second), + ) + + return tun.NewTunServer( + logging.WithScope(logger, "srv"), + cfg, + ruleMatcher, // For IP-based matching in server.go + tcpHandler, + udpHandler, + ), nil + default: + return nil, fmt.Errorf("unknown server mode: %s", *cfg.Server.Mode) + } } func printBanner() { diff --git a/cmd/spoofdpi/main_test.go b/cmd/spoofdpi/main_test.go index 7076c373..22101337 100644 --- a/cmd/spoofdpi/main_test.go +++ b/cmd/spoofdpi/main_test.go @@ -6,21 +6,21 @@ import ( "time" "github.com/rs/zerolog" + "github.com/samber/lo" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/xvzc/SpoofDPI/internal/config" "github.com/xvzc/SpoofDPI/internal/proto" - "github.com/xvzc/SpoofDPI/internal/ptr" ) func TestCreateResolver(t *testing.T) { cfg := config.NewConfig() cfg.DNS = &config.DNSOptions{ - Mode: ptr.FromValue(config.DNSModeUDP), + Mode: lo.ToPtr(config.DNSModeUDP), Addr: &net.TCPAddr{IP: net.ParseIP("8.8.8.8"), Port: 53}, - HTTPSURL: ptr.FromValue("https://dns.google/dns-query"), - QType: ptr.FromValue(config.DNSQueryIPv4), - Cache: ptr.FromValue(true), + HTTPSURL: lo.ToPtr("https://dns.google/dns-query"), + QType: lo.ToPtr(config.DNSQueryIPv4), + Cache: lo.ToPtr(true), } logger := zerolog.Nop() @@ -35,39 +35,40 @@ func TestCreateProxy_NoPcap(t *testing.T) { // Server Config cfg.Server = &config.ServerOptions{ + Mode: lo.ToPtr(config.ServerModeHTTP), ListenAddr: &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 0}, - DefaultTTL: ptr.FromValue(uint8(64)), - Timeout: ptr.FromValue(time.Duration(1 * time.Second)), + DefaultTTL: lo.ToPtr(uint8(64)), + Timeout: lo.ToPtr(time.Duration(0)), } // HTTPS Config (Ensure FakeCount is 0 to disable PCAP) cfg.HTTPS = &config.HTTPSOptions{ - Disorder: ptr.FromValue(false), - FakeCount: ptr.FromValue(uint8(0)), + Disorder: lo.ToPtr(false), + FakeCount: lo.ToPtr(uint8(0)), FakePacket: proto.NewFakeTLSMessage([]byte{}), - SplitMode: ptr.FromValue(config.HTTPSSplitModeChunk), - ChunkSize: ptr.FromValue(uint8(10)), - Skip: ptr.FromValue(false), + SplitMode: lo.ToPtr(config.HTTPSSplitModeChunk), + ChunkSize: lo.ToPtr(uint8(10)), + Skip: lo.ToPtr(false), } // Policy Config cfg.Policy = &config.PolicyOptions{ - Auto: ptr.FromValue(false), + Auto: lo.ToPtr(false), } // DNS Config cfg.DNS = &config.DNSOptions{ - Mode: ptr.FromValue(config.DNSModeUDP), + Mode: lo.ToPtr(config.DNSModeUDP), Addr: &net.TCPAddr{IP: net.ParseIP("8.8.8.8"), Port: 53}, - HTTPSURL: ptr.FromValue("https://dns.google/dns-query"), - QType: ptr.FromValue(config.DNSQueryIPv4), - Cache: ptr.FromValue(false), + HTTPSURL: lo.ToPtr("https://dns.google/dns-query"), + QType: lo.ToPtr(config.DNSQueryIPv4), + Cache: lo.ToPtr(false), } logger := zerolog.Nop() resolver := createResolver(logger, cfg) - p, err := createProxy(logger, cfg, resolver) + p, err := createServer(logger, cfg, resolver) require.NoError(t, err) assert.NotNil(t, p) } @@ -77,30 +78,31 @@ func TestCreateProxy_WithPolicy(t *testing.T) { // Server Config cfg.Server = &config.ServerOptions{ + Mode: lo.ToPtr(config.ServerModeHTTP), ListenAddr: &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 0}, - DefaultTTL: ptr.FromValue(uint8(64)), - Timeout: ptr.FromValue(time.Duration(0)), + DefaultTTL: lo.ToPtr(uint8(64)), + Timeout: lo.ToPtr(time.Duration(0)), } // HTTPS Config cfg.HTTPS = &config.HTTPSOptions{ - FakeCount: ptr.FromValue(uint8(0)), + FakeCount: lo.ToPtr(uint8(0)), } // Policy Config with one override cfg.Policy = &config.PolicyOptions{ - Auto: ptr.FromValue(false), + Auto: lo.ToPtr(false), Overrides: []config.Rule{ { - Name: ptr.FromValue("test-rule"), + Name: lo.ToPtr("test-rule"), Match: &config.MatchAttrs{ Domains: []string{"example.com"}, }, DNS: &config.DNSOptions{ - Mode: ptr.FromValue(config.DNSModeSystem), + Mode: lo.ToPtr(config.DNSModeSystem), }, HTTPS: &config.HTTPSOptions{ - Skip: ptr.FromValue(true), + Skip: lo.ToPtr(true), }, }, }, @@ -108,17 +110,17 @@ func TestCreateProxy_WithPolicy(t *testing.T) { // DNS Config cfg.DNS = &config.DNSOptions{ - Mode: ptr.FromValue(config.DNSModeUDP), + Mode: lo.ToPtr(config.DNSModeUDP), Addr: &net.TCPAddr{IP: net.ParseIP("8.8.8.8"), Port: 53}, - HTTPSURL: ptr.FromValue("https://dns.google/dns-query"), - QType: ptr.FromValue(config.DNSQueryIPv4), - Cache: ptr.FromValue(false), + HTTPSURL: lo.ToPtr("https://dns.google/dns-query"), + QType: lo.ToPtr(config.DNSQueryIPv4), + Cache: lo.ToPtr(false), } logger := zerolog.Nop() resolver := createResolver(logger, cfg) - p, err := createProxy(logger, cfg, resolver) + p, err := createServer(logger, cfg, resolver) require.NoError(t, err) assert.NotNil(t, p) } diff --git a/go.mod b/go.mod index c6243528..f6cabd6d 100644 --- a/go.mod +++ b/go.mod @@ -9,21 +9,27 @@ require ( github.com/google/gopacket v1.1.19 github.com/miekg/dns v1.1.61 github.com/rs/zerolog v1.33.0 + github.com/songgao/water v0.0.0-20200317203138-2b4b6d7c09d8 github.com/stretchr/testify v1.11.1 github.com/urfave/cli/v3 v3.6.1 + golang.org/x/net v0.44.0 golang.org/x/sys v0.36.0 + gvisor.dev/gvisor v0.0.0-20251220000015-517913d17844 ) require ( github.com/davecgh/go-spew v1.1.1 // indirect + github.com/google/btree v1.1.2 // indirect github.com/kr/pretty v0.1.0 // indirect github.com/mattn/go-colorable v0.1.13 // indirect github.com/mattn/go-isatty v0.0.19 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect - golang.org/x/mod v0.27.0 // indirect - golang.org/x/net v0.43.0 // indirect + github.com/samber/lo v1.52.0 // indirect + golang.org/x/mod v0.28.0 // indirect golang.org/x/sync v0.17.0 // indirect - golang.org/x/tools v0.36.0 // indirect + golang.org/x/text v0.29.0 // indirect + golang.org/x/time v0.12.0 // indirect + golang.org/x/tools v0.37.0 // indirect gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index 1c71dc55..002d7467 100644 --- a/go.sum +++ b/go.sum @@ -4,8 +4,10 @@ github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSV github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= -github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= -github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/google/btree v1.1.2 h1:xf4v41cLI2Z6FxbKm+8Bu+m8ifhj15JuZ9sa0jZCMUU= +github.com/google/btree v1.1.2/go.mod h1:qOPhT0dTNdNzV6Z/lhRX0YXUafgPLFUh+gZMl761Gm4= +github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= +github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= github.com/google/gopacket v1.1.19 h1:ves8RnFZPGiFnTS0uPQStjwru6uO6h+nlr9j6fL7kF8= github.com/google/gopacket v1.1.19/go.mod h1:iJ8V8n6KS+z2U1A8pUwu8bW5SyEMkXJB8Yo/Vo+TKTo= github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI= @@ -26,6 +28,10 @@ github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZN github.com/rs/xid v1.5.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg= github.com/rs/zerolog v1.33.0 h1:1cU2KZkvPxNyfgEmhHAz/1A9Bz+llsdYzklWFzgp0r8= github.com/rs/zerolog v1.33.0/go.mod h1:/7mN4D5sKwJLZQ2b/znpjC3/GQWY/xaDXUM0kKWRHss= +github.com/samber/lo v1.52.0 h1:Rvi+3BFHES3A8meP33VPAxiBZX/Aws5RxrschYGjomw= +github.com/samber/lo v1.52.0/go.mod h1:4+MXEGsJzbKGaUEQFKBq2xtfuznW9oz/WrgyzMzRoM0= +github.com/songgao/water v0.0.0-20200317203138-2b4b6d7c09d8 h1:TG/diQgUe0pntT/2D9tmUCz4VNwm9MfrtPr0SU2qSX8= +github.com/songgao/water v0.0.0-20200317203138-2b4b6d7c09d8/go.mod h1:P5HUIBuIWKbyjl083/loAegFkfbFNx5i2qEP4CNbm7E= github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= github.com/urfave/cli/v3 v3.6.1 h1:j8Qq8NyUawj/7rTYdBGrxcH7A/j7/G8Q5LhWEW4G3Mo= @@ -34,12 +40,12 @@ golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACk golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/lint v0.0.0-20200302205851-738671d3881b/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY= golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg= -golang.org/x/mod v0.27.0 h1:kb+q2PyFnEADO2IEF935ehFUXlWiNjJWtRNgBLSfbxQ= -golang.org/x/mod v0.27.0/go.mod h1:rWI627Fq0DEoudcK+MBkNkCe0EetEaDSwJJkCcjpazc= +golang.org/x/mod v0.28.0 h1:gQBtGhjxykdjY9YhZpSlZIsbnaE2+PgjfLWUQTnoZ1U= +golang.org/x/mod v0.28.0/go.mod h1:yfB/L0NOf/kmEbXjzCPOx1iK1fRutOydrCMsqRhEBxI= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= -golang.org/x/net v0.43.0 h1:lat02VYK2j4aLzMzecihNvTlJNQUq316m2Mr9rnM6YE= -golang.org/x/net v0.43.0/go.mod h1:vhO1fvI4dGsIjh73sWfUVjj3N7CA9WkKJNQm2svM6Jg= +golang.org/x/net v0.44.0 h1:evd8IRDyfNBMBTTY5XRF1vaZlD+EmWx6x8PkhR04H/I= +golang.org/x/net v0.44.0/go.mod h1:ECOoLqd5U3Lhyeyo/QDCEVQ4sNgYsqvCZ722XogGieY= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.17.0 h1:l60nONMj9l5drqw6jlhIELNv9I0A4OFgRsG9k2oT9Ug= golang.org/x/sync v0.17.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= @@ -51,12 +57,18 @@ golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.36.0 h1:KVRy2GtZBrk1cBYA7MKu5bEZFxQk4NIDV6RLVcC8o0k= golang.org/x/sys v0.36.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/text v0.29.0 h1:1neNs90w9YzJ9BocxfsQNHKuAT4pkghyXc4nhZ6sJvk= +golang.org/x/text v0.29.0/go.mod h1:7MhJOA9CD2qZyOKYazxdYMF85OwPdEr9jTtBpO7ydH4= +golang.org/x/time v0.12.0 h1:ScB/8o8olJvc+CQPWrK3fPZNfh7qgwCrY0zJmoEQLSE= +golang.org/x/time v0.12.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg= golang.org/x/tools v0.0.0-20200130002326-2f3ba24bd6e7/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= -golang.org/x/tools v0.36.0 h1:kWS0uv/zsvHEle1LbV5LE8QujrxB3wfQyxHfhOk0Qkg= -golang.org/x/tools v0.36.0/go.mod h1:WBDiHKJK8YgLHlcQPYQzNCkUxUypCaa5ZegCVutKm+s= +golang.org/x/tools v0.37.0 h1:DVSRzp7FwePZW356yEAChSdNcQo6Nsp+fex1SUW09lE= +golang.org/x/tools v0.37.0/go.mod h1:MBN5QPQtLMHVdvsbtarmTNukZDdgwdwlO5qGacAzF0w= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 h1:YR8cESwS4TdDjEe65xsg0ogRM/Nc3DYOhEAlW+xobZo= gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gvisor.dev/gvisor v0.0.0-20251220000015-517913d17844 h1:7SkRScnij3eBOP12JnH9KnIfAcpPFMWy9KxNHjOXGTM= +gvisor.dev/gvisor v0.0.0-20251220000015-517913d17844/go.mod h1:W1ZgZ/Dh85TgSZWH67l2jKVpDE5bjIaut7rjwwOiHzQ= diff --git a/internal/cache/cache.go b/internal/cache/cache.go index 965e204f..9ae875c3 100644 --- a/internal/cache/cache.go +++ b/internal/cache/cache.go @@ -38,4 +38,8 @@ type Cache interface { Get(key string) (any, bool) // Set adds a value to the cache, applying any provided options. Set(key string, value any, opts *options) bool + Delete(key string) + Range(f func(key string, value any) bool) + // Size returns the number of items in the cache. + Size() int } diff --git a/internal/cache/lru_cache.go b/internal/cache/lru_cache.go index b58cfa95..5db51742 100644 --- a/internal/cache/lru_cache.go +++ b/internal/cache/lru_cache.go @@ -124,3 +124,36 @@ func (c *LRUCache) Set(key string, value any, opts *options) bool { return true } + +// Range iterates over the cache items. +// If f returns false, the item is removed from the cache. +func (c *LRUCache) Range(f func(key string, value any) bool) { + c.mu.Lock() + defer c.mu.Unlock() + + var next *list.Element + for e := c.list.Front(); e != nil; e = next { + next = e.Next() + entry := e.Value.(*lruEntry) + if !f(entry.key, entry.value) { + c.removeElement(e) + } + } +} + +// Delete removes an item from the cache. +func (c *LRUCache) Delete(key string) { + c.mu.Lock() + defer c.mu.Unlock() + + if element, ok := c.cache[key]; ok { + c.removeElement(element) + } +} + +// Size returns the number of items in the cache. +func (c *LRUCache) Size() int { + c.mu.RLock() + defer c.mu.RUnlock() + return c.list.Len() +} diff --git a/internal/cache/ttl_cache.go b/internal/cache/ttl_cache.go index 70d7b10c..87eaa6e1 100644 --- a/internal/cache/ttl_cache.go +++ b/internal/cache/ttl_cache.go @@ -179,3 +179,28 @@ func (c *TTLCache) ForceCleanup() { shard.mu.Unlock() } } + +// Range iterates over the cache items. +// If f returns false, the item is removed from the cache. +func (c *TTLCache) Range(f func(key string, value any) bool) { + for _, shard := range c.shards { + shard.mu.Lock() + for key, i := range shard.items { + if !f(key, i.value) { + delete(shard.items, key) + } + } + shard.mu.Unlock() + } +} + +// Size returns the total number of items across all shards. +func (c *TTLCache) Size() int { + total := 0 + for _, shard := range c.shards { + shard.mu.RLock() + total += len(shard.items) + shard.mu.RUnlock() + } + return total +} diff --git a/internal/config/cli.go b/internal/config/cli.go index c776e513..08197aad 100644 --- a/internal/config/cli.go +++ b/internal/config/cli.go @@ -5,14 +5,15 @@ import ( "fmt" "io" "math" + "net" "os" "path" "strings" "time" + "github.com/samber/lo" "github.com/urfave/cli/v3" "github.com/xvzc/SpoofDPI/internal/proto" - "github.com/xvzc/SpoofDPI/internal/ptr" ) func CreateCommand( @@ -38,7 +39,7 @@ func CreateCommand( &cli.BoolFlag{ Name: "clean", Usage: ` - if set, all configuration files will be ignored (default: %v)`, + if set, all configuration files will be ignored (default: false)`, OnlyOnce: true, }, @@ -61,7 +62,21 @@ func CreateCommand( OnlyOnce: true, Validator: checkUint8NonZero, Action: func(ctx context.Context, cmd *cli.Command, v int64) error { - argsCfg.Server.DefaultTTL = ptr.FromValue(uint8(v)) + argsCfg.Server.DefaultTTL = lo.ToPtr(uint8(v)) + return nil + }, + }, + + &cli.StringFlag{ + Name: "server-mode", + Usage: fmt.Sprintf(`<"http"|"socks5"|"tun"> + Specifies the proxy mode. (default: %q)`, + defaultCfg.Server.Mode.String(), + ), + OnlyOnce: true, + Validator: checkServerMode, + Action: func(ctx context.Context, cmd *cli.Command, v string) error { + argsCfg.Server.Mode = lo.ToPtr(MustParseServerModeType(v)) return nil }, }, @@ -75,7 +90,7 @@ func CreateCommand( OnlyOnce: true, Validator: checkHostPort, Action: func(ctx context.Context, cmd *cli.Command, v string) error { - argsCfg.DNS.Addr = ptr.FromValue(MustParseTCPAddr(v)) + argsCfg.DNS.Addr = lo.ToPtr(MustParseTCPAddr(v)) return nil }, }, @@ -84,12 +99,12 @@ func CreateCommand( Name: "dns-cache", Usage: fmt.Sprintf(` If set, DNS records will be cached. (default: %v)`, - defaultCfg.DNS.Cache, + *defaultCfg.DNS.Cache, ), Value: false, OnlyOnce: true, Action: func(ctx context.Context, cmd *cli.Command, v bool) error { - argsCfg.DNS.Cache = ptr.FromValue(v) + argsCfg.DNS.Cache = lo.ToPtr(v) return nil }, }, @@ -105,7 +120,7 @@ func CreateCommand( OnlyOnce: true, Validator: checkDNSMode, Action: func(ctx context.Context, cmd *cli.Command, v string) error { - argsCfg.DNS.Mode = ptr.FromValue(MustParseDNSModeType(v)) + argsCfg.DNS.Mode = lo.ToPtr(MustParseDNSModeType(v)) return nil }, }, @@ -121,7 +136,7 @@ func CreateCommand( OnlyOnce: true, Validator: checkHTTPSEndpoint, Action: func(ctx context.Context, cmd *cli.Command, v string) error { - argsCfg.DNS.HTTPSURL = ptr.FromValue(v) + argsCfg.DNS.HTTPSURL = lo.ToPtr(v) return nil }, }, @@ -137,7 +152,7 @@ func CreateCommand( OnlyOnce: true, Validator: checkDNSQueryType, Action: func(ctx context.Context, cmd *cli.Command, v string) error { - argsCfg.DNS.QType = ptr.FromValue(MustParseDNSQueryType(v)) + argsCfg.DNS.QType = lo.ToPtr(MustParseDNSQueryType(v)) return nil }, }, @@ -147,13 +162,13 @@ func CreateCommand( Usage: fmt.Sprintf(` Number of fake packets to be sent before the Client Hello. Requires 'https-chunk-size' > 0 for fragmentation. (default: %v)`, - defaultCfg.HTTPS.FakeCount, + *defaultCfg.HTTPS.FakeCount, ), Value: 0, OnlyOnce: true, Validator: checkUint8, Action: func(ctx context.Context, cmd *cli.Command, v int64) error { - argsCfg.HTTPS.FakeCount = ptr.FromValue(uint8(v)) + argsCfg.HTTPS.FakeCount = lo.ToPtr(uint8(v)) return nil }, }, @@ -176,18 +191,18 @@ func CreateCommand( Name: "https-disorder", Usage: fmt.Sprintf(` If set, sends fragmented Client Hello packets out-of-order. (default: %v)`, - defaultCfg.HTTPS.Disorder, + *defaultCfg.HTTPS.Disorder, ), OnlyOnce: true, Action: func(ctx context.Context, cmd *cli.Command, v bool) error { - argsCfg.HTTPS.Disorder = ptr.FromValue(v) + argsCfg.HTTPS.Disorder = lo.ToPtr(v) return nil }, }, &cli.StringFlag{ Name: "https-split-mode", - Usage: fmt.Sprintf(`<"sni"|"random"|"chunk"|"sni"|"none"> + Usage: fmt.Sprintf(`<"sni"|"random"|"chunk"|"sni"|"custom"|"none"> Specifies the default packet fragmentation strategy to use. (default: %q)`, defaultCfg.HTTPS.SplitMode.String(), ), @@ -195,7 +210,7 @@ func CreateCommand( OnlyOnce: true, Validator: checkHTTPSSplitMode, Action: func(ctx context.Context, cmd *cli.Command, v string) error { - argsCfg.HTTPS.SplitMode = ptr.FromValue(mustParseHTTPSSplitModeType(v)) + argsCfg.HTTPS.SplitMode = lo.ToPtr(mustParseHTTPSSplitModeType(v)) return nil }, }, @@ -205,11 +220,11 @@ func CreateCommand( Usage: fmt.Sprintf(` If set, HTTPS traffic will be processed without any DPI bypass techniques. (default: %v)`, - defaultCfg.HTTPS.Skip, + *defaultCfg.HTTPS.Skip, ), OnlyOnce: true, Action: func(ctx context.Context, cmd *cli.Command, v bool) error { - argsCfg.HTTPS.Skip = ptr.FromValue(v) + argsCfg.HTTPS.Skip = lo.ToPtr(v) return nil }, }, @@ -222,29 +237,72 @@ func CreateCommand( disables fragmentation (to avoid division-by-zero errors), you should set 'https-split-default' to 'none' to disable the feature cleanly. (default: %v, max: %v)`, - defaultCfg.HTTPS.ChunkSize, + *defaultCfg.HTTPS.ChunkSize, math.MaxUint8, ), - Value: 0, OnlyOnce: true, Validator: checkUint8NonZero, Action: func(ctx context.Context, cmd *cli.Command, v int64) error { - argsCfg.HTTPS.ChunkSize = ptr.FromValue(uint8(v)) + argsCfg.HTTPS.ChunkSize = lo.ToPtr(uint8(v)) + return nil + }, + }, + + &cli.Int64Flag{ + Name: "udp-fake-count", + Usage: fmt.Sprintf(` + Number of fake packets to be sent. (default: %v)`, + *defaultCfg.UDP.FakeCount, + ), + Value: 0, + OnlyOnce: true, + Validator: int64Range(0, math.MaxInt), + Action: func(ctx context.Context, cmd *cli.Command, v int64) error { + argsCfg.UDP.FakeCount = lo.ToPtr(int(v)) return nil }, }, &cli.StringFlag{ - Name: "listen-addr", + Name: "udp-fake-packet", + Usage: ` + Comma-separated hexadecimal byte array used for fake packet. + (default: built-in fake packet)`, + Value: MustParseHexCSV(defaultCfg.UDP.FakePacket), + OnlyOnce: true, + Validator: checkHexBytesStr, + Action: func(ctx context.Context, cmd *cli.Command, v string) error { + argsCfg.UDP.FakePacket = MustParseBytes(v) + return nil + }, + }, + + &cli.Int64Flag{ + Name: "udp-timeout", Usage: fmt.Sprintf(` - IP address to listen on (default: %v)`, - defaultCfg.Server.ListenAddr.String(), + UDP session timeout in milliseconds. (default: %v)`, + *defaultCfg.UDP.Timeout, ), - Value: "127.0.0.1:8080", + Value: 0, + OnlyOnce: true, + Validator: checkUint16, + Action: func(ctx context.Context, cmd *cli.Command, v int64) error { + argsCfg.UDP.Timeout = lo.ToPtr(time.Duration(v) * time.Millisecond) + return nil + }, + }, + + &cli.StringFlag{ + Name: "listen-addr", + Usage: ` + IP address to listen on (default: 127.0.0.1:8080 for http, or 127.0.0.1:1080 for socks5)`, OnlyOnce: true, Validator: checkHostPort, Action: func(ctx context.Context, cmd *cli.Command, v string) error { - argsCfg.Server.ListenAddr = ptr.FromValue(MustParseTCPAddr(v)) + if v == "" { + return nil + } + argsCfg.Server.ListenAddr = lo.ToPtr(MustParseTCPAddr(v)) return nil }, }, @@ -256,7 +314,7 @@ func CreateCommand( OnlyOnce: true, Validator: checkLogLevel, Action: func(ctx context.Context, cmd *cli.Command, v string) error { - argsCfg.General.LogLevel = ptr.FromValue(MustParseLogLevel(v)) + argsCfg.General.LogLevel = lo.ToPtr(MustParseLogLevel(v)) return nil }, }, @@ -269,7 +327,7 @@ func CreateCommand( ), OnlyOnce: true, Action: func(ctx context.Context, cmd *cli.Command, v bool) error { - argsCfg.Policy.Auto = ptr.FromValue(v) + argsCfg.Policy.Auto = lo.ToPtr(v) return nil }, }, @@ -278,24 +336,24 @@ func CreateCommand( Name: "silent", Usage: fmt.Sprintf(` Do not show the banner at start up (default: %v)`, - defaultCfg.General.Silent, + *defaultCfg.General.Silent, ), OnlyOnce: true, Action: func(ctx context.Context, cmd *cli.Command, v bool) error { - argsCfg.General.Silent = ptr.FromValue(v) + argsCfg.General.Silent = lo.ToPtr(v) return nil }, }, &cli.BoolFlag{ - Name: "system-proxy", + Name: "network-config", Usage: fmt.Sprintf(` Automatically set system-wide proxy configuration (default: %v)`, - defaultCfg.General.SetSystemProxy, + *defaultCfg.General.SetNetworkConfig, ), OnlyOnce: true, Action: func(ctx context.Context, cmd *cli.Command, v bool) error { - argsCfg.General.SetSystemProxy = ptr.FromValue(v) + argsCfg.General.SetNetworkConfig = lo.ToPtr(v) return nil }, }, @@ -305,14 +363,14 @@ func CreateCommand( Usage: fmt.Sprintf(` Timeout for tcp connection in milliseconds. No effect when the value is 0 (default: %v, max: %v)`, - defaultCfg.Server.Timeout, + *defaultCfg.Server.Timeout, math.MaxUint16, ), Value: 0, OnlyOnce: true, Validator: checkUint16, Action: func(ctx context.Context, cmd *cli.Command, v int64) error { - argsCfg.Server.Timeout = ptr.FromValue( + argsCfg.Server.Timeout = lo.ToPtr( time.Duration(v * int64(time.Millisecond)), ) return nil @@ -361,6 +419,17 @@ func CreateCommand( finalCfg := defaultCfg.Merge(tomlCfg.Merge(argsCfg)) + if finalCfg.Server.ListenAddr == nil { + port := 8080 + if *finalCfg.Server.Mode == ServerModeSOCKS5 { + port = 1080 + } + finalCfg.Server.ListenAddr = &net.TCPAddr{ + IP: net.ParseIP("127.0.0.1"), + Port: port, + } + } + runFunc(ctx, strings.Replace(configDir, os.Getenv("HOME"), "~", 1), finalCfg) return nil }, diff --git a/internal/config/cli_test.go b/internal/config/cli_test.go index 9f0ee18c..70ced9eb 100644 --- a/internal/config/cli_test.go +++ b/internal/config/cli_test.go @@ -26,7 +26,7 @@ func TestCreateCommand_Flags(t *testing.T) { // Verify defaults are preserved assert.Equal(t, zerolog.InfoLevel, *cfg.General.LogLevel) assert.False(t, *cfg.General.Silent) - assert.False(t, *cfg.General.SetSystemProxy) + assert.False(t, *cfg.General.SetNetworkConfig) assert.Equal(t, "127.0.0.1:8080", cfg.Server.ListenAddr.String()) assert.Equal(t, uint8(64), *cfg.Server.DefaultTTL) assert.Equal(t, time.Duration(0), *cfg.Server.Timeout) @@ -38,8 +38,10 @@ func TestCreateCommand_Flags(t *testing.T) { assert.Equal(t, uint8(0), *cfg.HTTPS.FakeCount) assert.False(t, *cfg.HTTPS.Disorder) assert.Equal(t, HTTPSSplitModeSNI, *cfg.HTTPS.SplitMode) - assert.Equal(t, uint8(0), *cfg.HTTPS.ChunkSize) + assert.Equal(t, uint8(35), *cfg.HTTPS.ChunkSize) assert.False(t, *cfg.HTTPS.Skip) + assert.Equal(t, 0, *cfg.UDP.FakeCount) + assert.Equal(t, 64, len(cfg.UDP.FakePacket)) assert.False(t, *cfg.Policy.Auto) }, }, @@ -50,7 +52,7 @@ func TestCreateCommand_Flags(t *testing.T) { "--clean", // Ensure no config file interferes "--log-level", "debug", "--silent", - "--system-proxy", + "--network-config", "--listen-addr", "127.0.0.1:9090", "--default-ttl", "128", "--timeout", "5000", @@ -65,13 +67,16 @@ func TestCreateCommand_Flags(t *testing.T) { "--https-split-mode", "chunk", "--https-chunk-size", "50", "--https-skip", + "--udp-fake-count", "5", + "--udp-fake-packet", "0x01, 0x02", + "--udp-timeout", "1000", "--policy-auto", }, assert: func(t *testing.T, cfg *Config) { // General assert.Equal(t, zerolog.DebugLevel, *cfg.General.LogLevel) assert.True(t, *cfg.General.Silent) - assert.True(t, *cfg.General.SetSystemProxy) + assert.True(t, *cfg.General.SetNetworkConfig) // Server assert.Equal(t, "127.0.0.1:9090", cfg.Server.ListenAddr.String()) @@ -93,6 +98,11 @@ func TestCreateCommand_Flags(t *testing.T) { assert.Equal(t, uint8(50), *cfg.HTTPS.ChunkSize) assert.True(t, *cfg.HTTPS.Skip) + // UDP + assert.Equal(t, 5, *cfg.UDP.FakeCount) + assert.Equal(t, []byte{0x01, 0x02}, cfg.UDP.FakePacket) + assert.Equal(t, 1000*time.Millisecond, *cfg.UDP.Timeout) + // Policy assert.True(t, *cfg.Policy.Auto) }, @@ -127,6 +137,18 @@ func TestCreateCommand_Flags(t *testing.T) { assert.True(t, cfg.Server.ListenAddr.IP.Equal(ip)) }, }, + { + name: "socks5 default port", + args: []string{ + "spoofdpi", + "--clean", + "--server-mode", "socks5", + }, + assert: func(t *testing.T, cfg *Config) { + assert.Equal(t, "127.0.0.1:1080", cfg.Server.ListenAddr.String()) + assert.Equal(t, ServerModeSOCKS5, *cfg.Server.Mode) + }, + }, } for _, tc := range tcs { @@ -224,7 +246,7 @@ func TestCreateCommand_OverrideTOML(t *testing.T) { "--config", configPath, "--log-level", "error", "--silent=false", - "--system-proxy=false", + "--network-config=false", "--listen-addr", "127.0.0.1:9090", "--timeout", "2000", "--default-ttl", "200", @@ -239,6 +261,8 @@ func TestCreateCommand_OverrideTOML(t *testing.T) { "--https-split-mode", "sni", "--https-chunk-size", "10", "--https-skip=false", + "--udp-fake-count", "20", + "--udp-fake-packet", "0xcc,0xdd", "--policy-auto=false", } @@ -250,7 +274,7 @@ func TestCreateCommand_OverrideTOML(t *testing.T) { // General assert.Equal(t, zerolog.ErrorLevel, *capturedCfg.General.LogLevel) assert.False(t, *capturedCfg.General.Silent) - assert.False(t, *capturedCfg.General.SetSystemProxy) + assert.False(t, *capturedCfg.General.SetNetworkConfig) // Server assert.Equal(t, "127.0.0.1:9090", capturedCfg.Server.ListenAddr.String()) @@ -272,6 +296,11 @@ func TestCreateCommand_OverrideTOML(t *testing.T) { assert.Equal(t, uint8(10), *capturedCfg.HTTPS.ChunkSize) assert.False(t, *capturedCfg.HTTPS.Skip) + // UDP + assert.Equal(t, 20, *capturedCfg.UDP.FakeCount) + assert.Equal(t, []byte{0xcc, 0xdd}, capturedCfg.UDP.FakePacket) + assert.Equal(t, time.Duration(0), *capturedCfg.UDP.Timeout) + // Policy assert.False(t, *capturedCfg.Policy.Auto) diff --git a/internal/config/config.go b/internal/config/config.go index 757a2877..2f58acb0 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -6,8 +6,8 @@ import ( "time" "github.com/rs/zerolog" + "github.com/samber/lo" "github.com/xvzc/SpoofDPI/internal/proto" - "github.com/xvzc/SpoofDPI/internal/ptr" ) type merger[T any] interface { @@ -26,6 +26,7 @@ type Config struct { Server *ServerOptions `toml:"server"` DNS *DNSOptions `toml:"dns"` HTTPS *HTTPSOptions `toml:"https"` + UDP *UDPOptions `toml:"udp"` Policy *PolicyOptions `toml:"policy"` } @@ -39,6 +40,7 @@ func (c *Config) UnmarshalTOML(data any) (err error) { c.Server = findStructFrom[ServerOptions](m, "server", &err) c.DNS = findStructFrom[DNSOptions](m, "dns", &err) c.HTTPS = findStructFrom[HTTPSOptions](m, "https", &err) + c.UDP = findStructFrom[UDPOptions](m, "udp", &err) c.Policy = findStructFrom[PolicyOptions](m, "policy", &err) return @@ -50,6 +52,7 @@ func NewConfig() *Config { Server: &ServerOptions{}, DNS: &DNSOptions{}, HTTPS: &HTTPSOptions{}, + UDP: &UDPOptions{}, Policy: &PolicyOptions{}, } } @@ -64,6 +67,7 @@ func (c *Config) Clone() *Config { Server: c.Server.Clone(), DNS: c.DNS.Clone(), HTTPS: c.HTTPS.Clone(), + UDP: c.UDP.Clone(), Policy: c.Policy.Clone(), } } @@ -82,6 +86,7 @@ func (origin *Config) Merge(overrides *Config) *Config { Server: origin.Server.Merge(overrides.Server), DNS: origin.DNS.Merge(overrides.DNS), HTTPS: origin.HTTPS.Merge(overrides.HTTPS), + UDP: origin.UDP.Merge(overrides.UDP), Policy: origin.Policy.Merge(overrides.Policy), } } @@ -91,13 +96,20 @@ func (c *Config) ShouldEnablePcap() bool { return true } + if c.UDP != nil && c.UDP.FakeCount != nil && *c.UDP.FakeCount > 0 { + return true + } + if c.Policy == nil { return false } if c.Policy.Template != nil { template := c.Policy.Template - if template.HTTPS != nil && ptr.FromPtr(template.HTTPS.FakeCount) > 0 { + if template.HTTPS != nil && lo.FromPtr(template.HTTPS.FakeCount) > 0 { + return true + } + if template.UDP != nil && lo.FromPtr(template.UDP.FakeCount) > 0 { return true } } @@ -105,15 +117,11 @@ func (c *Config) ShouldEnablePcap() bool { if c.Policy.Overrides != nil { rules := c.Policy.Overrides for _, r := range rules { - if r.HTTPS == nil { - continue - } - - if r.HTTPS.FakeCount == nil { - continue + if r.HTTPS != nil && r.HTTPS.FakeCount != nil && *r.HTTPS.FakeCount > 0 { + return true } - if *r.HTTPS.FakeCount > 0 { + if r.UDP != nil && r.UDP.FakeCount != nil && *r.UDP.FakeCount > 0 { return true } } @@ -125,32 +133,39 @@ func (c *Config) ShouldEnablePcap() bool { func getDefault() *Config { //exhaustruct:enforce return &Config{ General: &GeneralOptions{ - LogLevel: ptr.FromValue(zerolog.InfoLevel), - Silent: ptr.FromValue(false), - SetSystemProxy: ptr.FromValue(false), + LogLevel: lo.ToPtr(zerolog.InfoLevel), + Silent: lo.ToPtr(false), + SetNetworkConfig: lo.ToPtr(false), }, Server: &ServerOptions{ - DefaultTTL: ptr.FromValue(uint8(64)), - ListenAddr: &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 8080, Zone: ""}, - Timeout: ptr.FromValue(time.Duration(0)), + Mode: lo.ToPtr(ServerModeHTTP), + DefaultTTL: lo.ToPtr(uint8(64)), + ListenAddr: nil, + Timeout: lo.ToPtr(time.Duration(0)), }, DNS: &DNSOptions{ - Mode: ptr.FromValue(DNSModeUDP), + Mode: lo.ToPtr(DNSModeUDP), Addr: &net.TCPAddr{IP: net.ParseIP("8.8.8.8"), Port: 53, Zone: ""}, - HTTPSURL: ptr.FromValue("https://dns.google/dns-query"), - QType: ptr.FromValue(DNSQueryIPv4), - Cache: ptr.FromValue(false), + HTTPSURL: lo.ToPtr("https://dns.google/dns-query"), + QType: lo.ToPtr(DNSQueryIPv4), + Cache: lo.ToPtr(false), }, HTTPS: &HTTPSOptions{ - Disorder: ptr.FromValue(false), - FakeCount: ptr.FromValue(uint8(0)), - FakePacket: proto.NewFakeTLSMessage([]byte(FakeClientHello)), - SplitMode: ptr.FromValue(HTTPSSplitModeSNI), - ChunkSize: ptr.FromValue(uint8(0)), - Skip: ptr.FromValue(false), + Disorder: lo.ToPtr(false), + FakeCount: lo.ToPtr(uint8(0)), + FakePacket: proto.NewFakeTLSMessage([]byte(FakeClientHello)), + SplitMode: lo.ToPtr(HTTPSSplitModeSNI), + ChunkSize: lo.ToPtr(uint8(35)), + CustomSegmentPlans: []SegmentPlan{}, + Skip: lo.ToPtr(false), + }, + UDP: &UDPOptions{ + FakeCount: lo.ToPtr(0), + FakePacket: make([]byte, 64), + Timeout: lo.ToPtr(time.Duration(0)), }, Policy: &PolicyOptions{ - Auto: ptr.FromValue(false), + Auto: lo.ToPtr(false), Template: &Rule{}, Overrides: []Rule{}, }, diff --git a/internal/config/config_test.go b/internal/config/config_test.go index c148caaf..dcaca3ef 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -4,8 +4,8 @@ import ( "net" "testing" + "github.com/samber/lo" "github.com/stretchr/testify/assert" - "github.com/xvzc/SpoofDPI/internal/ptr" ) func TestConfig_UnmarshalTOML(t *testing.T) { @@ -89,7 +89,7 @@ func TestConfig_ShouldEnablePcap(t *testing.T) { name: "global fake count > 0", config: Config{ HTTPS: &HTTPSOptions{ - FakeCount: ptr.FromValue(uint8(1)), + FakeCount: lo.ToPtr(uint8(1)), }, }, expect: true, @@ -98,13 +98,13 @@ func TestConfig_ShouldEnablePcap(t *testing.T) { name: "rule fake count > 0", config: Config{ HTTPS: &HTTPSOptions{ - FakeCount: ptr.FromValue(uint8(0)), + FakeCount: lo.ToPtr(uint8(0)), }, Policy: &PolicyOptions{ Overrides: []Rule{ { HTTPS: &HTTPSOptions{ - FakeCount: ptr.FromValue(uint8(1)), + FakeCount: lo.ToPtr(uint8(1)), }, }, }, @@ -116,13 +116,13 @@ func TestConfig_ShouldEnablePcap(t *testing.T) { name: "none", config: Config{ HTTPS: &HTTPSOptions{ - FakeCount: ptr.FromValue(uint8(0)), + FakeCount: lo.ToPtr(uint8(0)), }, Policy: &PolicyOptions{ Overrides: []Rule{ { HTTPS: &HTTPSOptions{ - FakeCount: ptr.FromValue(uint8(0)), + FakeCount: lo.ToPtr(uint8(0)), }, }, }, @@ -162,6 +162,18 @@ func TestConfig_Merge(t *testing.T) { assert.Equal(t, "127.0.0.1:8080", merged.Server.ListenAddr.String()) }, }, + { + name: "default udp fake packet", + tomlCfg: &Config{}, + argsCfg: &Config{}, + assert: func(t *testing.T, merged *Config) { + defaultCfg := getDefault() + assert.Equal(t, 64, len(defaultCfg.UDP.FakePacket)) + for _, b := range defaultCfg.UDP.FakePacket { + assert.Equal(t, byte(0), b) + } + }, + }, } for _, tc := range tcs { diff --git a/internal/config/parse.go b/internal/config/parse.go index 7d2da219..5f3ccd79 100644 --- a/internal/config/parse.go +++ b/internal/config/parse.go @@ -108,6 +108,19 @@ func MustParseLogLevel(s string) zerolog.Level { return level } +func MustParseServerModeType(s string) ServerModeType { + switch s { + case "http": + return ServerModeHTTP + case "socks5": + return ServerModeSOCKS5 + case "tun": + return ServerModeTUN + default: + panic(fmt.Sprintf("cannot parse %q to ServerModeType", s)) + } +} + func MustParseDNSModeType(s string) DNSModeType { switch s { case "udp": @@ -146,11 +159,24 @@ func mustParseHTTPSSplitModeType(s string) HTTPSSplitModeType { return HTTPSSplitModeFirstByte case "none": return HTTPSSplitModeNone + case "custom": + return HTTPSSplitModeCustom default: panic(fmt.Sprintf("cannot parse %q to HTTPSSplitModeType", s)) } } +func mustParseSegmentFromType(s string) SegmentFromType { + switch s { + case "sni": + return SegmentFromSNI + case "head": + return SegmentFromHead + default: + panic(fmt.Sprintf("cannot parse %q to SegmentFromType", s)) + } +} + type Integer interface { ~int | ~int8 | ~int16 | ~int32 | ~int64 | ~uint | ~uint8 | ~uint16 | ~uint32 | ~uint64 | ~uintptr diff --git a/internal/config/segment_test.go b/internal/config/segment_test.go new file mode 100644 index 00000000..d912156b --- /dev/null +++ b/internal/config/segment_test.go @@ -0,0 +1 @@ +package config diff --git a/internal/config/toml.go b/internal/config/toml.go index d2c5c498..9ada3adb 100644 --- a/internal/config/toml.go +++ b/internal/config/toml.go @@ -5,7 +5,7 @@ import ( "os" "github.com/BurntSushi/toml" - "github.com/xvzc/SpoofDPI/internal/ptr" + "github.com/samber/lo" ) func fromTomlFile(dir string) (*Config, error) { @@ -65,7 +65,7 @@ func findFrom[T any]( return nil } - return ptr.FromValue(val) + return lo.ToPtr(val) } func findStructFrom[T any, PT interface { diff --git a/internal/config/toml_test.go b/internal/config/toml_test.go index e97b1d84..4e47de25 100644 --- a/internal/config/toml_test.go +++ b/internal/config/toml_test.go @@ -414,11 +414,10 @@ func TestFindSliceFrom(t *testing.T) { func TestFromTomlFile(t *testing.T) { t.Run("full valid config", func(t *testing.T) { tomlContent := ` - [general] - log-level = "debug" - silent = true - system-proxy = true - + [general] + log-level = "debug" + silent = true + network-config = true [server] listen-addr = "127.0.0.1:8080" timeout = 1000 @@ -486,7 +485,7 @@ func TestFromTomlFile(t *testing.T) { assert.Equal(t, time.Duration(1000*time.Millisecond), *cfg.Server.Timeout) assert.Equal(t, zerolog.DebugLevel, *cfg.General.LogLevel) assert.True(t, *cfg.General.Silent) - assert.True(t, *cfg.General.SetSystemProxy) + assert.True(t, *cfg.General.SetNetworkConfig) assert.Equal(t, "8.8.8.8:53", cfg.DNS.Addr.String()) assert.True(t, *cfg.DNS.Cache) assert.Equal(t, DNSModeHTTPS, *cfg.DNS.Mode) diff --git a/internal/config/types.go b/internal/config/types.go index 5b613c20..90dd1772 100644 --- a/internal/config/types.go +++ b/internal/config/types.go @@ -1,27 +1,45 @@ package config import ( + "encoding/json" "fmt" + "math" "net" + "slices" "strings" "time" "github.com/rs/zerolog" + "github.com/samber/lo" "github.com/xvzc/SpoofDPI/internal/proto" - "github.com/xvzc/SpoofDPI/internal/ptr" ) +type primitive interface { + ~bool | ~string | + ~int | ~int8 | ~int16 | ~int32 | ~int64 | + ~uint | ~uint8 | ~uint16 | ~uint32 | ~uint64 | ~uintptr | + ~float32 | ~float64 | + ~complex64 | ~complex128 +} + +func clonePrimitive[T primitive](x *T) *T { + if x == nil { + return nil + } + return lo.ToPtr(lo.FromPtr(x)) +} + // ┌─────────────────┐ // │ GENERAL OPTIONS │ // └─────────────────┘ var _ merger[*GeneralOptions] = (*GeneralOptions)(nil) -var availableLogLevels = []string{"info", "warn", "trace", "error", "debug"} +var availableLogLevelValues = []string{"info", "warn", "trace", "error", "debug"} type GeneralOptions struct { - LogLevel *zerolog.Level `toml:"log-level"` - Silent *bool `toml:"silent"` - SetSystemProxy *bool `toml:"system-proxy"` + LogLevel *zerolog.Level `toml:"log-level"` + Silent *bool `toml:"silent"` + SetNetworkConfig *bool `toml:"network-config"` } func (o *GeneralOptions) UnmarshalTOML(data any) (err error) { @@ -31,9 +49,9 @@ func (o *GeneralOptions) UnmarshalTOML(data any) (err error) { } o.Silent = findFrom(m, "silent", parseBoolFn(), &err) - o.SetSystemProxy = findFrom(m, "system-proxy", parseBoolFn(), &err) + o.SetNetworkConfig = findFrom(m, "network-config", parseBoolFn(), &err) if p := findFrom(m, "log-level", parseStringFn(checkLogLevel), &err); isOk(p, err) { - o.LogLevel = ptr.FromValue(MustParseLogLevel(*p)) + o.LogLevel = lo.ToPtr(MustParseLogLevel(*p)) } return err @@ -46,13 +64,13 @@ func (o *GeneralOptions) Clone() *GeneralOptions { var newLevel *zerolog.Level if o.LogLevel != nil { - newLevel = ptr.FromValue(MustParseLogLevel(strings.ToLower(o.LogLevel.String()))) + newLevel = lo.ToPtr(MustParseLogLevel(strings.ToLower(o.LogLevel.String()))) } return &GeneralOptions{ - LogLevel: newLevel, - Silent: ptr.Clone(o.Silent), - SetSystemProxy: ptr.Clone(o.SetSystemProxy), + LogLevel: newLevel, + Silent: clonePrimitive(o.Silent), + SetNetworkConfig: clonePrimitive(o.SetNetworkConfig), } } @@ -66,9 +84,12 @@ func (origin *GeneralOptions) Merge(overrides *GeneralOptions) *GeneralOptions { } return &GeneralOptions{ - LogLevel: ptr.CloneOr(overrides.LogLevel, origin.LogLevel), - Silent: ptr.CloneOr(overrides.Silent, origin.Silent), - SetSystemProxy: ptr.CloneOr(overrides.SetSystemProxy, origin.SetSystemProxy), + LogLevel: lo.CoalesceOrEmpty(overrides.LogLevel, origin.LogLevel), + Silent: lo.CoalesceOrEmpty(overrides.Silent, origin.Silent), + SetNetworkConfig: lo.CoalesceOrEmpty( + overrides.SetNetworkConfig, + origin.SetNetworkConfig, + ), } } @@ -77,10 +98,25 @@ func (origin *GeneralOptions) Merge(overrides *GeneralOptions) *GeneralOptions { // └────────────────┘ var _ merger[*ServerOptions] = (*ServerOptions)(nil) +type ServerModeType int + +const ( + ServerModeHTTP ServerModeType = iota + ServerModeSOCKS5 + ServerModeTUN +) + +var availableServerModeValues = []string{"http", "socks5", "tun"} + +func (t ServerModeType) String() string { + return availableServerModeValues[t] +} + type ServerOptions struct { - DefaultTTL *uint8 `toml:"default-ttl"` - ListenAddr *net.TCPAddr `toml:"listen-addr"` - Timeout *time.Duration `toml:"timeout"` + Mode *ServerModeType `toml:"mode"` + DefaultTTL *uint8 `toml:"default-ttl"` + ListenAddr *net.TCPAddr `toml:"listen-addr"` + Timeout *time.Duration `toml:"timeout"` } func (o *ServerOptions) UnmarshalTOML(data any) (err error) { @@ -89,14 +125,18 @@ func (o *ServerOptions) UnmarshalTOML(data any) (err error) { return fmt.Errorf("non-table type server config") } + if p := findFrom(v, "mode", parseStringFn(checkServerMode), &err); isOk(p, err) { + o.Mode = lo.ToPtr(MustParseServerModeType(*p)) + } + o.DefaultTTL = findFrom(v, "default-ttl", parseIntFn[uint8](checkUint8NonZero), &err) if p := findFrom(v, "listen-addr", parseStringFn(checkHostPort), &err); isOk(p, err) { - o.ListenAddr = ptr.FromValue(MustParseTCPAddr(*p)) + o.ListenAddr = lo.ToPtr(MustParseTCPAddr(*p)) } if p := findFrom(v, "timeout", parseIntFn[uint16](checkUint16), &err); isOk(p, err) { - o.Timeout = ptr.FromValue(time.Duration(*p) * time.Millisecond) + o.Timeout = lo.ToPtr(time.Duration(*p) * time.Millisecond) } return err @@ -117,9 +157,10 @@ func (o *ServerOptions) Clone() *ServerOptions { } return &ServerOptions{ - DefaultTTL: ptr.Clone(o.DefaultTTL), + Mode: clonePrimitive(o.Mode), + DefaultTTL: clonePrimitive(o.DefaultTTL), ListenAddr: newAddr, - Timeout: ptr.Clone(o.Timeout), + Timeout: clonePrimitive(o.Timeout), } } @@ -133,9 +174,10 @@ func (origin *ServerOptions) Merge(overrides *ServerOptions) *ServerOptions { } return &ServerOptions{ - DefaultTTL: ptr.CloneOr(overrides.DefaultTTL, origin.DefaultTTL), - ListenAddr: ptr.CloneOr(overrides.ListenAddr, origin.ListenAddr), - Timeout: ptr.CloneOr(overrides.Timeout, origin.Timeout), + Mode: lo.CoalesceOrEmpty(overrides.Mode, origin.Mode), + DefaultTTL: lo.CoalesceOrEmpty(overrides.DefaultTTL, origin.DefaultTTL), + ListenAddr: lo.CoalesceOrEmpty(overrides.ListenAddr, origin.ListenAddr), + Timeout: lo.CoalesceOrEmpty(overrides.Timeout, origin.Timeout), } } @@ -150,8 +192,8 @@ type ( ) var ( - availableDNSModes = []string{"udp", "https", "system"} - availableDNSQueries = []string{"ipv4", "ipv6", "all"} + availableDNSModeValues = []string{"udp", "https", "system"} + availableDNSQueryValues = []string{"ipv4", "ipv6", "all"} ) const ( @@ -167,11 +209,11 @@ const ( ) func (t DNSModeType) String() string { - return availableDNSModes[t] + return availableDNSModeValues[t] } func (t DNSQueryType) String() string { - return availableDNSQueries[t] + return availableDNSQueryValues[t] } type DNSOptions struct { @@ -189,17 +231,17 @@ func (o *DNSOptions) UnmarshalTOML(data any) (err error) { } if p := findFrom(m, "mode", parseStringFn(checkDNSMode), &err); isOk(p, err) { - o.Mode = ptr.FromValue(MustParseDNSModeType(*p)) + o.Mode = lo.ToPtr(MustParseDNSModeType(*p)) } if p := findFrom(m, "addr", parseStringFn(checkHostPort), &err); isOk(p, err) { - o.Addr = ptr.FromValue(MustParseTCPAddr(*p)) + o.Addr = lo.ToPtr(MustParseTCPAddr(*p)) } o.HTTPSURL = findFrom(m, "https-url", parseStringFn(checkHTTPSEndpoint), &err) if p := findFrom(m, "qtype", parseStringFn(checkDNSQueryType), &err); isOk(p, err) { - o.QType = ptr.FromValue(MustParseDNSQueryType(*p)) + o.QType = lo.ToPtr(MustParseDNSQueryType(*p)) } o.Cache = findFrom(m, "cache", parseBoolFn(), &err) @@ -222,11 +264,11 @@ func (o *DNSOptions) Clone() *DNSOptions { } return &DNSOptions{ - Mode: ptr.Clone(o.Mode), + Mode: clonePrimitive(o.Mode), Addr: newAddr, - HTTPSURL: ptr.Clone(o.HTTPSURL), - QType: ptr.Clone(o.QType), - Cache: ptr.Clone(o.Cache), + HTTPSURL: clonePrimitive(o.HTTPSURL), + QType: clonePrimitive(o.QType), + Cache: clonePrimitive(o.Cache), } } @@ -240,11 +282,11 @@ func (origin *DNSOptions) Merge(overrides *DNSOptions) *DNSOptions { } return &DNSOptions{ - Mode: ptr.CloneOr(overrides.Mode, origin.Mode), - Addr: ptr.CloneOr(overrides.Addr, origin.Addr), - HTTPSURL: ptr.CloneOr(overrides.HTTPSURL, origin.HTTPSURL), - QType: ptr.CloneOr(overrides.QType, origin.QType), - Cache: ptr.CloneOr(overrides.Cache, origin.Cache), + Mode: lo.CoalesceOrEmpty(overrides.Mode, origin.Mode), + Addr: lo.CoalesceOrEmpty(overrides.Addr, origin.Addr), + HTTPSURL: lo.CoalesceOrEmpty(overrides.HTTPSURL, origin.HTTPSURL), + QType: lo.CoalesceOrEmpty(overrides.QType, origin.QType), + Cache: lo.CoalesceOrEmpty(overrides.Cache, origin.Cache), } } @@ -294,7 +336,14 @@ const FakeClientHello = "" + type HTTPSSplitModeType int -var availableHTTPSModes = []string{"sni", "random", "chunk", "first-byte", "none"} +var availableHTTPSModeValues = []string{ + "sni", + "random", + "chunk", + "first-byte", + "none", + "custom", +} const ( HTTPSSplitModeSNI HTTPSSplitModeType = iota @@ -302,19 +351,85 @@ const ( HTTPSSplitModeChunk HTTPSSplitModeFirstByte HTTPSSplitModeNone + HTTPSSplitModeCustom ) func (k HTTPSSplitModeType) String() string { - return availableHTTPSModes[k] + return availableHTTPSModeValues[k] +} + +type SegmentFromType int + +var availableSegmentFromValues = []string{"head", "sni"} + +const ( + SegmentFromHead SegmentFromType = iota + SegmentFromSNI +) + +func (s SegmentFromType) String() string { + return availableSegmentFromValues[s] +} + +type SegmentPlan struct { + From SegmentFromType `toml:"from"` + At int `toml:"at"` + Lazy bool `toml:"lazy"` + Noise int `toml:"noise"` +} + +func (s *SegmentPlan) UnmarshalTOML(data any) (err error) { + m, ok := data.(map[string]any) + if !ok { + return fmt.Errorf("segment must be table type") + } + + if _, ok := m["from"]; !ok { + return fmt.Errorf("field 'from' is required") + } + if p := findFrom(m, "from", parseStringFn(checkSegmentFrom), &err); isOk(p, err) { + s.From = mustParseSegmentFromType(*p) + } + + if _, ok := m["at"]; !ok { + return fmt.Errorf("field 'at' is required") + } + if p := findFrom(m, "at", parseIntFn[int](nil), &err); isOk(p, err) { + s.At = *p + } + + if p := findFrom(m, "lazy", parseBoolFn(), &err); isOk(p, err) { + s.Lazy = *p + } + + if p := findFrom(m, "noise", parseIntFn[int](nil), &err); isOk(p, err) { + s.Noise = *p + } + + return err +} + +func (s *SegmentPlan) Clone() *SegmentPlan { + if s == nil { + return nil + } + + return &SegmentPlan{ + From: s.From, + At: s.At, + Lazy: s.Lazy, + Noise: s.Noise, + } } type HTTPSOptions struct { - Disorder *bool `toml:"disorder" json:"ds,omitempty"` - FakeCount *uint8 `toml:"fake-count" json:"fc,omitempty"` - FakePacket *proto.TLSMessage `toml:"fake-packet" json:"fp,omitempty"` - SplitMode *HTTPSSplitModeType `toml:"split-mode" json:"sm,omitempty"` - ChunkSize *uint8 `toml:"chunk-size" json:"cs,omitempty"` - Skip *bool `toml:"skip" json:"sk,omitempty"` + Disorder *bool `toml:"disorder" json:"ds,omitempty"` + FakeCount *uint8 `toml:"fake-count" json:"fc,omitempty"` + FakePacket *proto.TLSMessage `toml:"fake-packet" json:"fp,omitempty"` + SplitMode *HTTPSSplitModeType `toml:"split-mode" json:"sm,omitempty"` + ChunkSize *uint8 `toml:"chunk-size" json:"cs,omitempty"` + Skip *bool `toml:"skip" json:"sk,omitempty"` + CustomSegmentPlans []SegmentPlan `toml:"custom-segments" json:"cseg,omitempty"` } func (o *HTTPSOptions) UnmarshalTOML(data any) (err error) { @@ -333,16 +448,22 @@ func (o *HTTPSOptions) UnmarshalTOML(data any) (err error) { splitModeParser := parseStringFn(checkHTTPSSplitMode) if p := findFrom(m, "split-mode", splitModeParser, &err); isOk(p, err) { - o.SplitMode = ptr.FromValue(mustParseHTTPSSplitModeType(*p)) + o.SplitMode = lo.ToPtr(mustParseHTTPSSplitModeType(*p)) } o.ChunkSize = findFrom(m, "chunk-size", parseIntFn[uint8](checkUint8NonZero), &err) o.Skip = findFrom(m, "skip", parseBoolFn(), &err) if o.Skip == nil { - o.Skip = ptr.FromValue(false) + o.Skip = lo.ToPtr(false) + } + + o.CustomSegmentPlans = findStructSliceFrom[SegmentPlan](m, "custom-segments", &err) + if err == nil && o.SplitMode != nil && *o.SplitMode == HTTPSSplitModeCustom && + len(o.CustomSegmentPlans) == 0 { + err = fmt.Errorf("custom-segments must be provided when split-mode is 'custom'") } - return nil + return err } func (o *HTTPSOptions) Clone() *HTTPSOptions { @@ -355,32 +476,107 @@ func (o *HTTPSOptions) Clone() *HTTPSOptions { fakePacket = proto.NewFakeTLSMessage(o.FakePacket.Raw()) } + var customSegmentPlans []SegmentPlan + if o.CustomSegmentPlans != nil { + customSegmentPlans = make([]SegmentPlan, 0, len(o.CustomSegmentPlans)) + for _, s := range o.CustomSegmentPlans { + customSegmentPlans = append(customSegmentPlans, *s.Clone()) + } + } + return &HTTPSOptions{ - Disorder: ptr.Clone(o.Disorder), - FakeCount: ptr.Clone(o.FakeCount), - FakePacket: fakePacket, - SplitMode: ptr.Clone(o.SplitMode), - ChunkSize: ptr.Clone(o.ChunkSize), - Skip: ptr.Clone(o.Skip), + Disorder: clonePrimitive(o.Disorder), + FakeCount: clonePrimitive(o.FakeCount), + FakePacket: fakePacket, + SplitMode: clonePrimitive(o.SplitMode), + ChunkSize: clonePrimitive(o.ChunkSize), + Skip: clonePrimitive(o.Skip), + CustomSegmentPlans: customSegmentPlans, } } func (origin *HTTPSOptions) Merge(overrides *HTTPSOptions) *HTTPSOptions { if overrides == nil { - return origin + return origin.Clone() } if origin == nil { - return overrides + return overrides.Clone() } return &HTTPSOptions{ - Disorder: ptr.CloneOr(overrides.Disorder, origin.Disorder), - FakeCount: ptr.CloneOr(overrides.FakeCount, origin.FakeCount), - FakePacket: ptr.CloneOr(overrides.FakePacket, origin.FakePacket), - SplitMode: ptr.CloneOr(overrides.SplitMode, origin.SplitMode), - ChunkSize: ptr.CloneOr(overrides.ChunkSize, origin.ChunkSize), - Skip: ptr.CloneOr(overrides.Skip, origin.Skip), + Disorder: lo.CoalesceOrEmpty(overrides.Disorder, origin.Disorder), + FakeCount: lo.CoalesceOrEmpty(overrides.FakeCount, origin.FakeCount), + FakePacket: lo.CoalesceOrEmpty(overrides.FakePacket, origin.FakePacket), + SplitMode: lo.CoalesceOrEmpty(overrides.SplitMode, origin.SplitMode), + ChunkSize: lo.CoalesceOrEmpty(overrides.ChunkSize, origin.ChunkSize), + Skip: lo.CoalesceOrEmpty(overrides.Skip, origin.Skip), + CustomSegmentPlans: lo.CoalesceSliceOrEmpty( + origin.CustomSegmentPlans, + overrides.CustomSegmentPlans, + ), + } +} + +// ┌─────────────┐ +// │ UDP OPTIONS │ +// └─────────────┘ +var _ merger[*UDPOptions] = (*UDPOptions)(nil) + +type UDPOptions struct { + FakeCount *int `toml:"fake-count" json:"fc,omitempty"` + FakePacket []byte `toml:"fake-packet" json:"fp,omitempty"` + Timeout *time.Duration `toml:"timeout" json:"to,omitempty"` +} + +func (o *UDPOptions) UnmarshalTOML(data any) (err error) { + m, ok := data.(map[string]any) + if !ok { + return fmt.Errorf("'udp' must be table type") + } + + o.FakeCount = findFrom( + m, "fake-count", parseIntFn[int](int64Range(0, math.MaxInt64)), &err, + ) + o.FakePacket = findSliceFrom(m, "fake-packet", parseByteFn(nil), &err) + + if p := findFrom(m, "timeout", parseIntFn[uint16](checkUint16), &err); isOk(p, err) { + o.Timeout = lo.ToPtr(time.Duration(*p) * time.Millisecond) + } + + return err +} + +func (o *UDPOptions) Clone() *UDPOptions { + if o == nil { + return nil + } + + return &UDPOptions{ + FakeCount: clonePrimitive(o.FakeCount), + FakePacket: append([]byte(nil), o.FakePacket...), + Timeout: clonePrimitive(o.Timeout), + } +} + +func (origin *UDPOptions) Merge(overrides *UDPOptions) *UDPOptions { + if overrides == nil { + return origin.Clone() + } + + if origin == nil { + return overrides.Clone() + } + + fakePacket := origin.FakePacket + if len(overrides.FakePacket) > 0 { + fakePacket = overrides.FakePacket + } + + return &UDPOptions{ + FakeCount: lo.CoalesceOrEmpty(overrides.FakeCount, origin.FakeCount), + FakePacket: fakePacket, + Timeout: lo.CoalesceOrEmpty(overrides.Timeout, origin.Timeout), } } @@ -423,7 +619,7 @@ func (o *PolicyOptions) Clone() *PolicyOptions { } return &PolicyOptions{ - Auto: ptr.Clone(o.Auto), + Auto: clonePrimitive(o.Auto), Template: o.Template.Clone(), Overrides: overrides, } @@ -438,20 +634,11 @@ func (origin *PolicyOptions) Merge(overrides *PolicyOptions) *PolicyOptions { return overrides.Clone() } - overridesCopy := overrides.Clone() - - merged := origin.Clone() - merged.Auto = ptr.CloneOr(overrides.Auto, origin.Auto) - - if overridesCopy.Template != nil { - merged.Template = overridesCopy.Template - } - - if overridesCopy.Overrides != nil { - merged.Overrides = append(merged.Overrides, overridesCopy.Overrides...) + return &PolicyOptions{ + Auto: lo.CoalesceOrEmpty(overrides.Auto, origin.Auto), + Template: lo.CoalesceOrEmpty(overrides.Template.Clone(), origin.Template.Clone()), + Overrides: lo.CoalesceSliceOrEmpty(overrides.Overrides, origin.Overrides), } - - return merged } type AddrMatch struct { @@ -467,12 +654,12 @@ func (a *AddrMatch) UnmarshalTOML(data any) (err error) { } if p := findFrom(v, "cidr", parseStringFn(checkCIDR), &err); isOk(p, err) { - a.CIDR = ptr.FromValue(MustParseCIDR(*p)) + a.CIDR = lo.ToPtr(MustParseCIDR(*p)) } if p := findFrom(v, "port", parseStringFn(checkPortRange), &err); isOk(p, err) { portFrom, portTo := MustParsePortRange(*p) - a.PortFrom, a.PortTo = ptr.FromValue(portFrom), ptr.FromValue(portTo) + a.PortFrom, a.PortTo = lo.ToPtr(portFrom), lo.ToPtr(portTo) } return err @@ -482,10 +669,19 @@ func (a *AddrMatch) Clone() *AddrMatch { if a == nil { return nil } + + var cidr *net.IPNet + if a.CIDR != nil { + cidr = &net.IPNet{ + IP: slices.Clone(a.CIDR.IP), + Mask: slices.Clone(a.CIDR.Mask), + } + } + return &AddrMatch{ - CIDR: ptr.Clone(a.CIDR), - PortFrom: ptr.Clone(a.PortFrom), - PortTo: ptr.Clone(a.PortTo), + CIDR: cidr, + PortFrom: clonePrimitive(a.PortFrom), + PortTo: clonePrimitive(a.PortTo), } } @@ -521,7 +717,7 @@ func (a *MatchAttrs) Clone() *MatchAttrs { } return &MatchAttrs{ - Domains: ptr.CloneSlice(a.Domains), + Domains: lo.CoalesceSliceOrEmpty(a.Domains), Addrs: addrs, } } @@ -533,6 +729,7 @@ type Rule struct { Match *MatchAttrs `toml:"match" json:"mt,omitempty"` DNS *DNSOptions `toml:"dns-override" json:"D,omitempty"` HTTPS *HTTPSOptions `toml:"https-override" json:"H,omitempty"` + UDP *UDPOptions `toml:"udp-override" json:"U,omitempty"` } func (r *Rule) UnmarshalTOML(data any) (err error) { @@ -547,6 +744,7 @@ func (r *Rule) UnmarshalTOML(data any) (err error) { r.Match = findStructFrom[MatchAttrs](m, "match", &err) r.DNS = findStructFrom[DNSOptions](m, "dns", &err) r.HTTPS = findStructFrom[HTTPSOptions](m, "https", &err) + r.UDP = findStructFrom[UDPOptions](m, "udp", &err) // if err == nil { // err = checkRule(*r) @@ -560,11 +758,35 @@ func (r *Rule) Clone() *Rule { return nil } return &Rule{ - Name: ptr.Clone(r.Name), - Priority: ptr.Clone(r.Priority), - Block: ptr.Clone(r.Block), - Match: ptr.Clone(r.Match), + Name: clonePrimitive(r.Name), + Priority: clonePrimitive(r.Priority), + Block: clonePrimitive(r.Block), + Match: r.Match.Clone(), DNS: r.DNS.Clone(), HTTPS: r.HTTPS.Clone(), + UDP: r.UDP.Clone(), } } + +func (r *Rule) JSON() []byte { + data := map[string]any{ + "name": r.Name, + "priority": r.Priority, + } + + if r.Match == nil { + data["match"] = nil + } else { + m := map[string]any{} + if r.Match.Addrs != nil { + m["addr"] = fmt.Sprintf("%v items", len(r.Match.Addrs)) + } + if r.Match.Domains != nil { + m["domain"] = fmt.Sprintf("%v items", len(r.Match.Domains)) + } + data["match"] = m + } + + bytes, _ := json.Marshal(data) + return bytes +} diff --git a/internal/config/types_test.go b/internal/config/types_test.go index 5e2a1498..6848f967 100644 --- a/internal/config/types_test.go +++ b/internal/config/types_test.go @@ -5,10 +5,11 @@ import ( "testing" "time" + "github.com/BurntSushi/toml" "github.com/rs/zerolog" + "github.com/samber/lo" "github.com/stretchr/testify/assert" "github.com/xvzc/SpoofDPI/internal/proto" - "github.com/xvzc/SpoofDPI/internal/ptr" ) // ┌─────────────────┐ @@ -24,15 +25,15 @@ func TestGeneralOptions_UnmarshalTOML(t *testing.T) { { name: "valid general options", input: map[string]any{ - "log-level": "debug", - "silent": true, - "system-proxy": true, + "log-level": "debug", + "silent": true, + "network-config": true, }, wantErr: false, assert: func(t *testing.T, o GeneralOptions) { assert.Equal(t, zerolog.DebugLevel, *o.LogLevel) assert.True(t, *o.Silent) - assert.True(t, *o.SetSystemProxy) + assert.True(t, *o.SetNetworkConfig) }, }, { @@ -74,8 +75,8 @@ func TestGeneralOptions_Clone(t *testing.T) { { name: "non-nil receiver", input: &GeneralOptions{ - LogLevel: ptr.FromValue(zerolog.DebugLevel), - Silent: ptr.FromValue(true), + LogLevel: lo.ToPtr(zerolog.DebugLevel), + Silent: lo.ToPtr(true), }, assert: func(t *testing.T, input *GeneralOptions, output *GeneralOptions) { assert.NotNil(t, output) @@ -104,14 +105,14 @@ func TestGeneralOptions_Merge(t *testing.T) { { name: "nil receiver", base: nil, - override: &GeneralOptions{Silent: ptr.FromValue(true)}, + override: &GeneralOptions{Silent: lo.ToPtr(true)}, assert: func(t *testing.T, output *GeneralOptions) { assert.True(t, *output.Silent) }, }, { name: "nil override", - base: &GeneralOptions{Silent: ptr.FromValue(false)}, + base: &GeneralOptions{Silent: lo.ToPtr(false)}, override: nil, assert: func(t *testing.T, output *GeneralOptions) { assert.False(t, *output.Silent) @@ -120,11 +121,11 @@ func TestGeneralOptions_Merge(t *testing.T) { { name: "merge values", base: &GeneralOptions{ - Silent: ptr.FromValue(false), - LogLevel: ptr.FromValue(zerolog.InfoLevel), + Silent: lo.ToPtr(false), + LogLevel: lo.ToPtr(zerolog.InfoLevel), }, override: &GeneralOptions{ - Silent: ptr.FromValue(true), + Silent: lo.ToPtr(true), }, assert: func(t *testing.T, output *GeneralOptions) { assert.True(t, *output.Silent) @@ -204,7 +205,7 @@ func TestServerOptions_Clone(t *testing.T) { { name: "non-nil receiver", input: &ServerOptions{ - DefaultTTL: ptr.FromValue(uint8(64)), + DefaultTTL: lo.ToPtr(uint8(64)), ListenAddr: &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 8080}, }, assert: func(t *testing.T, input *ServerOptions, output *ServerOptions) { @@ -237,14 +238,14 @@ func TestServerOptions_Merge(t *testing.T) { { name: "nil receiver", base: nil, - override: &ServerOptions{DefaultTTL: ptr.FromValue(uint8(64))}, + override: &ServerOptions{DefaultTTL: lo.ToPtr(uint8(64))}, assert: func(t *testing.T, output *ServerOptions) { assert.Equal(t, uint8(64), *output.DefaultTTL) }, }, { name: "nil override", - base: &ServerOptions{DefaultTTL: ptr.FromValue(uint8(128))}, + base: &ServerOptions{DefaultTTL: lo.ToPtr(uint8(128))}, override: nil, assert: func(t *testing.T, output *ServerOptions) { assert.Equal(t, uint8(128), *output.DefaultTTL) @@ -253,11 +254,11 @@ func TestServerOptions_Merge(t *testing.T) { { name: "merge values", base: &ServerOptions{ - DefaultTTL: ptr.FromValue(uint8(64)), + DefaultTTL: lo.ToPtr(uint8(64)), ListenAddr: &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 8080}, }, override: &ServerOptions{ - DefaultTTL: ptr.FromValue(uint8(128)), + DefaultTTL: lo.ToPtr(uint8(128)), }, assert: func(t *testing.T, output *ServerOptions) { assert.Equal(t, uint8(128), *output.DefaultTTL) @@ -341,7 +342,7 @@ func TestDNSOptions_Clone(t *testing.T) { { name: "non-nil receiver", input: &DNSOptions{ - Mode: ptr.FromValue(DNSModeHTTPS), + Mode: lo.ToPtr(DNSModeHTTPS), Addr: &net.TCPAddr{IP: net.ParseIP("1.1.1.1"), Port: 53}, }, assert: func(t *testing.T, input *DNSOptions, output *DNSOptions) { @@ -373,14 +374,14 @@ func TestDNSOptions_Merge(t *testing.T) { { name: "nil receiver", base: nil, - override: &DNSOptions{Mode: ptr.FromValue(DNSModeHTTPS)}, + override: &DNSOptions{Mode: lo.ToPtr(DNSModeHTTPS)}, assert: func(t *testing.T, output *DNSOptions) { assert.Equal(t, DNSModeHTTPS, *output.Mode) }, }, { name: "nil override", - base: &DNSOptions{Mode: ptr.FromValue(DNSModeUDP)}, + base: &DNSOptions{Mode: lo.ToPtr(DNSModeUDP)}, override: nil, assert: func(t *testing.T, output *DNSOptions) { assert.Equal(t, DNSModeUDP, *output.Mode) @@ -389,12 +390,12 @@ func TestDNSOptions_Merge(t *testing.T) { { name: "merge values", base: &DNSOptions{ - Mode: ptr.FromValue(DNSModeUDP), + Mode: lo.ToPtr(DNSModeUDP), Addr: &net.TCPAddr{IP: net.ParseIP("8.8.8.8"), Port: 53}, }, override: &DNSOptions{ - Mode: ptr.FromValue(DNSModeUDP), - HTTPSURL: ptr.FromValue("https://dns.google/test"), + Mode: lo.ToPtr(DNSModeUDP), + HTTPSURL: lo.ToPtr("https://dns.google/test"), }, assert: func(t *testing.T, output *DNSOptions) { assert.Equal(t, DNSModeUDP, *output.Mode) @@ -481,7 +482,7 @@ func TestHTTPSOptions_Clone(t *testing.T) { { name: "non-nil receiver", input: &HTTPSOptions{ - Disorder: ptr.FromValue(true), + Disorder: lo.ToPtr(true), FakePacket: proto.NewFakeTLSMessage([]byte{0x01}), }, assert: func(t *testing.T, input *HTTPSOptions, output *HTTPSOptions) { @@ -514,14 +515,14 @@ func TestHTTPSOptions_Merge(t *testing.T) { { name: "nil receiver", base: nil, - override: &HTTPSOptions{Disorder: ptr.FromValue(true)}, + override: &HTTPSOptions{Disorder: lo.ToPtr(true)}, assert: func(t *testing.T, output *HTTPSOptions) { assert.True(t, *output.Disorder) }, }, { name: "nil override", - base: &HTTPSOptions{Disorder: ptr.FromValue(false)}, + base: &HTTPSOptions{Disorder: lo.ToPtr(false)}, override: nil, assert: func(t *testing.T, output *HTTPSOptions) { assert.False(t, *output.Disorder) @@ -530,12 +531,12 @@ func TestHTTPSOptions_Merge(t *testing.T) { { name: "merge values", base: &HTTPSOptions{ - Disorder: ptr.FromValue(false), - ChunkSize: ptr.FromValue(uint8(10)), + Disorder: lo.ToPtr(false), + ChunkSize: lo.ToPtr(uint8(10)), FakePacket: proto.NewFakeTLSMessage([]byte{0x01}), }, override: &HTTPSOptions{ - Disorder: ptr.FromValue(true), + Disorder: lo.ToPtr(true), FakePacket: proto.NewFakeTLSMessage([]byte{0x02}), }, assert: func(t *testing.T, output *HTTPSOptions) { @@ -623,10 +624,10 @@ func TestPolicyOptions_Clone(t *testing.T) { { name: "non-nil receiver", input: &PolicyOptions{ - Auto: ptr.FromValue(true), + Auto: lo.ToPtr(true), Overrides: []Rule{ { - Name: ptr.FromValue("rule1"), + Name: lo.ToPtr("rule1"), Match: &MatchAttrs{Domains: []string{"example.com"}}, }, }, @@ -660,14 +661,14 @@ func TestPolicyOptions_Merge(t *testing.T) { { name: "nil receiver", base: nil, - override: &PolicyOptions{Auto: ptr.FromValue(true)}, + override: &PolicyOptions{Auto: lo.ToPtr(true)}, assert: func(t *testing.T, output *PolicyOptions) { assert.True(t, *output.Auto) }, }, { name: "nil override", - base: &PolicyOptions{Auto: ptr.FromValue(false)}, + base: &PolicyOptions{Auto: lo.ToPtr(false)}, override: nil, assert: func(t *testing.T, output *PolicyOptions) { assert.False(t, *output.Auto) @@ -676,18 +677,17 @@ func TestPolicyOptions_Merge(t *testing.T) { { name: "merge values", base: &PolicyOptions{ - Auto: ptr.FromValue(false), - Overrides: []Rule{{Name: ptr.FromValue("rule1")}}, + Auto: lo.ToPtr(false), + Overrides: []Rule{{Name: lo.ToPtr("rule1")}}, }, override: &PolicyOptions{ - Auto: ptr.FromValue(true), - Overrides: []Rule{{Name: ptr.FromValue("rule2")}}, + Auto: lo.ToPtr(true), + Overrides: []Rule{{Name: lo.ToPtr("rule2")}}, }, assert: func(t *testing.T, output *PolicyOptions) { assert.True(t, *output.Auto) - assert.Len(t, output.Overrides, 2) - assert.Equal(t, "rule1", *output.Overrides[0].Name) - assert.Equal(t, "rule2", *output.Overrides[1].Name) + assert.Len(t, output.Overrides, 1) + assert.Equal(t, "rule2", *output.Overrides[0].Name) }, }, } @@ -885,7 +885,7 @@ func TestRule_Clone(t *testing.T) { { name: "non-nil receiver", input: &Rule{ - Name: ptr.FromValue("rule1"), + Name: lo.ToPtr("rule1"), Match: &MatchAttrs{Domains: []string{"example.com"}}, }, assert: func(t *testing.T, input *Rule, output *Rule) { @@ -903,3 +903,92 @@ func TestRule_Clone(t *testing.T) { }) } } + +func TestSegmentPlan_UnmarshalTOML(t *testing.T) { + t.Run("valid segment head", func(t *testing.T) { + input := ` +from = "head" +at = 10 +lazy = true +noise = 1 +` + var s SegmentPlan + err := toml.Unmarshal([]byte(input), &s) + assert.NoError(t, err) + assert.Equal(t, SegmentFromHead, s.From) + assert.Equal(t, 10, s.At) + assert.True(t, s.Lazy) + assert.Equal(t, 1, s.Noise) + }) + + t.Run("valid segment sni", func(t *testing.T) { + input := ` +from = "sni" +at = -5 +` + var s SegmentPlan + err := toml.Unmarshal([]byte(input), &s) + assert.NoError(t, err) + assert.Equal(t, SegmentFromSNI, s.From) + assert.Equal(t, -5, s.At) + }) + + t.Run("missing required field from", func(t *testing.T) { + input := ` +at = 5 +` + var s SegmentPlan + err := toml.Unmarshal([]byte(input), &s) + assert.Error(t, err) + assert.Contains(t, err.Error(), "field 'from' is required") + }) + + t.Run("missing required field at", func(t *testing.T) { + input := ` +from = "head" +` + var s SegmentPlan + err := toml.Unmarshal([]byte(input), &s) + assert.Error(t, err) + assert.Contains(t, err.Error(), "field 'at' is required") + }) + + t.Run("invalid from value", func(t *testing.T) { + input := ` +from = "invalid" +at = 5 +` + var s SegmentPlan + err := toml.Unmarshal([]byte(input), &s) + assert.Error(t, err) + }) +} + +func TestHTTPSOptions_CustomSegmentPlans(t *testing.T) { + t.Run("valid custom config", func(t *testing.T) { + input := ` +split-mode = "custom" +custom-segments = [ + { from = "head", at = 2 }, + { from = "sni", at = 0 } +] +` + var opts HTTPSOptions + err := toml.Unmarshal([]byte(input), &opts) + assert.NoError(t, err) + assert.Equal(t, HTTPSSplitModeCustom, *opts.SplitMode) + assert.Len(t, opts.CustomSegmentPlans, 2) + assert.Equal(t, SegmentFromHead, opts.CustomSegmentPlans[0].From) + assert.Equal(t, 2, opts.CustomSegmentPlans[0].At) + }) + + t.Run("missing custom segments", func(t *testing.T) { + input := ` +split-mode = "custom" +` + var opts HTTPSOptions + err := toml.Unmarshal([]byte(input), &opts) + assert.Error(t, err) + assert.Contains(t, err.Error(), "custom-segments must be provided") + }) +} diff --git a/internal/config/validate.go b/internal/config/validate.go index c45c71a0..ba24e50b 100644 --- a/internal/config/validate.go +++ b/internal/config/validate.go @@ -51,10 +51,12 @@ var ( checkUint8 = int64Range(0, math.MaxUint8) checkUint16 = int64Range(0, math.MaxUint16) checkUint8NonZero = int64Range(1, math.MaxUint8) - checkDNSMode = checkOneOf(availableDNSModes...) - checkDNSQueryType = checkOneOf(availableDNSQueries...) - checkHTTPSSplitMode = checkOneOf(availableHTTPSModes...) - checkLogLevel = checkOneOf(availableLogLevels...) + checkServerMode = checkOneOf(availableServerModeValues...) + checkDNSMode = checkOneOf(availableDNSModeValues...) + checkDNSQueryType = checkOneOf(availableDNSQueryValues...) + checkHTTPSSplitMode = checkOneOf(availableHTTPSModeValues...) + checkLogLevel = checkOneOf(availableLogLevelValues...) + checkSegmentFrom = checkOneOf(availableSegmentFromValues...) ) func checkDomainPattern(v string) error { diff --git a/internal/config/validate_test.go b/internal/config/validate_test.go index 3c9c39dc..302c82d9 100644 --- a/internal/config/validate_test.go +++ b/internal/config/validate_test.go @@ -3,8 +3,8 @@ package config import ( "testing" + "github.com/samber/lo" "github.com/stretchr/testify/assert" - "github.com/xvzc/SpoofDPI/internal/ptr" ) func TestCheckDomainPattern(t *testing.T) { @@ -179,9 +179,9 @@ func TestCheckMatchAttrs(t *testing.T) { Domains: []string{"www.google.com"}, Addrs: []AddrMatch{ { - CIDR: ptr.FromValue(MustParseCIDR("192.168.0.1/24")), - PortFrom: ptr.FromValue(uint16(80)), - PortTo: ptr.FromValue(uint16(443)), + CIDR: lo.ToPtr(MustParseCIDR("192.168.0.1/24")), + PortFrom: lo.ToPtr(uint16(80)), + PortTo: lo.ToPtr(uint16(443)), }, }, }, @@ -199,9 +199,9 @@ func TestCheckMatchAttrs(t *testing.T) { input: MatchAttrs{ Addrs: []AddrMatch{ { - CIDR: ptr.FromValue(MustParseCIDR("10.0.0.0/8")), - PortFrom: ptr.FromValue(uint16(0)), - PortTo: ptr.FromValue(uint16(65535)), + CIDR: lo.ToPtr(MustParseCIDR("10.0.0.0/8")), + PortFrom: lo.ToPtr(uint16(0)), + PortTo: lo.ToPtr(uint16(65535)), }, }, }, @@ -217,7 +217,7 @@ func TestCheckMatchAttrs(t *testing.T) { input: MatchAttrs{ Addrs: []AddrMatch{ { - CIDR: ptr.FromValue(MustParseCIDR("10.0.0.0/8")), + CIDR: lo.ToPtr(MustParseCIDR("10.0.0.0/8")), }, }, }, @@ -228,8 +228,8 @@ func TestCheckMatchAttrs(t *testing.T) { input: MatchAttrs{ Addrs: []AddrMatch{ { - PortFrom: ptr.FromValue(uint16(80)), - PortTo: ptr.FromValue(uint16(443)), + PortFrom: lo.ToPtr(uint16(80)), + PortTo: lo.ToPtr(uint16(443)), }, }, }, @@ -262,7 +262,7 @@ func TestCheckRule(t *testing.T) { Domains: []string{"example.com"}, }, DNS: &DNSOptions{ - Mode: ptr.FromValue(DNSModeUDP), + Mode: lo.ToPtr(DNSModeUDP), }, }, wantErr: false, @@ -273,14 +273,14 @@ func TestCheckRule(t *testing.T) { Match: &MatchAttrs{ Addrs: []AddrMatch{ { - CIDR: ptr.FromValue(MustParseCIDR("192.168.1.0/24")), - PortFrom: ptr.FromValue(uint16(80)), - PortTo: ptr.FromValue(uint16(80)), + CIDR: lo.ToPtr(MustParseCIDR("192.168.1.0/24")), + PortFrom: lo.ToPtr(uint16(80)), + PortTo: lo.ToPtr(uint16(80)), }, }, }, HTTPS: &HTTPSOptions{ - Disorder: ptr.FromValue(true), + Disorder: lo.ToPtr(true), }, }, wantErr: false, diff --git a/internal/desync/tls.go b/internal/desync/tls.go index b136cd75..7d411d05 100644 --- a/internal/desync/tls.go +++ b/internal/desync/tls.go @@ -8,16 +8,17 @@ import ( "net" "github.com/rs/zerolog" + "github.com/samber/lo" "github.com/xvzc/SpoofDPI/internal/config" "github.com/xvzc/SpoofDPI/internal/logging" "github.com/xvzc/SpoofDPI/internal/netutil" "github.com/xvzc/SpoofDPI/internal/packet" "github.com/xvzc/SpoofDPI/internal/proto" - "github.com/xvzc/SpoofDPI/internal/ptr" ) -type TLSDesyncerAttrs struct { - DefaultTTL uint8 +type Segment struct { + Packet []byte + Lazy bool } // TLSDesyncer splits the data into chunks and optionally @@ -25,22 +26,19 @@ type TLSDesyncerAttrs struct { type TLSDesyncer struct { writer packet.Writer sniffer packet.Sniffer - attrs *TLSDesyncerAttrs } func NewTLSDesyncer( writer packet.Writer, sniffer packet.Sniffer, - attrs *TLSDesyncerAttrs, ) *TLSDesyncer { return &TLSDesyncer{ writer: writer, sniffer: sniffer, - attrs: attrs, } } -func (d *TLSDesyncer) Send( +func (d *TLSDesyncer) Desync( ctx context.Context, logger zerolog.Logger, conn net.Conn, @@ -49,13 +47,13 @@ func (d *TLSDesyncer) Send( ) (int, error) { logger = logging.WithLocalScope(ctx, logger, "tls_desync") - if ptr.FromPtr(httpsOpts.Skip) { + if lo.FromPtr(httpsOpts.Skip) { logger.Trace().Msg("skip desync for this request") - return d.sendSegments(conn, [][]byte{msg.Raw()}) + return d.sendSegments(conn, logger, []Segment{{Packet: msg.Raw()}}) } - if d.sniffer != nil && d.writer != nil && ptr.FromPtr(httpsOpts.FakeCount) > 0 { - oTTL := d.sniffer.GetOptimalTTL(conn.RemoteAddr().String()) + if d.sniffer != nil && d.writer != nil && lo.FromPtr(httpsOpts.FakeCount) > 0 { + oTTL := d.sniffer.GetOptimalTTL(conn.RemoteAddr().(*net.TCPAddr).IP.String()) n, err := d.sendFakePackets(ctx, logger, conn, oTTL, httpsOpts) if err != nil { logger.Warn().Err(err).Msg("failed to send fake packets") @@ -64,38 +62,16 @@ func (d *TLSDesyncer) Send( } } - segments := split(logger, httpsOpts, msg) + segments := split(logger, msg, httpsOpts) - if ptr.FromPtr(httpsOpts.Disorder) { - return d.sendSegmentsDisorder(conn, logger, segments, httpsOpts) - } - - return d.sendSegments(conn, segments) + return d.sendSegments(conn, logger, segments) } // sendSegments sends the segmented Client Hello sequentially. -func (d *TLSDesyncer) sendSegments(conn net.Conn, segments [][]byte) (int, error) { - total := 0 - for _, chunk := range segments { - n, err := conn.Write(chunk) - total += n - if err != nil { - return total, err - } - } - - return total, nil -} - -// sendSegmentsDisorder sends the segmented Client Hello out of order. -// Since performance is prioritized over strict randomness, -// a single 64-bit pattern is generated and reused cyclically -// for sequences exceeding 64 chunks. -func (d *TLSDesyncer) sendSegmentsDisorder( +func (d *TLSDesyncer) sendSegments( conn net.Conn, logger zerolog.Logger, - segments [][]byte, - opts *config.HTTPSOptions, + segments []Segment, ) (int, error) { var isIPv4 bool if tcpAddr, ok := conn.LocalAddr().(*net.TCPAddr); ok { @@ -110,30 +86,23 @@ func (d *TLSDesyncer) sendSegmentsDisorder( } } - defer setTTLWrap(d.attrs.DefaultTTL) // Restore the default TTL on return + defaultTTL := getDefaultTTL() - disorderBits := genPatternMask() - logger.Debug(). - Str("bits", fmt.Sprintf("%064b", disorderBits)). - Msgf("disorder ready") - curBit := uint64(1) total := 0 for _, chunk := range segments { - if !ttlErrored && disorderBits&curBit == curBit { - setTTLWrap(1) + if !ttlErrored && chunk.Lazy { + setTTLWrap(0) } - n, err := conn.Write(chunk) + n, err := conn.Write(chunk.Packet) + total += n if err != nil { return total, err } - total += n - if !ttlErrored && disorderBits&curBit == curBit { - setTTLWrap(d.attrs.DefaultTTL) + if !ttlErrored && chunk.Lazy { + setTTLWrap(defaultTTL) } - - curBit = bits.RotateLeft64(curBit, 1) } return total, nil @@ -141,12 +110,12 @@ func (d *TLSDesyncer) sendSegmentsDisorder( func split( logger zerolog.Logger, - attrs *config.HTTPSOptions, msg *proto.TLSMessage, -) [][]byte { - mode := *attrs.SplitMode + opts *config.HTTPSOptions, +) []Segment { + mode := *opts.SplitMode raw := msg.Raw() - var chunks [][]byte + var segments []Segment var err error switch mode { case config.HTTPSSplitModeSNI: @@ -155,38 +124,42 @@ func split( if err != nil { break } - chunks, err = splitSNI(raw, start, end) + segments, err = splitSNI(raw, start, end, *opts.Disorder) logger.Trace().Msgf("extracted SNI is '%s'", raw[start:end]) case config.HTTPSSplitModeRandom: mask := genPatternMask() - chunks, err = splitMask(raw, mask) + segments, err = splitMask(raw, mask, *opts.Disorder) case config.HTTPSSplitModeChunk: - chunks, err = splitChunks(raw, int(*attrs.ChunkSize)) + segments, err = splitChunks(raw, int(*opts.ChunkSize), *opts.Disorder) case config.HTTPSSplitModeFirstByte: - chunks, err = splitFirstByte(raw) + segments, err = splitFirstByte(raw, *opts.Disorder) + case config.HTTPSSplitModeCustom: + segments, err = applySegmentPlans(msg, opts.CustomSegmentPlans) case config.HTTPSSplitModeNone: - chunks = [][]byte{raw} + segments = []Segment{{Packet: raw}} default: logger.Debug().Msgf("unsupprted split mode '%s'. proceed without split", mode) - chunks = [][]byte{raw} + segments = []Segment{{Packet: raw}} } logger.Debug(). - Int("len", len(chunks)). + Int("len", len(segments)). Str("mode", mode.String()). Str("kind", msg.Kind()). + Bool("disorder", *opts.Disorder). Msg("segments ready") if err != nil { logger.Debug().Err(err). + Str("kind", msg.Kind()). Msgf("error processing split mode '%s', fallback to 'none'", mode) - chunks = [][]byte{raw} + segments = []Segment{{Packet: raw}} } - return chunks + return segments } -func splitChunks(raw []byte, size int) ([][]byte, error) { +func splitChunks(raw []byte, size int, disorder bool) ([]Segment, error) { lenRaw := len(raw) if lenRaw == 0 { @@ -198,26 +171,31 @@ func splitChunks(raw []byte, size int) ([][]byte, error) { } capacity := (lenRaw + size - 1) / size - chunks := make([][]byte, 0, capacity) + chunks := make([]Segment, 0, capacity) + curDisorder := true for len(raw) > 0 { n := min(len(raw), size) - chunks = append(chunks, raw[:n]) + chunks = append(chunks, Segment{Packet: raw[:n], Lazy: curDisorder && disorder}) raw = raw[n:] + curDisorder = !curDisorder } return chunks, nil } -func splitFirstByte(raw []byte) ([][]byte, error) { +func splitFirstByte(raw []byte, disorder bool) ([]Segment, error) { if len(raw) < 2 { return nil, fmt.Errorf("len(raw) is less than 2") } - return [][]byte{raw[:1], raw[1:]}, nil + return []Segment{ + {Packet: raw[:1], Lazy: disorder && true}, + {Packet: raw[1:], Lazy: false}, + }, nil } -func splitSNI(raw []byte, start, end int) ([][]byte, error) { +func splitSNI(raw []byte, start, end int, disorder bool) ([]Segment, error) { lenRaw := len(raw) if lenRaw == 0 { @@ -232,43 +210,66 @@ func splitSNI(raw []byte, start, end int) ([][]byte, error) { return nil, fmt.Errorf("invalid start, end pos (out of range)") } - segments := make([][]byte, 0, lenRaw) - segments = append(segments, raw[:start]) + curDisorder := true + segments := make([]Segment, 0, lenRaw) + segments = append(segments, Segment{Packet: raw[:start]}) for i := range end - start { - segments = append(segments, []byte{raw[start+i]}) + segments = append(segments, Segment{ + Packet: []byte{raw[start+i]}, + Lazy: curDisorder && disorder, + }) + curDisorder = !curDisorder } - segments = append(segments, raw[end:]) - return append([][]byte(nil), segments...), nil + segments = append(segments, Segment{ + Packet: raw[end:], + Lazy: curDisorder && disorder, + }) + + return segments, nil } -func splitMask(raw []byte, mask uint64) ([][]byte, error) { +func splitMask(raw []byte, mask uint64, disorder bool) ([]Segment, error) { lenRaw := len(raw) if lenRaw == 0 { return nil, fmt.Errorf("empty data") } - segments := make([][]byte, 0, lenRaw) + curDisorder := true + segments := make([]Segment, 0, lenRaw) start := 0 curBit := uint64(1) for i := range lenRaw { if mask&curBit == curBit { if i > start { - segments = append(segments, raw[start:i]) + segments = append(segments, Segment{ + Packet: raw[start:i], + Lazy: curDisorder && disorder, + }) + curDisorder = !curDisorder } - segments = append(segments, raw[i:i+1]) + + segments = append(segments, Segment{ + Packet: raw[i : i+1], + Lazy: curDisorder && disorder, + }) + start = i + 1 + curDisorder = !curDisorder } curBit = bits.RotateLeft64(curBit, 1) } if lenRaw > start { - segments = append(segments, raw[start:lenRaw]) + segments = append(segments, Segment{ + Packet: raw[start:lenRaw], + Lazy: curDisorder && disorder, + }) } - return append([][]byte(nil), segments...), nil + return segments, nil } func (d *TLSDesyncer) String() string { @@ -283,7 +284,7 @@ func (d *TLSDesyncer) sendFakePackets( opts *config.HTTPSOptions, ) (int, error) { var totalSent int - segments := split(logger, opts, opts.FakePacket) + segments := split(logger, opts.FakePacket, opts) for range *(opts.FakeCount) { for _, v := range segments { @@ -292,7 +293,7 @@ func (d *TLSDesyncer) sendFakePackets( conn.LocalAddr(), conn.RemoteAddr(), oTTL, - v, + v.Packet, ) if err != nil { return totalSent, err @@ -305,6 +306,67 @@ func (d *TLSDesyncer) sendFakePackets( return totalSent, nil } +func applySegmentPlans( + msg *proto.TLSMessage, + plans []config.SegmentPlan, +) ([]Segment, error) { + raw := msg.Raw() + sniStart, _, err := msg.ExtractSNIOffset() + if err != nil { + return nil, err + } + + var segments []Segment + prvAt := 0 + + for _, s := range plans { + base := 0 + switch s.From { + case config.SegmentFromSNI: + base = sniStart + case config.SegmentFromHead: + base = 0 + } + + curAt := base + s.At + + if s.Noise > 0 { + // Random integer in [-noise, noise] + noiseVal := rand.IntN(s.Noise*2+1) - s.Noise + curAt += noiseVal + } + + // Boundary checks + if curAt < 0 { + curAt = 0 + } + if curAt > len(raw) { + curAt = len(raw) + } + + // Handle overlap with previous split point + if curAt < prvAt { + curAt = prvAt + } + + segments = append(segments, Segment{ + Packet: raw[prvAt:curAt], + Lazy: s.Lazy, + }) + prvAt = curAt + } + + if prvAt < len(raw) { + segments = append(segments, Segment{Packet: raw[prvAt:]}) + } + + return segments, nil +} + +func getDefaultTTL() uint8 { + return 64 +} + // --- Helper Functions (Low-level Syscall) --- // genPatternMask generates a pseudo-random 64-bit mask used for determining diff --git a/internal/desync/tls_test.go b/internal/desync/tls_test.go index 0311e79f..83cbb1cb 100644 --- a/internal/desync/tls_test.go +++ b/internal/desync/tls_test.go @@ -1,359 +1,189 @@ package desync import ( - "fmt" - "slices" "testing" - "github.com/rs/zerolog" "github.com/stretchr/testify/assert" "github.com/xvzc/SpoofDPI/internal/config" "github.com/xvzc/SpoofDPI/internal/proto" - "github.com/xvzc/SpoofDPI/internal/ptr" ) -func TestSplit(t *testing.T) { - logger := zerolog.Nop() +func TestApplySegmentPlans(t *testing.T) { + // Using the FakeClientHello from config/types.go - tcs := []struct { - name string - opts *config.HTTPSOptions - msg *proto.TLSMessage - assert func(t *testing.T, chunks [][]byte) - }{ - { - name: "none", - opts: &config.HTTPSOptions{SplitMode: ptr.FromValue(config.HTTPSSplitModeNone)}, - msg: proto.NewFakeTLSMessage([]byte("12345")), - assert: func(t *testing.T, chunks [][]byte) { - assert.Len(t, chunks, 1) - assert.Equal(t, []byte("12345"), chunks[0]) - }, - }, - { - name: "chunk size 3", - msg: proto.NewFakeTLSMessage([]byte("1234567890")), - opts: &config.HTTPSOptions{ - SplitMode: ptr.FromValue(config.HTTPSSplitModeChunk), - ChunkSize: ptr.FromValue(uint8(3)), - }, - assert: func(t *testing.T, chunks [][]byte) { - assert.Len(t, chunks, 4) - assert.Equal(t, []byte("123"), chunks[0]) - assert.Equal(t, []byte("0"), chunks[3]) + fakeRaw := []byte(config.FakeClientHello) + + msg := proto.NewFakeTLSMessage(fakeRaw) + + // We know where SNI is in FakeClientHello. + // Let's verify SNI offset first to write correct assertions. + sniStart, _, _ := msg.ExtractSNIOffset() + t.Run("cut head", func(t *testing.T) { + plans := []config.SegmentPlan{ + { + From: config.SegmentFromHead, + At: 5, }, - }, - { - name: "chunk size 0 (fallback)", - msg: proto.NewFakeTLSMessage([]byte("1234567890")), - opts: &config.HTTPSOptions{ - SplitMode: ptr.FromValue(config.HTTPSSplitModeChunk), - ChunkSize: ptr.FromValue(uint8(0)), + } + + chunks, err := applySegmentPlans(msg, plans) + assert.NoError(t, err) + assert.Len(t, chunks, 2) + assert.Equal(t, fakeRaw[:5], chunks[0].Packet) + assert.Equal(t, fakeRaw[5:], chunks[1].Packet) + assert.False(t, chunks[0].Lazy) + assert.False(t, chunks[1].Lazy) + }) + + t.Run("cut head lazy", func(t *testing.T) { + plans := []config.SegmentPlan{ + { + From: config.SegmentFromHead, + At: 5, + Lazy: true, }, - assert: func(t *testing.T, chunks [][]byte) { - assert.Len(t, chunks, 1) - assert.Equal(t, []byte("1234567890"), chunks[0]) + } + + chunks, err := applySegmentPlans(msg, plans) + assert.NoError(t, err) + assert.Len(t, chunks, 2) + assert.Equal(t, fakeRaw[:5], chunks[0].Packet) + assert.True(t, chunks[0].Lazy) + assert.Equal(t, fakeRaw[5:], chunks[1].Packet) + assert.False(t, chunks[1].Lazy) // Remainder is not lazy by default + }) + + t.Run("cut head multiple", func(t *testing.T) { + plans := []config.SegmentPlan{ + { + From: config.SegmentFromHead, + At: 5, }, - }, - { - name: "first-byte", - msg: proto.NewFakeTLSMessage([]byte("1234567890")), - opts: &config.HTTPSOptions{ - SplitMode: ptr.FromValue(config.HTTPSSplitModeFirstByte), + { + From: config.SegmentFromHead, + At: 10, }, - assert: func(t *testing.T, chunks [][]byte) { - assert.Len(t, chunks, 2) - assert.Equal(t, []byte("1"), chunks[0]) - assert.Equal(t, []byte("234567890"), chunks[1]) + } + + chunks, err := applySegmentPlans(msg, plans) + assert.NoError(t, err) + assert.Len(t, chunks, 3) + assert.Equal(t, fakeRaw[:5], chunks[0].Packet) + assert.Equal(t, fakeRaw[5:10], chunks[1].Packet) + assert.Equal(t, fakeRaw[10:], chunks[2].Packet) + }) + + t.Run("cut sni", func(t *testing.T) { + if sniStart == 0 { + t.Skip("SNI not found in fake packet") + } + + // Split at SNI start + plans := []config.SegmentPlan{ + { + From: config.SegmentFromSNI, + At: 0, }, - }, - { - name: "first-byte (fallback)", - msg: proto.NewFakeTLSMessage([]byte("1")), - opts: &config.HTTPSOptions{ - SplitMode: ptr.FromValue(config.HTTPSSplitModeFirstByte), + } + + chunks, err := applySegmentPlans(msg, plans) + assert.NoError(t, err) + + // Should be [0...sniStart], [sniStart...] + assert.Len(t, chunks, 2) + assert.Equal(t, fakeRaw[:sniStart], chunks[0].Packet) + assert.Equal(t, fakeRaw[sniStart:], chunks[1].Packet) + }) + + t.Run("cut sni offset", func(t *testing.T) { + if sniStart == 0 { + t.Skip("SNI not found in fake packet") + } + + offset := 5 + target := sniStart + offset + plans := []config.SegmentPlan{ + { + From: config.SegmentFromSNI, + At: offset, }, - assert: func(t *testing.T, chunks [][]byte) { - assert.Len(t, chunks, 1) - assert.Equal(t, []byte("1"), chunks[0]) + } + + chunks, err := applySegmentPlans(msg, plans) + assert.NoError(t, err) + assert.Len(t, chunks, 2) + assert.Equal(t, fakeRaw[:target], chunks[0].Packet) + assert.Equal(t, fakeRaw[target:], chunks[1].Packet) + }) + + t.Run("cut mixed head and sni", func(t *testing.T) { + if sniStart == 0 { + t.Skip("SNI not found in fake packet") + } + + // Split at 5 (head) and then at SNI start + plans := []config.SegmentPlan{ + { + From: config.SegmentFromHead, + At: 5, }, - }, - { - name: "valid sni", - msg: proto.NewFakeTLSMessage([]byte(config.FakeClientHello)), - opts: &config.HTTPSOptions{SplitMode: ptr.FromValue(config.HTTPSSplitModeSNI)}, - assert: func(t *testing.T, chunks [][]byte) { - assert.GreaterOrEqual(t, len(chunks), 1) - assert.Equal(t, string("www.w3.org"), string(slices.Concat(chunks[1:11]...))) + { + From: config.SegmentFromSNI, + At: 0, }, - }, - { - name: "sni (fallback)", - msg: proto.NewFakeTLSMessage([]byte("1234567890")), - opts: &config.HTTPSOptions{SplitMode: ptr.FromValue(config.HTTPSSplitModeSNI)}, - assert: func(t *testing.T, chunks [][]byte) { - // Fallback to no split on error - assert.Len(t, chunks, 1) - assert.Equal(t, []byte("1234567890"), chunks[0]) + } + + chunks, err := applySegmentPlans(msg, plans) + assert.NoError(t, err) + assert.Len(t, chunks, 3) + assert.Equal(t, fakeRaw[:5], chunks[0].Packet) + assert.Equal(t, fakeRaw[5:sniStart], chunks[1].Packet) + assert.Equal(t, fakeRaw[sniStart:], chunks[2].Packet) + }) + + t.Run("overlap ignored", func(t *testing.T) { + // Split at 10, then try to split at 5 (should be ignored/empty for that segment) + plans := []config.SegmentPlan{ + { + From: config.SegmentFromHead, + At: 10, }, - }, - { - name: "random", - msg: proto.NewFakeTLSMessage([]byte(config.FakeClientHello)), - opts: &config.HTTPSOptions{ - SplitMode: ptr.FromValue(config.HTTPSSplitModeRandom), + { + From: config.SegmentFromHead, + At: 5, }, - assert: func(t *testing.T, chunks [][]byte) { - assert.GreaterOrEqual(t, len(chunks), 1) - var joined []byte - for _, c := range chunks { - joined = append(joined, c...) - } - assert.Equal(t, []byte(config.FakeClientHello), joined) + } + + chunks, err := applySegmentPlans(msg, plans) + assert.NoError(t, err) + // chunks[0]: 0-10 + // chunks[1]: 10-10 (empty) + // chunks[2]: 10-end + assert.Len(t, chunks, 3) + assert.Equal(t, fakeRaw[:10], chunks[0].Packet) + assert.Empty(t, chunks[1].Packet) + assert.Equal(t, fakeRaw[10:], chunks[2].Packet) + }) + + t.Run("noise", func(t *testing.T) { + // We can't strictly test random values, but we can check if it runs without panic + + plans := []config.SegmentPlan{ + { + From: config.SegmentFromHead, + At: 10, + Noise: 5, }, - }, - } - - for _, tc := range tcs { - t.Run(tc.name, func(t *testing.T) { - chunks := split(logger, tc.opts, tc.msg) - tc.assert(t, chunks) - }) - } -} - -func TestSplitChunks(t *testing.T) { - tcs := []struct { - name string - raw []byte - size int - wantErr bool - expect [][]byte - }{ - { - name: "size 2", - raw: []byte("12345"), - size: 2, - wantErr: false, - expect: [][]byte{[]byte("12"), []byte("34"), []byte("5")}, - }, - { - name: "size larger than len", - raw: []byte("123"), - size: 5, - wantErr: false, - expect: [][]byte{[]byte("123")}, - }, - { - name: "size 0", - raw: []byte("12345"), - size: 0, - wantErr: true, - }, - { - name: "len(raw) is 0", - raw: []byte(""), - size: 3, - wantErr: true, - }, - } - - for _, tc := range tcs { - t.Run(tc.name, func(t *testing.T) { - chunks, err := splitChunks(tc.raw, tc.size) - if tc.wantErr { - assert.Error(t, err) - return - } - - assert.NoError(t, err) - assert.Equal(t, tc.expect, chunks) - }) - } -} - -func TestSplitFirstByte(t *testing.T) { - tcs := []struct { - name string - raw []byte - wantErr bool - expect [][]byte - }{ - { - name: "size 2", - raw: []byte("12"), - wantErr: false, - expect: [][]byte{[]byte("1"), []byte("2")}, - }, - { - name: "size 3", - raw: []byte("123"), - wantErr: false, - expect: [][]byte{[]byte("1"), []byte("23")}, - }, - { - name: "size 10", - raw: []byte("1234567890"), - wantErr: false, - expect: [][]byte{[]byte("1"), []byte("234567890")}, - }, - { - name: "len(data) is 0", - raw: []byte(""), - wantErr: true, - }, - { - name: "len(data) is 1", - raw: []byte(""), - wantErr: true, - }, - } - - for _, tc := range tcs { - t.Run(tc.name, func(t *testing.T) { - chunks, err := splitFirstByte(tc.raw) - if tc.wantErr { - assert.Error(t, err) - return - } - - assert.NoError(t, err) - assert.Equal(t, chunks, tc.expect) - }) - } -} - -func TestSplitSNI(t *testing.T) { - tcs := []struct { - name string - raw []byte - start int - end int - wantErr bool - expect [][]byte - }{ - { - name: "size 0", - raw: []byte("PREFIX_SNI_SUFFIX"), - start: 7, - end: 10, - wantErr: false, - expect: [][]byte{ - []byte("PREFIX_"), - []byte("S"), - []byte("N"), - []byte("I"), - []byte("_SUFFIX"), - }, - }, - { - name: "start out of range (start > len)", - raw: []byte("1"), - start: 3, - end: 3, - wantErr: true, - expect: [][]byte{[]byte("1")}, - }, - { - name: "start out of range (start < 0)", - raw: []byte("1"), - start: -1, - end: 5, - wantErr: true, - expect: [][]byte{[]byte("1")}, - }, - { - name: "end out of range (end > len)", - raw: []byte("1"), - start: 0, - end: 5, - wantErr: true, - expect: [][]byte{[]byte("1")}, - }, - { - name: "end out of range (end < 0)", - raw: []byte("1"), - start: -1, - end: -1, - wantErr: true, - expect: [][]byte{[]byte("1")}, - }, - { - name: "invalid start, end pos (start > end)", - raw: []byte("12345"), - start: 4, - end: 3, - wantErr: true, - expect: [][]byte{[]byte("1")}, - }, - } - - for _, tc := range tcs { - t.Run(tc.name, func(t *testing.T) { - chunks, err := splitSNI(tc.raw, tc.start, tc.end) - if tc.wantErr { - assert.Error(t, err) - return - } - - assert.NoError(t, err) - assert.Equal(t, chunks, tc.expect) - // Debug print if fail - if !assert.Equal(t, tc.expect, chunks) { - for i, c := range chunks { - fmt.Printf("chunk[%d]: %s\n", i, c) - } - } - }) - } -} - -func TestSplitMask(t *testing.T) { - tcs := []struct { - name string - raw []byte - mask uint64 - wantErr bool - expect [][]byte - }{ - { - name: "mask with some bits set (137 = 10001001)", - raw: []byte("12345678"), - mask: 137, // - wantErr: false, - expect: [][]byte{ - []byte("1"), - []byte("23"), - []byte("4"), - []byte("567"), - []byte("8"), - }, - }, - { - name: "mask 0 (no split)", - raw: []byte("123"), - mask: 0, - wantErr: false, - expect: [][]byte{[]byte("123")}, - }, - { - name: "len(data) is 0", - raw: []byte(""), - mask: 123, - wantErr: true, - }, - } - - for _, tc := range tcs { - t.Run(tc.name, func(t *testing.T) { - chunks, err := splitMask(tc.raw, tc.mask) - if tc.wantErr { - assert.Error(t, err) - return - } + } + for i := 0; i < 50; i++ { + chunks, err := applySegmentPlans(msg, plans) assert.NoError(t, err) - assert.Equal(t, chunks, tc.expect) - }) - } + assert.Len(t, chunks, 2) + // Split point should be between 10-5=5 and 10+5=15 + splitLen := len(chunks[0].Packet) + assert.GreaterOrEqual(t, splitLen, 5) + assert.LessOrEqual(t, splitLen, 15) + } + }) } diff --git a/internal/desync/udp.go b/internal/desync/udp.go new file mode 100644 index 00000000..228f68fe --- /dev/null +++ b/internal/desync/udp.go @@ -0,0 +1,72 @@ +package desync + +import ( + "context" + "net" + + "github.com/rs/zerolog" + "github.com/xvzc/SpoofDPI/internal/config" + "github.com/xvzc/SpoofDPI/internal/logging" + "github.com/xvzc/SpoofDPI/internal/packet" +) + +type UDPDesyncer struct { + logger zerolog.Logger + writer packet.Writer + sniffer packet.Sniffer +} + +func NewUDPDesyncer( + logger zerolog.Logger, + writer packet.Writer, + sniffer packet.Sniffer, +) *UDPDesyncer { + return &UDPDesyncer{ + logger: logger, + writer: writer, + sniffer: sniffer, + } +} + +func (d *UDPDesyncer) Desync( + ctx context.Context, + lConn net.Conn, + rConn net.Conn, + opts *config.UDPOptions, +) (int, error) { + logger := logging.WithLocalScope(ctx, d.logger, "udp_desync") + + if d.sniffer == nil || d.writer == nil || opts == nil || + opts.FakeCount == nil || *opts.FakeCount <= 0 { + return 0, nil + } + + dstIP := rConn.RemoteAddr().(*net.UDPAddr).IP.String() + oTTL := d.sniffer.GetOptimalTTL(dstIP) + + var totalSent int + for range *opts.FakeCount { + n, err := d.writer.WriteCraftedPacket( + ctx, + lConn.LocalAddr(), // Spoofing source: original local address (TUN) + rConn.RemoteAddr(), + oTTL, + opts.FakePacket, + ) + if err != nil { + logger.Warn().Err(err).Msg("failed to send fake packet") + continue + } + totalSent += n + } + + if totalSent > 0 { + logger.Debug(). + Int("count", *opts.FakeCount). + Int("bytes", totalSent). + Uint8("ttl", oTTL). + Msg("sent fake packets") + } + + return totalSent, nil +} diff --git a/internal/dns/addrselect/addrselect.go b/internal/dns/addrselect/addrselect.go index f81efbc1..702ab13c 100644 --- a/internal/dns/addrselect/addrselect.go +++ b/internal/dns/addrselect/addrselect.go @@ -12,26 +12,26 @@ import ( // Minimal RFC 6724 address selection. -func SortByRFC6724(addrs []net.IPAddr) { - if len(addrs) < 2 { +func SortByRFC6724(ips []net.IP) { + if len(ips) < 2 { return } - sortByRFC6724withSrcs(addrs, srcAddrs(addrs)) + sortByRFC6724withSrcs(ips, srcAddrs(ips)) } -func sortByRFC6724withSrcs(addrs []net.IPAddr, srcs []netip.Addr) { - if len(addrs) != len(srcs) { +func sortByRFC6724withSrcs(ips []net.IP, srcs []netip.Addr) { + if len(ips) != len(srcs) { panic("internal error") } - addrAttr := make([]ipAttr, len(addrs)) + addrAttr := make([]ipAttr, len(ips)) srcAttr := make([]ipAttr, len(srcs)) - for i, v := range addrs { - addrAttrIP, _ := netip.AddrFromSlice(v.IP) + for i, v := range ips { + addrAttrIP, _ := netip.AddrFromSlice(v) addrAttr[i] = ipAttrOf(addrAttrIP) srcAttr[i] = ipAttrOf(srcs[i]) } sort.Stable(&byRFC6724{ - addrs: addrs, + addrs: ips, addrAttr: addrAttr, srcs: srcs, srcAttr: srcAttr, @@ -41,12 +41,12 @@ func sortByRFC6724withSrcs(addrs []net.IPAddr, srcs []netip.Addr) { // srcAddrs tries to UDP-connect to each address to see if it has a // route. This does not send any packets. The destination port number // is irrelevant. -func srcAddrs(addrs []net.IPAddr) []netip.Addr { - srcs := make([]netip.Addr, len(addrs)) +func srcAddrs(ips []net.IP) []netip.Addr { + srcs := make([]netip.Addr, len(ips)) dst := net.UDPAddr{Port: 9} - for i := range addrs { - dst.IP = addrs[i].IP - dst.Zone = addrs[i].Zone + for i := range ips { + dst.IP = ips[i] + // dst.Zone = ips[i].Zone // Zone is not easily available in net.IP, skipping for now c, err := net.DialUDP("udp", nil, &dst) if err == nil { if src, ok := c.LocalAddr().(*net.UDPAddr); ok { @@ -79,7 +79,7 @@ func ipAttrOf(ip netip.Addr) ipAttr { } type byRFC6724 struct { - addrs []net.IPAddr // Addresses to sort. + addrs []net.IP // Addresses to sort. addrAttr []ipAttr srcs []netip.Addr // Or not a valid addr if unreachable. srcAttr []ipAttr @@ -99,8 +99,8 @@ func (s *byRFC6724) Swap(i, j int) { // // The algorithm and variable names are from RFC 6724 section 6. func (s *byRFC6724) Less(i, j int) bool { - DA := s.addrs[i].IP - DB := s.addrs[j].IP + DA := s.addrs[i] + DB := s.addrs[j] SourceDA := s.srcs[i] SourceDB := s.srcs[j] attrDA := &s.addrAttr[i] diff --git a/internal/dns/cache.go b/internal/dns/cache.go index 7185330c..11b10e24 100644 --- a/internal/dns/cache.go +++ b/internal/dns/cache.go @@ -54,7 +54,7 @@ func (cr *CacheResolver) Resolve( // For now, assuming simplistic cache key = domain, but awareness of potential issue. // Ideally: key = domain + qtypes + spec-related-things if item, ok := cr.ttlCache.Get(domain); ok { - logger.Trace().Msgf("hit") + logger.Debug().Str("domain", domain).Msgf("hit") return item.(*RecordSet).Clone(), nil } @@ -64,7 +64,9 @@ func (cr *CacheResolver) Resolve( // 2. [Cache Miss] // Delegate the actual network request to 'r.next' (the worker). - logger.Trace().Str("fallback", fallback.Info()[0].Name).Msgf("miss") + logger.Debug().Str("domain", domain).Str("fallback", fallback.Info()[0].Name). + Msgf("miss") + rSet, err := fallback.Resolve(ctx, domain, nil, rule) if err != nil { return nil, err @@ -73,7 +75,8 @@ func (cr *CacheResolver) Resolve( // 3. [Cache Write] // (Assuming the actual TTL is parsed from the DNS response) // realTTL := 5 * time.Second - logger.Trace(). + logger.Debug(). + Str("domain", domain). Int("len", len(rSet.Addrs)). Uint32("ttl", rSet.TTL). Msg("set") diff --git a/internal/dns/https.go b/internal/dns/https.go index 51fd4a0c..9bbce947 100644 --- a/internal/dns/https.go +++ b/internal/dns/https.go @@ -3,8 +3,9 @@ package dns import ( "bytes" "context" - "encoding/base64" + "crypto/tls" "fmt" + "io" "net" "net/http" "strings" @@ -14,13 +15,13 @@ import ( "github.com/rs/zerolog" "github.com/xvzc/SpoofDPI/internal/config" "github.com/xvzc/SpoofDPI/internal/logging" + "golang.org/x/net/http2" ) var _ Resolver = (*HTTPSResolver)(nil) type HTTPSResolver struct { - logger zerolog.Logger - + logger zerolog.Logger client *http.Client dnsOpts *config.DNSOptions } @@ -29,19 +30,31 @@ func NewHTTPSResolver( logger zerolog.Logger, dnsOpts *config.DNSOptions, ) *HTTPSResolver { + tr := &http.Transport{ + TLSClientConfig: &tls.Config{ + NextProtos: []string{"h2", "http/1.1"}, + }, + DialContext: (&net.Dialer{ + Timeout: 7 * time.Second, + KeepAlive: 30 * time.Second, + }).DialContext, + TLSHandshakeTimeout: 9 * time.Second, + MaxIdleConnsPerHost: 100, + MaxIdleConns: 100, + ForceAttemptHTTP2: true, + } + + // Configure HTTP/2 transport explicitly + if err := http2.ConfigureTransport(tr); err != nil { + // Log error instead of panic if strict http2 is not required, otherwise panic + panic(fmt.Sprintf("failed to configure http2: %v", err)) + } + return &HTTPSResolver{ logger: logger, client: &http.Client{ - Timeout: 5 * time.Second, - Transport: &http.Transport{ - DialContext: (&net.Dialer{ - Timeout: 3 * time.Second, - KeepAlive: 30 * time.Second, - }).DialContext, - TLSHandshakeTimeout: 5 * time.Second, - MaxIdleConnsPerHost: 100, - MaxIdleConns: 100, - }, + Transport: tr, + Timeout: 10 * time.Second, }, dnsOpts: dnsOpts, } @@ -94,55 +107,75 @@ func (dr *HTTPSResolver) exchange( return nil, err } - url := fmt.Sprintf( - "%s?dns=%s", - upstream, - base64.RawURLEncoding.EncodeToString(pack), - // base64.RawStdEncoding.EncodeToString(pack), - ) - req, err := http.NewRequest("GET", url, nil) - if err != nil { - return nil, err + const maxRetries = 2 + var resp *http.Response + var reqErr error + + // Retry loop for transient network errors like unexpected EOF + for i := 0; i < maxRetries; i++ { + req, err := http.NewRequestWithContext( + ctx, + http.MethodPost, + upstream, + bytes.NewReader(pack), + ) + if err != nil { + return nil, err + } + + req.Header.Set("Content-Type", "application/dns-message") + req.Header.Set("Accept", "application/dns-message") + + resp, reqErr = dr.client.Do(req) + if reqErr == nil { + break + } + + // Check if error is retryable + if i < maxRetries-1 && isRetryableError(reqErr) { + continue + } } - req = req.WithContext(ctx) - req.Header.Set("Accept", "application/dns-message") - - resp, err := dr.client.Do(req) - if err != nil { - return nil, err + if reqErr != nil { + return nil, reqErr } defer func() { _ = resp.Body.Close() }() - buf := bytes.Buffer{} - bodyLen, err := buf.ReadFrom(resp.Body) + body, err := io.ReadAll(resp.Body) if err != nil { return nil, err } if resp.StatusCode != http.StatusOK { logger.Trace(). - Int64("len", bodyLen). + Int("len", len(body)). Int("status", resp.StatusCode). - Str("body", buf.String()). - Msg("status not ok") - return nil, fmt.Errorf("doh status code(%d)", resp.StatusCode) + Str("body", string(body)). + Msg("doh status not ok") + return nil, fmt.Errorf("status code(%d)", resp.StatusCode) } resultMsg := new(dns.Msg) - err = resultMsg.Unpack(buf.Bytes()) - if err != nil { + if err := resultMsg.Unpack(body); err != nil { return nil, err } - // Ignore Rcode 3 (NameNotFound) as it's not a critical error. if resultMsg.Rcode != dns.RcodeSuccess && resultMsg.Rcode != dns.RcodeNameError { logger.Trace(). Int("rcode", resultMsg.Rcode). Str("msg", resultMsg.String()). - Msg("rcode not ok") - return nil, fmt.Errorf("doh Rcode(%d)", resultMsg.Rcode) + Msg("doh rcode not ok") + return nil, fmt.Errorf("Rcode(%d)", resultMsg.Rcode) } return resultMsg, nil } + +// isRetryableError checks for common transient network errors +func isRetryableError(err error) bool { + msg := err.Error() + return strings.Contains(msg, "unexpected EOF") || + strings.Contains(msg, "connection reset") || + strings.Contains(msg, "broken pipe") +} diff --git a/internal/dns/resolver.go b/internal/dns/resolver.go index 2cc63b71..daab801b 100644 --- a/internal/dns/resolver.go +++ b/internal/dns/resolver.go @@ -2,6 +2,7 @@ package dns import ( "context" + "errors" "fmt" "math" "net" @@ -52,13 +53,13 @@ type MsgChan struct { } type RecordSet struct { - Addrs []net.IPAddr + Addrs []net.IP TTL uint32 } func (rs *RecordSet) Clone() *RecordSet { return &RecordSet{ - Addrs: append([]net.IPAddr(nil), rs.Addrs...), + Addrs: append([]net.IP(nil), rs.Addrs...), TTL: rs.TTL, } } @@ -149,8 +150,8 @@ func lookupAllTypes( return resCh } -func parseMsg(msg *dns.Msg) ([]net.IPAddr, uint32, bool) { - var addrs []net.IPAddr +func parseMsg(msg *dns.Msg) ([]net.IP, uint32, bool) { + var addrs []net.IP minTTL := uint32(math.MaxUint32) ok := false @@ -158,11 +159,11 @@ func parseMsg(msg *dns.Msg) ([]net.IPAddr, uint32, bool) { switch ipRecord := record.(type) { case *dns.A: ok = true - addrs = append(addrs, net.IPAddr{IP: ipRecord.A}) + addrs = append(addrs, ipRecord.A) minTTL = min(minTTL, record.Header().Ttl) case *dns.AAAA: ok = true - addrs = append(addrs, net.IPAddr{IP: ipRecord.AAAA}) + addrs = append(addrs, ipRecord.AAAA) minTTL = min(minTTL, record.Header().Ttl) } } @@ -175,7 +176,7 @@ func processMessages( resCh <-chan *MsgChan, ) (*RecordSet, error) { var errs []error - var addrs []net.IPAddr + var addrs []net.IP minTTL := uint32(math.MaxUint32) found := false @@ -224,7 +225,7 @@ loop: // Loop until the channel is closed or context is canceled // Only return errors if no addresses were found at all if len(errs) > 0 { - return nil, fmt.Errorf("failed to resolve with %d errors", len(errs)) + return nil, errors.Join(errs...) } return nil, fmt.Errorf("record not found") diff --git a/internal/dns/route.go b/internal/dns/route.go index 1166ed3d..dfc97928 100644 --- a/internal/dns/route.go +++ b/internal/dns/route.go @@ -8,9 +8,9 @@ import ( "time" "github.com/rs/zerolog" + "github.com/samber/lo" "github.com/xvzc/SpoofDPI/internal/config" "github.com/xvzc/SpoofDPI/internal/logging" - "github.com/xvzc/SpoofDPI/internal/ptr" ) type RouteResolver struct { @@ -64,7 +64,7 @@ func (rr *RouteResolver) Resolve( // 1. Check for IP address in domain if ip, err := parseIpAddr(domain); err == nil { - return &RecordSet{Addrs: []net.IPAddr{*ip}, TTL: 0}, nil + return &RecordSet{Addrs: []net.IP{ip}, TTL: 0}, nil } // 4. Handle ROUTE rule (or default) @@ -77,11 +77,10 @@ func (rr *RouteResolver) Resolve( } resolverInfo := resolver.Info()[0] - logger.Trace(). - Str("name", resolverInfo.Name). - Bool("cache", ptr.FromPtr(opts.Cache)). + logger.Debug().Str("mode", resolverInfo.Name).Bool("cache", lo.FromPtr(opts.Cache)). Msgf("ready to resolve") + t1 := time.Now() var rSet *RecordSet var err error if *opts.Mode != config.DNSModeSystem && *opts.Cache { @@ -90,7 +89,17 @@ func (rr *RouteResolver) Resolve( rSet, err = resolver.Resolve(ctx, domain, nil, rule) } - return rSet, err + if err != nil { + return nil, err + } + + logger.Debug(). + Str("domain", domain). + Int("len", len(rSet.Addrs)). + Str("took", fmt.Sprintf("%.3fms", float64(time.Since(t1).Microseconds())/1000.0)). + Msgf("dns lookup ok") + + return rSet, nil } func (rr *RouteResolver) route(attrs *config.DNSOptions) Resolver { @@ -106,15 +115,11 @@ func (rr *RouteResolver) route(attrs *config.DNSOptions) Resolver { } } -func parseIpAddr(addr string) (*net.IPAddr, error) { +func parseIpAddr(addr string) (net.IP, error) { ip := net.ParseIP(addr) if ip == nil { return nil, fmt.Errorf("%s is not an ip address", addr) } - ipAddr := &net.IPAddr{ - IP: ip, - } - - return ipAddr, nil + return ip, nil } diff --git a/internal/dns/system.go b/internal/dns/system.go index 63c2e72d..6769fbd1 100644 --- a/internal/dns/system.go +++ b/internal/dns/system.go @@ -9,9 +9,9 @@ import ( "github.com/xvzc/SpoofDPI/internal/config" ) -var _ Resolver = (*SysResolver)(nil) +var _ Resolver = (*SystemResolver)(nil) -type SysResolver struct { +type SystemResolver struct { logger zerolog.Logger *net.Resolver @@ -21,15 +21,15 @@ type SysResolver struct { func NewSystemResolver( logger zerolog.Logger, dnsOps *config.DNSOptions, -) *SysResolver { - return &SysResolver{ +) *SystemResolver { + return &SystemResolver{ logger: logger, Resolver: &net.Resolver{PreferGo: true}, dnsOpts: dnsOps, } } -func (sr *SysResolver) Info() []ResolverInfo { +func (sr *SystemResolver) Info() []ResolverInfo { return []ResolverInfo{ { Name: "system", @@ -38,7 +38,7 @@ func (sr *SysResolver) Info() []ResolverInfo { } } -func (sr *SysResolver) Resolve( +func (sr *SystemResolver) Resolve( ctx context.Context, domain string, fallback Resolver, @@ -49,18 +49,18 @@ func (sr *SysResolver) Resolve( opts = opts.Merge(rule.DNS) } - addrs, err := sr.LookupIPAddr(ctx, domain) + ips, err := sr.LookupIP(ctx, "ip", domain) if err != nil { return nil, err } return &RecordSet{ - Addrs: filtterAddrs(addrs, parseQueryTypes(*opts.QType)), + Addrs: filtterAddrs(ips, parseQueryTypes(*opts.QType)), TTL: 0, }, nil } -func filtterAddrs(addrs []net.IPAddr, qTypes []uint16) []net.IPAddr { +func filtterAddrs(ips []net.IP, qTypes []uint16) []net.IP { wantsA, wantsAAAA := false, false for _, qType := range qTypes { switch qType { @@ -76,31 +76,31 @@ func filtterAddrs(addrs []net.IPAddr, qTypes []uint16) []net.IPAddr { } if !wantsA && !wantsAAAA { - return []net.IPAddr{} + return []net.IP{} } - filteredMap := make(map[string]net.IPAddr) + filteredMap := make(map[string]net.IP) - for _, addr := range addrs { - addrStr := addr.IP.String() + for _, ip := range ips { + addrStr := ip.String() if _, exists := filteredMap[addrStr]; exists { continue } - isIPv4 := addr.IP.To4() != nil + isIPv4 := ip.To4() != nil if wantsA && isIPv4 { - filteredMap[addrStr] = addr + filteredMap[addrStr] = ip } if wantsAAAA && !isIPv4 { - filteredMap[addrStr] = addr + filteredMap[addrStr] = ip } } - filtered := make([]net.IPAddr, 0, len(filteredMap)) - for _, addr := range filteredMap { - filtered = append(filtered, addr) + filtered := make([]net.IP, 0, len(filteredMap)) + for _, ip := range filteredMap { + filtered = append(filtered, ip) } return filtered diff --git a/internal/logging/logging.go b/internal/logging/logging.go index a32a890d..58b9f245 100644 --- a/internal/logging/logging.go +++ b/internal/logging/logging.go @@ -127,9 +127,8 @@ func (h ctxHook) Run(e *zerolog.Event, level zerolog.Level, msg string) { // Request-scoped values like trace_id. // Scope is expected to be added at the component's creation time. - if traceID, ok := session.TraceIDFrom(ctx); ok { - e.Str(traceIDFieldName, traceID) - } + traceID, _ := session.TraceIDFrom(ctx) + e.Str(traceIDFieldName, traceID) if domain, ok := session.HostInfoFrom(ctx); ok { e.Str(hostInfoFieldName, domain) @@ -140,21 +139,21 @@ type joinableError interface { Unwrap() []error } -// ErrorUnwrapped tries to unwrap an error and prints each error separately. +// ErrorErrs tries to unwrap an error and prints each error separately. // If the error is not joined, it logs the single error normally. -func ErrorUnwrapped(logger *zerolog.Logger, msg string, err error) { - logUnwrapped(logger, zerolog.ErrorLevel, msg, err) +func ErrorErrs(logger *zerolog.Logger, err error, msg string) { + logErrs(logger, zerolog.ErrorLevel, err, msg) } -func WarnUnwrapped(logger *zerolog.Logger, msg string, err error) { - logUnwrapped(logger, zerolog.WarnLevel, msg, err) +func WarnErrs(logger *zerolog.Logger, err error, msg string) { + logErrs(logger, zerolog.WarnLevel, err, msg) } -func TraceUnwrapped(logger *zerolog.Logger, msg string, err error) { - logUnwrapped(logger, zerolog.TraceLevel, msg, err) +func TraceErrs(logger *zerolog.Logger, err error, msg string) { + logErrs(logger, zerolog.TraceLevel, err, msg) } -func logUnwrapped(logger *zerolog.Logger, level zerolog.Level, msg string, err error) { +func logErrs(logger *zerolog.Logger, level zerolog.Level, err error, msg string) { var joinedErrs joinableError if errors.As(err, &joinedErrs) { diff --git a/internal/matcher/addr.go b/internal/matcher/addr.go index 89b1fab2..6a358c48 100644 --- a/internal/matcher/addr.go +++ b/internal/matcher/addr.go @@ -6,8 +6,8 @@ import ( "sort" "sync" + "github.com/samber/lo" "github.com/xvzc/SpoofDPI/internal/config" - "github.com/xvzc/SpoofDPI/internal/ptr" ) // AddrMatcher implements Matcher for IP/CIDR rules. @@ -43,11 +43,11 @@ func (m *AddrMatcher) Add(r *config.Rule) error { } if r.Priority == nil { - r.Priority = ptr.FromValue(uint16(0)) + r.Priority = lo.ToPtr(uint16(0)) } if r.Block == nil { - r.Block = ptr.FromValue(false) + r.Block = lo.ToPtr(false) } m.mu.Lock() diff --git a/internal/matcher/addr_test.go b/internal/matcher/addr_test.go index 90810ac0..ae2e61e9 100644 --- a/internal/matcher/addr_test.go +++ b/internal/matcher/addr_test.go @@ -4,37 +4,37 @@ import ( "net" "testing" + "github.com/samber/lo" "github.com/stretchr/testify/assert" "github.com/xvzc/SpoofDPI/internal/config" - "github.com/xvzc/SpoofDPI/internal/ptr" ) func TestAddrMatcher(t *testing.T) { matcher := NewAddrMatcher() rule1 := &config.Rule{ - Name: ptr.FromValue("rule1"), - Priority: ptr.FromValue(uint16(10)), + Name: lo.ToPtr("rule1"), + Priority: lo.ToPtr(uint16(10)), Match: &config.MatchAttrs{ Addrs: []config.AddrMatch{ { - CIDR: ptr.FromValue(config.MustParseCIDR("192.168.1.0/24")), - PortFrom: ptr.FromValue(uint16(80)), - PortTo: ptr.FromValue(uint16(80)), + CIDR: lo.ToPtr(config.MustParseCIDR("192.168.1.0/24")), + PortFrom: lo.ToPtr(uint16(80)), + PortTo: lo.ToPtr(uint16(80)), }, }, }, } rule2 := &config.Rule{ - Name: ptr.FromValue("rule2"), - Priority: ptr.FromValue(uint16(20)), + Name: lo.ToPtr("rule2"), + Priority: lo.ToPtr(uint16(20)), Match: &config.MatchAttrs{ Addrs: []config.AddrMatch{ { - CIDR: ptr.FromValue(config.MustParseCIDR("10.0.0.0/8")), - PortFrom: ptr.FromValue(uint16(0)), - PortTo: ptr.FromValue(uint16(65535)), + CIDR: lo.ToPtr(config.MustParseCIDR("10.0.0.0/8")), + PortFrom: lo.ToPtr(uint16(0)), + PortTo: lo.ToPtr(uint16(65535)), }, }, }, @@ -42,14 +42,14 @@ func TestAddrMatcher(t *testing.T) { // Overlapping lower priority rule rule3 := &config.Rule{ - Name: ptr.FromValue("rule3"), - Priority: ptr.FromValue(uint16(5)), + Name: lo.ToPtr("rule3"), + Priority: lo.ToPtr(uint16(5)), Match: &config.MatchAttrs{ Addrs: []config.AddrMatch{ { - CIDR: ptr.FromValue(config.MustParseCIDR("172.16.0.0/16")), - PortFrom: ptr.FromValue(uint16(0)), - PortTo: ptr.FromValue(uint16(65535)), + CIDR: lo.ToPtr(config.MustParseCIDR("172.16.0.0/16")), + PortFrom: lo.ToPtr(uint16(0)), + PortTo: lo.ToPtr(uint16(65535)), }, }, }, @@ -57,14 +57,14 @@ func TestAddrMatcher(t *testing.T) { // Overlapping lower priority rule rule4 := &config.Rule{ - Name: ptr.FromValue("rule4"), - Priority: ptr.FromValue(uint16(4)), + Name: lo.ToPtr("rule4"), + Priority: lo.ToPtr(uint16(4)), Match: &config.MatchAttrs{ Addrs: []config.AddrMatch{ { - CIDR: ptr.FromValue(config.MustParseCIDR("172.16.0.0/16")), - PortFrom: ptr.FromValue(uint16(443)), - PortTo: ptr.FromValue(uint16(443)), + CIDR: lo.ToPtr(config.MustParseCIDR("172.16.0.0/16")), + PortFrom: lo.ToPtr(uint16(443)), + PortTo: lo.ToPtr(uint16(443)), }, }, }, @@ -142,7 +142,7 @@ func TestAddrMatcher(t *testing.T) { t.Run(tc.name, func(t *testing.T) { ip := net.ParseIP(tc.ip) port := tc.port - selector := &Selector{IP: &ip, Port: ptr.FromValue(uint16(port))} + selector := &Selector{IP: &ip, Port: lo.ToPtr(uint16(port))} output := matcher.Search(selector) tc.assert(t, output) }) diff --git a/internal/matcher/domain.go b/internal/matcher/domain.go index 7aa3a6e5..9d021054 100644 --- a/internal/matcher/domain.go +++ b/internal/matcher/domain.go @@ -5,8 +5,8 @@ import ( "strings" "sync" + "github.com/samber/lo" "github.com/xvzc/SpoofDPI/internal/config" - "github.com/xvzc/SpoofDPI/internal/ptr" ) // node represents a single node in the radix tree implementation. @@ -72,11 +72,11 @@ func (t *DomainMatcher) Add(r *config.Rule) error { } if r.Priority == nil { - r.Priority = ptr.FromValue(uint16(0)) + r.Priority = lo.ToPtr(uint16(0)) } if r.Block == nil { - r.Block = ptr.FromValue(false) + r.Block = lo.ToPtr(false) } t.mu.Lock() diff --git a/internal/matcher/domain_test.go b/internal/matcher/domain_test.go index 69ad5704..dab97f17 100644 --- a/internal/matcher/domain_test.go +++ b/internal/matcher/domain_test.go @@ -3,9 +3,9 @@ package matcher import ( "testing" + "github.com/samber/lo" "github.com/stretchr/testify/assert" "github.com/xvzc/SpoofDPI/internal/config" - "github.com/xvzc/SpoofDPI/internal/ptr" ) func TestDomainMatcher(t *testing.T) { @@ -13,24 +13,24 @@ func TestDomainMatcher(t *testing.T) { matcher := NewDomainMatcher() rule1 := &config.Rule{ - Name: ptr.FromValue("rule1"), - Priority: ptr.FromValue(uint16(10)), + Name: lo.ToPtr("rule1"), + Priority: lo.ToPtr(uint16(10)), Match: &config.MatchAttrs{ Domains: []string{"example.com"}, }, } rule2 := &config.Rule{ - Name: ptr.FromValue("rule2"), - Priority: ptr.FromValue(uint16(20)), + Name: lo.ToPtr("rule2"), + Priority: lo.ToPtr(uint16(20)), Match: &config.MatchAttrs{ Domains: []string{"*.google.com"}, }, } rule3 := &config.Rule{ - Name: ptr.FromValue("rule3"), - Priority: ptr.FromValue(uint16(5)), + Name: lo.ToPtr("rule3"), + Priority: lo.ToPtr(uint16(5)), Match: &config.MatchAttrs{ Domains: []string{"**.youtube.com"}, }, @@ -38,8 +38,8 @@ func TestDomainMatcher(t *testing.T) { // Additional rule for priority check rule4 := &config.Rule{ - Name: ptr.FromValue("rule4"), - Priority: ptr.FromValue(uint16(30)), + Name: lo.ToPtr("rule4"), + Priority: lo.ToPtr(uint16(30)), Match: &config.MatchAttrs{ Domains: []string{"mail.google.com"}, }, @@ -57,7 +57,7 @@ func TestDomainMatcher(t *testing.T) { }{ { name: "exact match", - selector: &Selector{Domain: ptr.FromValue("example.com")}, + selector: &Selector{Domain: lo.ToPtr("example.com")}, assert: func(t *testing.T, output *config.Rule) { assert.NotNil(t, output) assert.Equal(t, "rule1", *output.Name) @@ -65,7 +65,7 @@ func TestDomainMatcher(t *testing.T) { }, { name: "wildcard match", - selector: &Selector{Domain: ptr.FromValue("maps.google.com")}, + selector: &Selector{Domain: lo.ToPtr("maps.google.com")}, assert: func(t *testing.T, output *config.Rule) { assert.NotNil(t, output) assert.Equal(t, "rule2", *output.Name) @@ -73,7 +73,7 @@ func TestDomainMatcher(t *testing.T) { }, { name: "globstar match", - selector: &Selector{Domain: ptr.FromValue("foo.bar.youtube.com")}, + selector: &Selector{Domain: lo.ToPtr("foo.bar.youtube.com")}, assert: func(t *testing.T, output *config.Rule) { assert.NotNil(t, output) assert.Equal(t, "rule3", *output.Name) @@ -81,7 +81,7 @@ func TestDomainMatcher(t *testing.T) { }, { name: "wildcard higher priority check", - selector: &Selector{Domain: ptr.FromValue("mail.google.com")}, + selector: &Selector{Domain: lo.ToPtr("mail.google.com")}, assert: func(t *testing.T, output *config.Rule) { // Should pick rule4 (priority 30) over rule2 (priority 20) assert.NotNil(t, output) @@ -90,7 +90,7 @@ func TestDomainMatcher(t *testing.T) { }, { name: "no match", - selector: &Selector{Domain: ptr.FromValue("naver.com")}, + selector: &Selector{Domain: lo.ToPtr("naver.com")}, assert: func(t *testing.T, output *config.Rule) { assert.Nil(t, output) }, diff --git a/internal/matcher/matcher.go b/internal/matcher/matcher.go index 5d3718c1..3e269a25 100644 --- a/internal/matcher/matcher.go +++ b/internal/matcher/matcher.go @@ -4,8 +4,8 @@ import ( "fmt" "net" + "github.com/samber/lo" "github.com/xvzc/SpoofDPI/internal/config" - "github.com/xvzc/SpoofDPI/internal/ptr" ) // ----------------------------------------------------------------------------- @@ -142,7 +142,7 @@ func GetHigherPriorityRule(r1, r2 *config.Rule) *config.Rule { return r1 } - if ptr.FromPtr(r1.Priority) >= ptr.FromPtr(r2.Priority) { + if lo.FromPtr(r1.Priority) >= lo.FromPtr(r2.Priority) { return r1 } return r2 diff --git a/internal/matcher/matcher_test.go b/internal/matcher/matcher_test.go index 4e9156c2..5079963c 100644 --- a/internal/matcher/matcher_test.go +++ b/internal/matcher/matcher_test.go @@ -3,14 +3,14 @@ package matcher import ( "testing" + "github.com/samber/lo" "github.com/stretchr/testify/assert" "github.com/xvzc/SpoofDPI/internal/config" - "github.com/xvzc/SpoofDPI/internal/ptr" ) func TestGetHigherPriorityRule(t *testing.T) { - r1 := &config.Rule{Priority: ptr.FromValue(uint16(10))} - r2 := &config.Rule{Priority: ptr.FromValue(uint16(20))} + r1 := &config.Rule{Priority: lo.ToPtr(uint16(10))} + r2 := &config.Rule{Priority: lo.ToPtr(uint16(20))} tcs := []struct { name string diff --git a/internal/netutil/addr.go b/internal/netutil/addr.go index f7fae42c..24b85548 100644 --- a/internal/netutil/addr.go +++ b/internal/netutil/addr.go @@ -3,18 +3,27 @@ package netutil import ( "fmt" "net" + "os/exec" + "regexp" + "strconv" "time" ) type Destination struct { Domain string - Addrs []net.IPAddr + Addrs []net.IP Port int Timeout time.Duration + Iface *net.Interface + Gateway string +} + +func (d *Destination) String() string { + return net.JoinHostPort(d.Domain, strconv.Itoa(d.Port)) } func ValidateDestination( - dstAddrs []net.IPAddr, + dstAddrs []net.IP, dstPort int, listenAddr *net.TCPAddr, ) (bool, error) { @@ -27,7 +36,7 @@ func ValidateDestination( ifAddrs, err = net.InterfaceAddrs() for _, dstAddr := range dstAddrs { - ip := dstAddr.IP + ip := dstAddr if ip.IsLoopback() { return false, fmt.Errorf("loopback addr detected %v", ip.String()) } @@ -42,4 +51,121 @@ func ValidateDestination( } return true, err -} \ No newline at end of file +} + +// FindSafeSubnet scans the 10.0.0.0/8 range to find an unused /30 subnet +func FindSafeSubnet() (string, string, error) { + /* Retrieve all active interface addresses to prevent CIDR overlapping. + Checking against existing networks is faster than sending probe packets. + */ + addrs, err := net.InterfaceAddrs() + if err != nil { + return "", "", err + } + + var existingNets []*net.IPNet + for _, addr := range addrs { + if ipnet, ok := addr.(*net.IPNet); ok { + existingNets = append(existingNets, ipnet) + } + } + + /* Iterate through the 10.0.0.0/8 private range with a /30 step. + A /30 subnet provides exactly two usable end-point IP addresses. + */ + for i := 0; i < 256; i++ { + for j := 0; j < 256; j++ { + // Construct candidate IP pair: 10.i.j.1 and 10.i.j.2 + local := net.IPv4(10, byte(i), byte(j), 1) + remote := net.IPv4(10, byte(i), byte(j), 2) + + conflict := false + for _, ipnet := range existingNets { + if ipnet.Contains(local) || ipnet.Contains(remote) { + conflict = true + break + } + } + + if !conflict { + return local.String(), remote.String(), nil + } + } + } + + return "", "", fmt.Errorf("failed to find an available address in 10.0.0.0/8") +} + +// GetDefaultInterfaceAndGateway returns the name of the default network interface and the gateway IP +func GetDefaultInterfaceAndGateway() (string, string, error) { + // Dial a public DNS server to determine the default interface + conn, err := net.Dial("udp", "8.8.8.8:53") + if err != nil { + return "", "", err + } + defer func() { _ = conn.Close() }() + + localAddr := conn.LocalAddr().(*net.UDPAddr) + + ifaces, err := net.Interfaces() + if err != nil { + return "", "", err + } + + var ifaceName string + for _, iface := range ifaces { + addrs, err := iface.Addrs() + if err != nil { + continue + } + + for _, addr := range addrs { + if ipnet, ok := addr.(*net.IPNet); ok { + if ipnet.IP.Equal(localAddr.IP) { + ifaceName = iface.Name + break + } + } + } + if ifaceName != "" { + break + } + } + + if ifaceName == "" { + return "", "", fmt.Errorf("default interface not found") + } + + // Get gateway by parsing route table + gateway, err := getDefaultGateway() + if err != nil { + return "", "", fmt.Errorf("failed to get default gateway: %w", err) + } + + return ifaceName, gateway, nil +} + +// getDefaultGateway parses the system route table to find the default gateway +func getDefaultGateway() (string, error) { + // Use netstat to get the default route on macOS + cmd := exec.Command("route", "-n", "get", "default") + out, err := cmd.Output() + if err != nil { + return "", err + } + + // Parse output to find gateway line + re := regexp.MustCompile(`gateway:\s+(\d+\.\d+\.\d+\.\d+)`) + matches := re.FindSubmatch(out) + if len(matches) < 2 { + return "", fmt.Errorf("could not parse gateway from route output") + } + + return string(matches[1]), nil +} + +// GetDefaultInterface returns the name of the default network interface +func GetDefaultInterface() (string, error) { + ifaceName, _, err := GetDefaultInterfaceAndGateway() + return ifaceName, err +} diff --git a/internal/netutil/conn.go b/internal/netutil/conn.go index 2f7a4d09..62754f91 100644 --- a/internal/netutil/conn.go +++ b/internal/netutil/conn.go @@ -1,18 +1,34 @@ package netutil import ( + "bufio" "context" "errors" "fmt" "io" "net" + "os" "sync" "syscall" + "time" "github.com/rs/zerolog" - "github.com/xvzc/SpoofDPI/internal/logging" ) +type TunnelDirType int + +const ( + TunnelDirOut TunnelDirType = iota + TunnelDirIn +) + +// TransferResult holds the result of a unidirectional tunnel transfer. +type TransferResult struct { + Written int64 + Dir TunnelDirType + Err error +} + // bufferPool is a package-level pool of 32KB buffers used by io.CopyBuffer // to reduce memory allocations and GC pressure in the tunnel hot path. var bufferPool = sync.Pool{ @@ -24,16 +40,16 @@ var bufferPool = sync.Pool{ }, } +// TunnelConns copies data from src to dst. +// It sends the result to resCh upon completion. +// It filters out benign errors like EOF, pipe closed, or read timeouts (for UDP). func TunnelConns( ctx context.Context, - logger zerolog.Logger, - errCh chan<- error, - dst net.Conn, // Destination connection (io.Writer) + resCh chan<- TransferResult, src net.Conn, // Source connection (io.Reader) + dst net.Conn, // Destination connection (io.Writer) + dir TunnelDirType, ) { - var n int64 - logger = logging.WithLocalScope(ctx, logger, "tunnel") - var once sync.Once closeOnce := func() { once.Do(func() { @@ -46,11 +62,6 @@ func TunnelConns( defer func() { stop() closeOnce() - - logger.Trace(). - Int64("len", n). - Str("route", fmt.Sprintf("%s -> %s", src.RemoteAddr(), dst.RemoteAddr())). - Msgf("done") }() bufPtr := bufferPool.Get().(*[]byte) @@ -58,13 +69,79 @@ func TunnelConns( // Copy data from src to dst using the borrowed buffer. n, err := io.CopyBuffer(dst, src, *bufPtr) + + // Filter benign errors. + // os.IsTimeout is checked to handle UDP idle timeouts gracefully. if err != nil && !errors.Is(err, net.ErrClosed) && !errors.Is(err, io.EOF) && - !errors.Is(err, syscall.EPIPE) { - errCh <- err + !errors.Is(err, syscall.EPIPE) && !os.IsTimeout(err) { + resCh <- TransferResult{Written: n, Dir: dir, Err: err} return } - errCh <- nil + resCh <- TransferResult{Written: n, Dir: dir, Err: nil} +} + +// WaitAndLogTunnel aggregates results and logs the summary. +// errHandler processes the list of errors to determine the final error. +func WaitAndLogTunnel( + ctx context.Context, + logger zerolog.Logger, + resCh <-chan TransferResult, + startedAt time.Time, + route string, + errHandler func(errs []error) error, // [Modified] Accepts slice handler +) error { + var ( + outBytes int64 + inBytes int64 + errs []error // Collect all errors + ) + + // Wait for exactly 2 results. + for range 2 { + res := <-resCh + + switch res.Dir { + case TunnelDirOut: + outBytes = res.Written + case TunnelDirIn: + inBytes = res.Written + default: + return fmt.Errorf("invalid tunnel dir") + } + + if res.Err != nil { + errs = append(errs, res.Err) + } + } + + duration := float64(time.Since(startedAt).Microseconds()) / 1000.0 + logger.Trace(). + Int64("out", outBytes). + Int64("in", inBytes). + Str("took", fmt.Sprintf("%.3fms", duration)). + Str("route", route). + Int("errs", len(errs)). + Msg("tunnel closed") + + if errHandler != nil { + return errHandler(errs) + } + + if len(errs) > 0 { + return errs[0] + } + + return nil +} + +func DescribeRoute(src, dst net.Conn) string { + return fmt.Sprintf("%s(%s) -> %s(%s)", + src.RemoteAddr(), + src.RemoteAddr().Network(), + dst.RemoteAddr(), + dst.RemoteAddr().Network(), + ) } // CloseConns safely closes one or more io.Closer (like net.Conn). @@ -74,7 +151,6 @@ func TunnelConns( func CloseConns(closers ...io.Closer) { for _, c := range closers { if c != nil { - // Intentionally ignore the error. _ = c.Close() } } @@ -103,7 +179,7 @@ func SetTTL(conn net.Conn, isIPv4 bool, ttl uint8) error { var sysErr error - // Invoke Control to manipulate file descriptor directly + // Invoke Control to manipulate file descriptor directly. err = rawConn.Control(func(fd uintptr) { sysErr = syscall.SetsockoptInt(int(fd), level, opt, int(ttl)) }) @@ -113,3 +189,62 @@ func SetTTL(conn net.Conn, isIPv4 bool, ttl uint8) error { return sysErr } + +// BufferedConn wraps a net.Conn with a bufio.Reader to support peeking. +type BufferedConn struct { + r *bufio.Reader + net.Conn +} + +func NewBufferedConn(c net.Conn) *BufferedConn { + return &BufferedConn{ + r: bufio.NewReader(c), + Conn: c, + } +} + +func (b *BufferedConn) Read(p []byte) (int, error) { + return b.r.Read(p) +} + +func (b *BufferedConn) Peek(n int) ([]byte, error) { + return b.r.Peek(n) +} + +// TimeoutConn wraps a net.Conn to update the read deadline on every Read call. +// This is useful for UDP sessions which do not have a natural EOF. +type TimeoutConn struct { + net.Conn + Timeout time.Duration + LastActive time.Time + ExpiredAt time.Time // Calculated expiration time for cleanup +} + +func (c *TimeoutConn) Read(b []byte) (int, error) { + c.ExtendDeadline() + return c.Conn.Read(b) +} + +func (c *TimeoutConn) Write(b []byte) (int, error) { + c.ExtendDeadline() + return c.Conn.Write(b) +} + +// ExtendDeadline attempts to extend the connection's deadline. +// Returns false if the connection was already expired. +func (c *TimeoutConn) ExtendDeadline() bool { + now := time.Now() + + // Check if already expired + if !c.ExpiredAt.IsZero() && now.After(c.ExpiredAt) { + return false + } + + c.LastActive = now + if c.Timeout > 0 { + c.ExpiredAt = now.Add(c.Timeout) + _ = c.SetReadDeadline(c.ExpiredAt) + _ = c.SetWriteDeadline(c.ExpiredAt) + } + return true +} diff --git a/internal/netutil/conn_pool.go b/internal/netutil/conn_pool.go new file mode 100644 index 00000000..f730a85f --- /dev/null +++ b/internal/netutil/conn_pool.go @@ -0,0 +1,206 @@ +package netutil + +import ( + "container/list" + "net" + "sync" + "time" +) + +// ConnPool manages UDP connections with LRU eviction policy and idle timeout. +type ConnPool struct { + capacity int + timeout time.Duration + cache map[string]*list.Element + ll *list.List + mu sync.Mutex + stopCh chan struct{} + stopOnce sync.Once +} + +// PooledConn wraps net.Conn with LRU tracking and deadline management. +type PooledConn struct { + net.Conn + pool *ConnPool + key string + timeout time.Duration + expiredAt time.Time +} + +type connEntry struct { + key string + conn *PooledConn +} + +// NewConnPool creates a new pool with the specified capacity and timeout. +// Starts a background goroutine for expired connection cleanup. +func NewConnPool(capacity int, timeout time.Duration) *ConnPool { + p := &ConnPool{ + capacity: capacity, + timeout: timeout, + cache: make(map[string]*list.Element), + ll: list.New(), + stopCh: make(chan struct{}), + } + + // Cleanup interval: half of timeout, min 10s, max 60s + cleanupInterval := timeout / 2 + if cleanupInterval < 10*time.Second { + cleanupInterval = 10 * time.Second + } + if cleanupInterval > 60*time.Second { + cleanupInterval = 60 * time.Second + } + + go p.cleanupLoop(cleanupInterval) + return p +} + +// Add adds a connection to the pool and returns the wrapped connection. +// If capacity is full, evicts the least recently used connection. +func (p *ConnPool) Add(key string, rawConn net.Conn) *PooledConn { + p.mu.Lock() + defer p.mu.Unlock() + + // Evict if capacity is reached + if p.ll.Len() >= p.capacity { + p.evictOldest() + } + + now := time.Now() + expiredAt := now.Add(p.timeout) + + wrapper := &PooledConn{ + Conn: rawConn, + pool: p, + key: key, + timeout: p.timeout, + expiredAt: expiredAt, + } + + _ = rawConn.SetDeadline(expiredAt) + + elem := p.ll.PushFront(&connEntry{key: key, conn: wrapper}) + p.cache[key] = elem + + return wrapper +} + +// Remove closes and removes the connection from the pool. +func (p *ConnPool) Remove(key string) { + p.mu.Lock() + defer p.mu.Unlock() + + if elem, ok := p.cache[key]; ok { + p.removeElement(elem) + } +} + +// Size returns the number of connections in the pool. +func (p *ConnPool) Size() int { + p.mu.Lock() + defer p.mu.Unlock() + return p.ll.Len() +} + +// Stop stops the background cleanup goroutine. +func (p *ConnPool) Stop() { + p.stopOnce.Do(func() { + close(p.stopCh) + }) +} + +// CloseAll closes all connections in the pool. +func (p *ConnPool) CloseAll() { + p.mu.Lock() + defer p.mu.Unlock() + + elem := p.ll.Front() + for elem != nil { + next := elem.Next() + p.removeElement(elem) + elem = next + } +} + +func (p *ConnPool) cleanupLoop(interval time.Duration) { + ticker := time.NewTicker(interval) + defer ticker.Stop() + + for { + select { + case <-p.stopCh: + return + case <-ticker.C: + p.evictExpired() + } + } +} + +func (p *ConnPool) evictExpired() { + p.mu.Lock() + defer p.mu.Unlock() + + now := time.Now() + elem := p.ll.Back() + for elem != nil { + // Save next before potential removal + next := elem.Prev() + e := elem.Value.(*connEntry) + if now.After(e.conn.expiredAt) { + p.removeElement(elem) + } + elem = next + } +} + +func (p *ConnPool) evictOldest() { + if elem := p.ll.Back(); elem != nil { + p.removeElement(elem) + } +} + +func (p *ConnPool) removeElement(elem *list.Element) { + e := elem.Value.(*connEntry) + _ = e.conn.Conn.Close() + p.ll.Remove(elem) + delete(p.cache, e.key) +} + +func (p *ConnPool) touch(key string) { + p.mu.Lock() + defer p.mu.Unlock() + if elem, ok := p.cache[key]; ok { + p.ll.MoveToFront(elem) + } +} + +func (c *PooledConn) refreshDeadline() { + c.expiredAt = time.Now().Add(c.timeout) + _ = c.SetDeadline(c.expiredAt) + c.pool.touch(c.key) +} + +// Read reads data and refreshes the deadline on success. +func (c *PooledConn) Read(b []byte) (n int, err error) { + n, err = c.Conn.Read(b) + if n > 0 { + c.refreshDeadline() + } + return +} + +// Write writes data and refreshes the deadline on success. +func (c *PooledConn) Write(b []byte) (n int, err error) { + n, err = c.Conn.Write(b) + if n > 0 { + c.refreshDeadline() + } + return +} + +// Close removes the connection from the pool (underlying close handled by pool). +func (c *PooledConn) Close() error { + c.pool.Remove(c.key) + return nil +} diff --git a/internal/netutil/dial.go b/internal/netutil/dial.go index 895e377d..dd6293bb 100644 --- a/internal/netutil/dial.go +++ b/internal/netutil/dial.go @@ -19,24 +19,22 @@ type dialResult struct { func DialFastest( ctx context.Context, network string, - addrs []net.IPAddr, - port int, - timeout time.Duration, + dst *Destination, ) (net.Conn, error) { - if len(addrs) == 0 { + if len(dst.Addrs) == 0 { return nil, fmt.Errorf("no addresses provided to dial") } ctx, cancel := context.WithCancel(ctx) defer cancel() - results := make(chan dialResult, len(addrs)) + results := make(chan dialResult, len(dst.Addrs)) const maxConcurrency = 10 sem := make(chan struct{}, maxConcurrency) // semaphore go func() { - for _, addr := range addrs { + for _, addr := range dst.Addrs { // Get semaphore select { case sem <- struct{}{}: @@ -47,10 +45,15 @@ func DialFastest( go func(ip net.IP) { defer func() { <-sem }() // Return semaphore - targetAddr := net.JoinHostPort(ip.String(), strconv.Itoa(port)) + targetAddr := net.JoinHostPort(ip.String(), strconv.Itoa(dst.Port)) dialer := &net.Dialer{} - if timeout > 0 { - dialer.Deadline = time.Now().Add(timeout) + if dst.Timeout > 0 { + dialer.Deadline = time.Now().Add(dst.Timeout) + } + + // If Iface is specified, bind to the interface + if dst.Iface != nil { + bindToInterface(dialer, dst.Iface, ip) } conn, err := dialer.DialContext(ctx, network, targetAddr) @@ -62,14 +65,14 @@ func DialFastest( _ = conn.Close() // Close on context cancel } } - }(addr.IP) + }(addr) } }() var firstError error failureCount := 0 - for range addrs { + for range dst.Addrs { select { case res := <-results: if res.err == nil { diff --git a/internal/netutil/dial_darwin.go b/internal/netutil/dial_darwin.go new file mode 100644 index 00000000..35d26553 --- /dev/null +++ b/internal/netutil/dial_darwin.go @@ -0,0 +1,40 @@ +//go:build darwin + +package netutil + +import ( + "net" + "syscall" + + "golang.org/x/sys/unix" +) + +// bindToInterface sets the dialer's Control function to bind the socket +// to a specific network interface using IP_BOUND_IF on Darwin. +func bindToInterface(dialer *net.Dialer, iface *net.Interface, targetIP net.IP) { + addrs, _ := iface.Addrs() + for _, addr := range addrs { + if ipnet, ok := addr.(*net.IPNet); ok { + if targetIP.To4() != nil && ipnet.IP.To4() != nil && !ipnet.IP.IsLoopback() { + ifaceIndex := iface.Index + dialer.Control = func(network, address string, c syscall.RawConn) error { + var setsockoptErr error + err := c.Control(func(fd uintptr) { + setsockoptErr = unix.SetsockoptInt( + int(fd), + unix.IPPROTO_IP, + unix.IP_BOUND_IF, + ifaceIndex, + ) + }) + if err != nil { + return err + } + + return setsockoptErr + } + break + } + } + } +} diff --git a/internal/netutil/dial_other.go b/internal/netutil/dial_other.go new file mode 100644 index 00000000..0d0d0fa9 --- /dev/null +++ b/internal/netutil/dial_other.go @@ -0,0 +1,11 @@ +//go:build !darwin + +package netutil + +import "net" + +// bindToInterface is a no-op on non-Darwin systems. +// Interface binding is handled differently or not supported on other platforms. +func bindToInterface(dialer *net.Dialer, iface *net.Interface, targetIP net.IP) { + // No-op: interface binding via IP_BOUND_IF is Darwin-specific +} diff --git a/internal/netutil/pac.go b/internal/netutil/pac.go new file mode 100644 index 00000000..67b27a11 --- /dev/null +++ b/internal/netutil/pac.go @@ -0,0 +1,34 @@ +package netutil + +import ( + "fmt" + "net" + "net/http" +) + +func RunPACServer(content string) (string, net.Listener, error) { + listener, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + return "", nil, err + } + + mux := http.NewServeMux() + mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/x-ns-proxy-autoconfig") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(content)) + }) + + server := &http.Server{ + Handler: mux, + } + + go func() { + _ = server.Serve(listener) + }() + + addr := listener.Addr().(*net.TCPAddr) + url := fmt.Sprintf("http://127.0.0.1:%d/proxy.pac", addr.Port) + + return url, listener, nil +} diff --git a/internal/packet/LICENSE b/internal/packet/LICENSE deleted file mode 100644 index 8dada3ed..00000000 --- a/internal/packet/LICENSE +++ /dev/null @@ -1,201 +0,0 @@ - Apache License - Version 2.0, January 2004 - http://www.apache.org/licenses/ - - TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION - - 1. Definitions. - - "License" shall mean the terms and conditions for use, reproduction, - and distribution as defined by Sections 1 through 9 of this document. - - "Licensor" shall mean the copyright owner or entity authorized by - the copyright owner that is granting the License. - - "Legal Entity" shall mean the union of the acting entity and all - other entities that control, are controlled by, or are under common - control with that entity. For the purposes of this definition, - "control" means (i) the power, direct or indirect, to cause the - direction or management of such entity, whether by contract or - otherwise, or (ii) ownership of fifty percent (50%) or more of the - outstanding shares, or (iii) beneficial ownership of such entity. - - "You" (or "Your") shall mean an individual or Legal Entity - exercising permissions granted by this License. - - "Source" form shall mean the preferred form for making modifications, - including but not limited to software source code, documentation - source, and configuration files. - - "Object" form shall mean any form resulting from mechanical - transformation or translation of a Source form, including but - not limited to compiled object code, generated documentation, - and conversions to other media types. - - "Work" shall mean the work of authorship, whether in Source or - Object form, made available under the License, as indicated by a - copyright notice that is included in or attached to the work - (an example is provided in the Appendix below). - - "Derivative Works" shall mean any work, whether in Source or Object - form, that is based on (or derived from) the Work and for which the - editorial revisions, annotations, elaborations, or other modifications - represent, as a whole, an original work of authorship. For the purposes - of this License, Derivative Works shall not include works that remain - separable from, or merely link (or bind by name) to the interfaces of, - the Work and Derivative Works thereof. - - "Contribution" shall mean any work of authorship, including - the original version of the Work and any modifications or additions - to that Work or Derivative Works thereof, that is intentionally - submitted to Licensor for inclusion in the Work by the copyright owner - or by an individual or Legal Entity authorized to submit on behalf of - the copyright owner. For the purposes of this definition, "submitted" - means any form of electronic, verbal, or written communication sent - to the Licensor or its representatives, including but not limited to - communication on electronic mailing lists, source code control systems, - and issue tracking systems that are managed by, or on behalf of, the - Licensor for the purpose of discussing and improving the Work, but - excluding communication that is conspicuously marked or otherwise - designated in writing by the copyright owner as "Not a Contribution." - - "Contributor" shall mean Licensor and any individual or Legal Entity - on behalf of whom a Contribution has been received by Licensor and - subsequently incorporated within the Work. - - 2. Grant of Copyright License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - copyright license to reproduce, prepare Derivative Works of, - publicly display, publicly perform, sublicense, and distribute the - Work and such Derivative Works in Source or Object form. - - 3. Grant of Patent License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - (except as stated in this section) patent license to make, have made, - use, offer to sell, sell, import, and otherwise transfer the Work, - where such license applies only to those patent claims licensable - by such Contributor that are necessarily infringed by their - Contribution(s) alone or by combination of their Contribution(s) - with the Work to which such Contribution(s) was submitted. If You - institute patent litigation against any entity (including a - cross-claim or counterclaim in a lawsuit) alleging that the Work - or a Contribution incorporated within the Work constitutes direct - or contributory patent infringement, then any patent licenses - granted to You under this License for that Work shall terminate - as of the date such litigation is filed. - - 4. Redistribution. You may reproduce and distribute copies of the - Work or Derivative Works thereof in any medium, with or without - modifications, and in Source or Object form, provided that You - meet the following conditions: - - (a) You must give any other recipients of the Work or - Derivative Works a copy of this License; and - - (b) You must cause any modified files to carry prominent notices - stating that You changed the files; and - - (c) You must retain, in the Source form of any Derivative Works - that You distribute, all copyright, patent, trademark, and - attribution notices from the Source form of the Work, - excluding those notices that do not pertain to any part of - the Derivative Works; and - - (d) If the Work includes a "NOTICE" text file as part of its - distribution, then any Derivative Works that You distribute must - include a readable copy of the attribution notices contained - within such NOTICE file, excluding those notices that do not - pertain to any part of the Derivative Works, in at least one - of the following places: within a NOTICE text file distributed - as part of the Derivative Works; within the Source form or - documentation, if provided along with the Derivative Works; or, - within a display generated by the Derivative Works, if and - wherever such third-party notices normally appear. The contents - of the NOTICE file are for informational purposes only and - do not modify the License. You may add Your own attribution - notices within Derivative Works that You distribute, alongside - or as an addendum to the NOTICE text from the Work, provided - that such additional attribution notices cannot be construed - as modifying the License. - - You may add Your own copyright statement to Your modifications and - may provide additional or different license terms and conditions - for use, reproduction, or distribution of Your modifications, or - for any such Derivative Works as a whole, provided Your use, - reproduction, and distribution of the Work otherwise complies with - the conditions stated in this License. - - 5. Submission of Contributions. Unless You explicitly state otherwise, - any Contribution intentionally submitted for inclusion in the Work - by You to the Licensor shall be under the terms and conditions of - this License, without any additional terms or conditions. - Notwithstanding the above, nothing herein shall supersede or modify - the terms of any separate license agreement you may have executed - with Licensor regarding such Contributions. - - 6. Trademarks. This License does not grant permission to use the trade - names, trademarks, service marks, or product names of the Licensor, - except as required for reasonable and customary use in describing the - origin of the Work and reproducing the content of the NOTICE file. - - 7. Disclaimer of Warranty. Unless required by applicable law or - agreed to in writing, Licensor provides the Work (and each - Contributor provides its Contributions) on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or - implied, including, without limitation, any warranties or conditions - of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A - PARTICULAR PURPOSE. You are solely responsible for determining the - appropriateness of using or redistributing the Work and assume any - risks associated with Your exercise of permissions under this License. - - 8. Limitation of Liability. In no event and under no legal theory, - whether in tort (including negligence), contract, or otherwise, - unless required by applicable law (such as deliberate and grossly - negligent acts) or agreed to in writing, shall any Contributor be - liable to You for damages, including any direct, indirect, special, - incidental, or consequential damages of any character arising as a - result of this License or out of the use or inability to use the - Work (including but not limited to damages for loss of goodwill, - work stoppage, computer failure or malfunction, or any and all - other commercial damages or losses), even if such Contributor - has been advised of the possibility of such damages. - - 9. Accepting Warranty or Additional Liability. While redistributing - the Work or Derivative Works thereof, You may choose to offer, - and charge a fee for, acceptance of support, warranty, indemnity, - or other liability obligations and/or rights consistent with this - License. However, in accepting such obligations, You may act only - on Your own behalf and on Your sole responsibility, not on behalf - of any other Contributor, and only if You agree to indemnify, - defend, and hold each Contributor harmless for any liability - incurred by, or claims asserted against, such Contributor by reason - of your accepting any such warranty or additional liability. - - END OF TERMS AND CONDITIONS - - APPENDIX: How to apply the Apache License to your work. - - To apply the Apache License to your work, attach the following - boilerplate notice, with the fields enclosed by brackets "{}" - replaced with your own identifying information. (Don't include - the brackets!) The text should be enclosed in the appropriate - comment syntax for the file format. We also recommend that a - file or class name and description of purpose be included on the - same "printed page" as the copyright notice for easier - identification within third-party archives. - - Copyright {yyyy} {name of copyright owner} - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. diff --git a/internal/packet/handle_linux.go b/internal/packet/handle_linux.go index 03247a5d..434a34b2 100644 --- a/internal/packet/handle_linux.go +++ b/internal/packet/handle_linux.go @@ -87,12 +87,29 @@ func (h *LinuxHandle) WritePacketData(data []byte) error { return unix.Sendto(h.fd, data, 0, addr) } -// LinkType logic for handling "any" interface (Linux SLL) vs Ethernet func (h *LinuxHandle) LinkType() layers.LinkType { + // 1. "Any" device (tcpdump -i any) uses Linux SLL (Cooked Mode) if h.ifIndex == 0 { - // "any" interface uses Linux SLL (Cooked Mode) return layers.LinkTypeLinuxSLL } + + // 2. Get interface information by Index + iface, err := net.InterfaceByIndex(h.ifIndex) + if err != nil { + // If fails to find interface, fallback to Ethernet + return layers.LinkTypeEthernet + } + + // 3. Check for VPN / TUN characteristics + // - Case A: No Hardware Address (MAC) -> Typical for TUN devices + // - Case B: Point-to-Point Flag -> Typical for VPN tunnels (WireGuard, OpenVPN) + if len(iface.HardwareAddr) == 0 || iface.Flags&net.FlagPointToPoint != 0 { + // Linux TUN devices provide Raw IP packets (No Link Header) + // This corresponds to DLT_RAW (12) or DLT_IPV4/IPV6 (101) + return layers.LinkTypeRaw + } + + // 4. Default to Ethernet return layers.LinkTypeEthernet } diff --git a/internal/packet/network_detector.go b/internal/packet/network_detector.go index 797dd6e8..c4fd4601 100644 --- a/internal/packet/network_detector.go +++ b/internal/packet/network_detector.go @@ -13,12 +13,12 @@ import ( "github.com/xvzc/SpoofDPI/internal/netutil" ) -var dnsServers = []net.IPAddr{ - {IP: net.ParseIP("8.8.8.8")}, - {IP: net.ParseIP("8.8.4.4")}, - {IP: net.ParseIP("1.1.1.1")}, - {IP: net.ParseIP("1.0.0.1")}, - {IP: net.ParseIP("9.9.9.9")}, +var dnsServers = []net.IP{ + net.ParseIP("8.8.8.8"), + net.ParseIP("8.8.4.4"), + net.ParseIP("1.1.1.1"), + net.ParseIP("1.0.0.1"), + net.ParseIP("9.9.9.9"), } type NetworkDetector struct { @@ -69,13 +69,46 @@ func (nd *NetworkDetector) Start(ctx context.Context) error { } }() - conn, err := netutil.DialFastest(ctx, "udp", dnsServers, 53, time.Duration(0)) + go func() { + // Wait for the packet capture to start + select { + case <-ctx.Done(): + return + case <-time.After(300 * time.Millisecond): + } + + ticker := time.NewTicker(500 * time.Millisecond) + defer ticker.Stop() + + nd.probe(ctx) + + for { + select { + case <-ctx.Done(): + return + case <-nd.found: + return + case <-ticker.C: + nd.probe(ctx) + } + } + }() + + return nil +} + +func (nd *NetworkDetector) probe(ctx context.Context) { + conn, err := netutil.DialFastest(ctx, "udp", &netutil.Destination{ + Addrs: dnsServers, + Port: 53, + Timeout: 2 * time.Second, + }) if err != nil { - return err + return } defer func() { _ = conn.Close() }() - return nil + _, _ = conn.Write([]byte(".")) } func (nd *NetworkDetector) processPacket(p gopacket.Packet) { @@ -202,6 +235,10 @@ func (nd *NetworkDetector) GetGatewayMAC() net.HardwareAddr { func (nd *NetworkDetector) WaitForGatewayMAC( ctx context.Context, ) (net.HardwareAddr, error) { + if nd.iface.HardwareAddr == nil { + return nil, nil + } + if nd.IsFound() { return nd.GetGatewayMAC(), nil } @@ -222,9 +259,11 @@ func findDefaultInterface(ctx context.Context) (*net.Interface, error) { conn, err := netutil.DialFastest( ctx, "udp", - dnsServers, - 53, - time.Duration(10)*time.Second, + &netutil.Destination{ + Addrs: dnsServers, + Port: 53, + Timeout: time.Duration(20) * time.Second, + }, ) if err != nil { return nil, fmt.Errorf( @@ -251,7 +290,10 @@ func findDefaultInterface(ctx context.Context) (*net.Interface, error) { ) } - for _, iface := range ifaces { + var defaultIface *net.Interface + + for i := range ifaces { + iface := ifaces[i] addrs, err := iface.Addrs() if err != nil { continue // Skip interfaces whose addresses cannot be retrieved @@ -261,11 +303,39 @@ func findDefaultInterface(ctx context.Context) (*net.Interface, error) { if ipnet, ok := addr.(*net.IPNet); ok { // Check if the IP used for Dial matches the interface's IP if ipnet.IP.Equal(localAddr.IP) { - // Return &iface (*net.Interface) instead of iface.Name (string) - return &iface, nil // Found the default interface + if len(iface.HardwareAddr) > 0 { + return &iface, nil // Found the default interface with MAC + } + defaultIface = &iface + } + } + } + } + + // fmt.Println("hello") + + // If we found a default interface but it has no MAC (e.g. VPN/TUN), + // try to find a physical interface as fallback. + if defaultIface != nil { + for i := range ifaces { + iface := ifaces[i] + if iface.Flags&net.FlagUp == 0 || iface.Flags&net.FlagLoopback != 0 { + continue + } + if len(iface.HardwareAddr) == 0 { + continue + } + + addrs, _ := iface.Addrs() + for _, addr := range addrs { + if ipnet, ok := addr.(*net.IPNet); ok { + if ipnet.IP.To4() != nil && !ipnet.IP.IsLoopback() { + return &iface, nil + } } } } + return defaultIface, nil } return nil, fmt.Errorf( diff --git a/internal/packet/sniffer.go b/internal/packet/sniffer.go index d69ed45a..5d36d236 100644 --- a/internal/packet/sniffer.go +++ b/internal/packet/sniffer.go @@ -8,7 +8,61 @@ import ( type Sniffer interface { StartCapturing() - RegisterUntracked(addrs []net.IPAddr, port int) + RegisterUntracked(addrs []net.IP) GetOptimalTTL(key string) uint8 Cache() cache.Cache } + +// estimateHops estimates the number of hops based on TTL. +// This logic is based on the hop counting mechanism from GoodbyeDPI. +// It returns 0 if the TTL is not recognizable. +func estimateHops(ttlLeft uint8) uint8 { + // Unrecognizable initial TTL + estimatedInitialHops := uint8(255) + switch { + case ttlLeft <= 64: + estimatedInitialHops = 64 + case ttlLeft <= 128: + estimatedInitialHops = 128 + default: + estimatedInitialHops = 255 + } + + return estimatedInitialHops - ttlLeft +} + +// isLocalIP checks if an IP address is in a local/private range. +// This is used to filter out outgoing packets from local machine. +// Private ranges: 10.0.0.0/8, 172.16.0.0/12, 192.168.0.0/16, 127.0.0.0/8 +func isLocalIP(ip []byte) bool { + if len(ip) < 4 { + return false + } + + // 10.0.0.0/8 + if ip[0] == 10 { + return true + } + + // 127.0.0.0/8 (loopback) + if ip[0] == 127 { + return true + } + + // 172.16.0.0/12 (172.16.x.x - 172.31.x.x) + if ip[0] == 172 && ip[1] >= 16 && ip[1] <= 31 { + return true + } + + // 192.168.0.0/16 + if ip[0] == 192 && ip[1] == 168 { + return true + } + + // 169.254.0.0/16 (link-local) + if ip[0] == 169 && ip[1] == 254 { + return true + } + + return false +} diff --git a/internal/packet/tcp_sniffer.go b/internal/packet/tcp_sniffer.go index 9871afe2..3720002d 100644 --- a/internal/packet/tcp_sniffer.go +++ b/internal/packet/tcp_sniffer.go @@ -4,14 +4,12 @@ import ( "context" "fmt" "net" - "strconv" "github.com/google/gopacket" "github.com/google/gopacket/layers" "github.com/rs/zerolog" "github.com/xvzc/SpoofDPI/internal/cache" "github.com/xvzc/SpoofDPI/internal/logging" - "github.com/xvzc/SpoofDPI/internal/session" ) var _ Sniffer = (*TCPSniffer)(nil) @@ -52,28 +50,22 @@ func (ts *TCPSniffer) StartCapturing() { packets := packetSource.Packets() // _ = ht.handle.SetBPFRawInstructionFilter(generateSynAckFilter()) _ = ts.handle.ClearBPF() - _ = ts.handle.SetBPFRawInstructionFilter(generateSynAckFilter()) + _ = ts.handle.SetBPFRawInstructionFilter(generateSynAckFilter(ts.handle.LinkType())) // Start a dedicated goroutine to process incoming packets. go func() { // Create a base context for this goroutine. - ctx := session.WithNewTraceID(context.Background()) for packet := range packets { - ts.processPacket(ctx, packet) + ts.processPacket(context.Background(), packet) } }() } // RegisterUntracked registers new IP addresses for tracking. // Addresses that are already being tracked are ignored. -func (ts *TCPSniffer) RegisterUntracked(addrs []net.IPAddr, port int) { - portStr := strconv.Itoa(port) +func (ts *TCPSniffer) RegisterUntracked(addrs []net.IP) { for _, v := range addrs { - ts.nhopCache.Set( - v.String()+":"+portStr, - ts.defaultTTL, - cache.Options().WithSkipExisting(true), - ) + ts.nhopCache.Set(v.String(), ts.defaultTTL, cache.Options().WithSkipExisting(true)) } } @@ -111,6 +103,11 @@ func (ts *TCPSniffer) processPacket(ctx context.Context, p gopacket.Packet) { // Handle IPv4 if ipLayer := p.Layer(layers.LayerTypeIPv4); ipLayer != nil { ip, _ := ipLayer.(*layers.IPv4) + // Skip packets from local/private IPs (outgoing packets) + if isLocalIP(ip.SrcIP) { + fmt.Println(ip.SrcIP) + return + } srcIP = ip.SrcIP.String() ttlLeft = ip.TTL } else if ipLayer := p.Layer(layers.LayerTypeIPv6); ipLayer != nil { @@ -124,75 +121,110 @@ func (ts *TCPSniffer) processPacket(ctx context.Context, p gopacket.Packet) { // Create the cache key: ServerIP:ServerPort // (The source of the SYN/ACK is the server) - key := fmt.Sprintf("%s:%d", srcIP, tcp.SrcPort) + key := srcIP // Calculate hop count from the TTL nhops := estimateHops(ttlLeft) - ok := ts.nhopCache.Set(key, nhops, nil) + ok := ts.nhopCache.Set(key, nhops, cache.Options().WithUpdateExistingOnly(true)) if ok { logger.Trace(). - Str("host_info", key). + Str("from", key). Uint8("nhops", nhops). Uint8("ttlLeft", ttlLeft). - Msgf("received syn+ack") - } -} - -// estimateHops estimates the number of hops based on TTL. -// This logic is based on the hop counting mechanism from GoodbyeDPI. -// It returns 0 if the TTL is not recognizable. -func estimateHops(ttlLeft uint8) uint8 { - // Unrecognizable initial TTL - estimatedInitialHops := uint8(255) - switch { - case ttlLeft <= 64: - estimatedInitialHops = 64 - case ttlLeft <= 128: - estimatedInitialHops = 128 - default: - estimatedInitialHops = 255 + Msgf("ttl(tcp) update") } - - return estimatedInitialHops - ttlLeft } // GenerateSynAckFilter creates a BPF program for "ip and tcp and (tcp[13] & 18 == 18)". // This captures only TCP SYN-ACK packets (IPv4). -func generateSynAckFilter() []BPFInstruction { - instructions := []BPFInstruction{ - // 1. Check EtherType == IPv4 (0x0800) - {Op: 0x28, Jt: 0, Jf: 0, K: 12}, - {Op: 0x15, Jt: 0, Jf: 8, K: 0x0800}, - - // 2. Check Protocol == TCP (6) - {Op: 0x30, Jt: 0, Jf: 0, K: 23}, - {Op: 0x15, Jt: 0, Jf: 6, K: 6}, - - // 3. Check Fragmentation - {Op: 0x28, Jt: 0, Jf: 0, K: 20}, - {Op: 0x45, Jt: 4, Jf: 0, K: 0x1fff}, - - // 4. Find TCP Header Start (IP Header Length to X) - // Loads byte at offset 14 (IP Header Start), gets IHL, multiplies by 4, stores in X. - {Op: 0xb1, Jt: 0, Jf: 0, K: 14}, - - // 5. Check TCP Flags (SYN+ACK) - // We want to load: Ethernet(14) + IP_Len(X) + TCP_Flags(13) - // Instruction is: Load [X + K] - // So K must be 14 + 13 = 27. - - // [FIX] K was 13, changed to 27 - {Op: 0x50, Jt: 0, Jf: 0, K: 27}, - - // Bitwise AND with 18 (SYN=2 | ACK=16) - {Op: 0x54, Jt: 0, Jf: 0, K: 18}, - - // Compare Result == 18 - {Op: 0x15, Jt: 0, Jf: 1, K: 18}, +// GenerateSynAckFilter creates a BPF program adapted to the LinkType. +// It supports Ethernet, Null (Loopback/VPN), and Raw IP link types. +func generateSynAckFilter(linkType layers.LinkType) []BPFInstruction { + var baseOffset uint32 + + // Determine the offset where the IP header begins + switch linkType { + case layers.LinkTypeEthernet: + baseOffset = 14 + case layers.LinkTypeNull, layers.LinkTypeLoop: // BSD Loopback / macOS utun + baseOffset = 4 + case layers.LinkTypeRaw: // Linux TUN + baseOffset = 0 + default: + // Fallback to Ethernet or handle error if necessary + baseOffset = 14 + } - // 6. Capture - {Op: 0x6, Jt: 0, Jf: 0, K: 0x00040000}, - {Op: 0x6, Jt: 0, Jf: 0, K: 0x00000000}, + instructions := []BPFInstruction{} + + // 1. Protocol Verification (IPv4) + if linkType == layers.LinkTypeEthernet { + // Check EtherType == IPv4 (0x0800) at offset 12 + instructions = append( + instructions, + BPFInstruction{Op: 0x28, Jt: 0, Jf: 0, K: 12}, // Ldh [12] + BPFInstruction{ + Op: 0x15, + Jt: 0, + Jf: 8, + K: 0x0800, + }, // Jeq 0x800, True, False(Skip to End) + ) + } else { + // Check IP Version == 4 at the base offset + // Load byte at baseOffset, mask 0xF0, check if 0x40 + instructions = append(instructions, + // BPFInstruction{Op: 0x30, Jt: 0, Jf: 0, K: baseOffset}, // Ldb [baseOffset] + BPFInstruction{Op: 0x54, Jt: 0, Jf: 0, K: 0xf0}, // And 0xf0 + BPFInstruction{Op: 0x15, Jt: 0, Jf: 8, K: 0x40}, // Jeq 0x40, True, False(Skip to End) + ) } + // 2. Check Protocol == TCP (6) + // Protocol field is at IP header + 9 bytes + instructions = append(instructions, + BPFInstruction{Op: 0x30, Jt: 0, Jf: 0, K: baseOffset + 9}, // Ldb [baseOffset + 9] + BPFInstruction{Op: 0x15, Jt: 0, Jf: 6, K: 6}, // Jeq 6, True, False + ) + + // 3. Check Fragmentation (Flags & Fragment Offset) + // At IP header + 6 bytes + instructions = append( + instructions, + BPFInstruction{Op: 0x28, Jt: 0, Jf: 0, K: baseOffset + 6}, // Ldh [baseOffset + 6] + BPFInstruction{ + Op: 0x45, + Jt: 4, + Jf: 0, + K: 0x1fff, + }, // Jset 0x1fff, True(Skip), False + ) + + // 4. Find TCP Header Start + // Load IP IHL from (baseOffset), multiply by 4 to get length, store in X + instructions = append(instructions, + BPFInstruction{Op: 0xb1, Jt: 0, Jf: 0, K: baseOffset}, // Ldxb 4*([baseOffset]&0xf) + ) + + // 5. Check TCP Flags (SYN+ACK) + // We need to load: baseOffset + IP_Len(X) + TCP_Flags(13) + // Instruction: Load [X + K] -> K = baseOffset + 13 + instructions = append( + instructions, + BPFInstruction{ + Op: 0x50, + Jt: 0, + Jf: 0, + K: baseOffset + 13, + }, // Ldb [X + baseOffset + 13] + BPFInstruction{Op: 0x54, Jt: 0, Jf: 0, K: 18}, // And 18 (SYN|ACK) + BPFInstruction{Op: 0x15, Jt: 0, Jf: 1, K: 18}, // Jeq 18, True, False + ) + + // 6. Capture + instructions = append(instructions, + BPFInstruction{Op: 0x6, Jt: 0, Jf: 0, K: 0x00040000}, // Ret capture_len + BPFInstruction{Op: 0x6, Jt: 0, Jf: 0, K: 0x00000000}, // Ret 0 + ) + return instructions } diff --git a/internal/packet/tcp_writer.go b/internal/packet/tcp_writer.go index 5aa4bc9e..b94fa025 100644 --- a/internal/packet/tcp_writer.go +++ b/internal/packet/tcp_writer.go @@ -124,12 +124,14 @@ func (tw *TCPWriter) createIPv4Layers( ) ([]gopacket.SerializableLayer, error) { var packetLayers []gopacket.SerializableLayer - eth := &layers.Ethernet{ - SrcMAC: srcMAC, - DstMAC: dstMAC, - EthernetType: layers.EthernetTypeIPv4, + if srcMAC != nil { + eth := &layers.Ethernet{ + SrcMAC: srcMAC, + DstMAC: dstMAC, + EthernetType: layers.EthernetTypeIPv4, + } + packetLayers = append(packetLayers, eth) } - packetLayers = append(packetLayers, eth) // define ip layer ipLayer := &layers.IPv4{ diff --git a/internal/packet/udp_sniffer.go b/internal/packet/udp_sniffer.go new file mode 100644 index 00000000..8f754d16 --- /dev/null +++ b/internal/packet/udp_sniffer.go @@ -0,0 +1,191 @@ +package packet + +import ( + "context" + "net" + + "github.com/google/gopacket" + "github.com/google/gopacket/layers" + "github.com/rs/zerolog" + "github.com/xvzc/SpoofDPI/internal/cache" + "github.com/xvzc/SpoofDPI/internal/logging" +) + +var _ Sniffer = (*UDPSniffer)(nil) + +type UDPSniffer struct { + logger zerolog.Logger + + nhopCache cache.Cache + defaultTTL uint8 + + handle Handle +} + +func NewUDPSniffer( + logger zerolog.Logger, + cache cache.Cache, + handle Handle, + defaultTTL uint8, +) *UDPSniffer { + return &UDPSniffer{ + logger: logger, + nhopCache: cache, + handle: handle, + defaultTTL: defaultTTL, + } +} + +// --- HopTracker Methods --- + +func (us *UDPSniffer) Cache() cache.Cache { + return us.nhopCache +} + +// StartCapturing begins monitoring for UDP packets in a background goroutine. +func (us *UDPSniffer) StartCapturing() { + // Create a new packet source from the handle. + packetSource := gopacket.NewPacketSource(us.handle, us.handle.LinkType()) + packets := packetSource.Packets() + + _ = us.handle.ClearBPF() + _ = us.handle.SetBPFRawInstructionFilter(generateUdpFilter(us.handle.LinkType())) + + // Start a dedicated goroutine to process incoming packets. + go func() { + for packet := range packets { + us.processPacket(context.Background(), packet) + } + }() +} + +// RegisterUntracked registers new IP addresses for tracking. +// Addresses that are already being tracked are ignored. +func (us *UDPSniffer) RegisterUntracked(addrs []net.IP) { + for _, v := range addrs { + us.nhopCache.Set(v.String(), us.defaultTTL, cache.Options().WithSkipExisting(true)) + } +} + +// GetOptimalTTL retrieves the estimated hop count for a given key from the cache. +// It returns the hop count and true if found, or 0 and false if not found. +func (us *UDPSniffer) GetOptimalTTL(key string) uint8 { + hopCount := uint8(255) + if oTTL, ok := us.nhopCache.Get(key); ok { + hopCount = oTTL.(uint8) + } + + return max(hopCount, 2) - 1 +} + +// processPacket analyzes a single packet to store hop counts. +func (us *UDPSniffer) processPacket(ctx context.Context, p gopacket.Packet) { + logger := logging.WithLocalScope(ctx, us.logger, "sniff") + + udpLayer := p.Layer(layers.LayerTypeUDP) + if udpLayer == nil { + return + } + + var srcIP string + var ttlLeft uint8 + + // Handle IPv4 + if ipLayer := p.Layer(layers.LayerTypeIPv4); ipLayer != nil { + ip, _ := ipLayer.(*layers.IPv4) + + // Skip packets from local/private IPs (outgoing packets) + if isLocalIP(ip.SrcIP) { + return + } + // Skip packets where dst is not local (outgoing packets including our fake packets) + if !isLocalIP(ip.DstIP) { + return + } + + srcIP = ip.SrcIP.String() + ttlLeft = ip.TTL + } else if ipLayer := p.Layer(layers.LayerTypeIPv6); ipLayer != nil { + // Handle IPv6 + ip, _ := ipLayer.(*layers.IPv6) + srcIP = ip.SrcIP.String() + ttlLeft = ip.HopLimit + } else { + return // No IP layer found + } + + key := srcIP + // Calculate hop count from the TTL + nhops := estimateHops(ttlLeft) + + stored, exists := us.nhopCache.Get(key) + + if us.nhopCache.Set(key, nhops, nil) { + if !exists || stored != nhops { + logger.Trace(). + Str("from", key). + Uint8("nhops", nhops). + Uint8("ttlLeft", ttlLeft). + Msgf("ttl(udp) update") + } + } +} + +// GenerateUdpFilter creates a BPF program for "ip and udp". +// It supports Ethernet, Null (Loopback/VPN), and Raw IP link types. +func generateUdpFilter(linkType layers.LinkType) []BPFInstruction { + var baseOffset uint32 + + // Determine the offset where the IP header begins + switch linkType { + case layers.LinkTypeEthernet: + baseOffset = 14 + case layers.LinkTypeNull, layers.LinkTypeLoop: // BSD Loopback / macOS utun + baseOffset = 4 + case layers.LinkTypeRaw: // Linux TUN + baseOffset = 0 + default: + // Fallback to Ethernet or handle error if necessary + baseOffset = 14 + } + + instructions := []BPFInstruction{} + + // 1. Protocol Verification (IPv4) + if linkType == layers.LinkTypeEthernet { + // Check EtherType == IPv4 (0x0800) at offset 12 + instructions = append( + instructions, + BPFInstruction{Op: 0x28, Jt: 0, Jf: 0, K: 12}, // Ldh [12] + BPFInstruction{ + Op: 0x15, + Jt: 0, + Jf: 3, + K: 0x0800, + }, // Jeq 0x800, True, False(Skip to End) + ) + } else { + // Check IP Version == 4 at the base offset + // Load byte at baseOffset, mask 0xF0, check if 0x40 + instructions = append(instructions, + BPFInstruction{Op: 0x30, Jt: 0, Jf: 0, K: baseOffset}, // Ldb [baseOffset] + BPFInstruction{Op: 0x54, Jt: 0, Jf: 0, K: 0xf0}, // And 0xf0 + BPFInstruction{Op: 0x15, Jt: 0, Jf: 3, K: 0x40}, // Jeq 0x40, True, False(Skip to End) + ) + } + + // 2. Check Protocol == UDP (17) + // Protocol field is at IP header + 9 bytes + instructions = append(instructions, + BPFInstruction{Op: 0x30, Jt: 0, Jf: 0, K: baseOffset + 9}, // Ldb [baseOffset + 9] + BPFInstruction{Op: 0x15, Jt: 0, Jf: 1, K: 17}, // Jeq 17, True, False + ) + + // 3. Capture + instructions = append(instructions, + BPFInstruction{Op: 0x6, Jt: 0, Jf: 0, K: 0x00040000}, // Ret capture_len + BPFInstruction{Op: 0x6, Jt: 0, Jf: 0, K: 0x00000000}, // Ret 0 + ) + + return instructions +} diff --git a/internal/packet/udp_writer.go b/internal/packet/udp_writer.go new file mode 100644 index 00000000..ff92bf97 --- /dev/null +++ b/internal/packet/udp_writer.go @@ -0,0 +1,197 @@ +package packet + +import ( + "context" + "errors" + "net" + + "github.com/google/gopacket" + "github.com/google/gopacket/layers" + "github.com/rs/zerolog" +) + +var _ Writer = (*UDPWriter)(nil) + +type UDPWriter struct { + logger zerolog.Logger + + handle Handle + iface *net.Interface + gatewayMAC net.HardwareAddr +} + +func NewUDPWriter( + logger zerolog.Logger, + handle Handle, + iface *net.Interface, + gatewayMAC net.HardwareAddr, +) *UDPWriter { + return &UDPWriter{ + logger: logger, + handle: handle, + iface: iface, + gatewayMAC: gatewayMAC, + } +} + +// --- Injector Methods --- + +// WriteCraftedPacket crafts and injects a full UDP packet from a payload. +// It uses the pre-configured gateway MAC address. +func (uw *UDPWriter) WriteCraftedPacket( + ctx context.Context, + src net.Addr, + dst net.Addr, + ttl uint8, + payload []byte, +) (int, error) { + // set variables for src/dst + srcMAC := uw.iface.HardwareAddr + dstMAC := uw.gatewayMAC + + srcUDP, ok := src.(*net.UDPAddr) + if !ok { + return 0, errors.New("src is not *net.UDPAddr") + } + + dstUDP, ok := dst.(*net.UDPAddr) + if !ok { + return 0, errors.New("dst is not *net.UDPAddr") + } + + srcPort := srcUDP.Port + dstPort := dstUDP.Port + + var err error + var packetLayers []gopacket.SerializableLayer + if dstUDP.IP.To4() != nil { + packetLayers, err = uw.createIPv4Layers( + srcMAC, + dstMAC, + srcUDP.IP, + dstUDP.IP, + srcPort, + dstPort, + ttl, + ) + } else { + packetLayers, err = uw.createIPv6Layers( + srcMAC, + dstMAC, + srcUDP.IP, + srcUDP.IP, + srcPort, + dstPort, + ttl, + ) + } + + if err != nil { + return 0, err + } + + // serialize the packet L2(optional) + L3 + L4 + payload + buf := gopacket.NewSerializeBuffer() + opts := gopacket.SerializeOptions{ + ComputeChecksums: true, // Recalculate checksums + FixLengths: true, // Fix lengths + } + + packetLayers = append(packetLayers, gopacket.Payload(payload)) + + err = gopacket.SerializeLayers(buf, opts, packetLayers...) + if err != nil { + return 0, err + } + + // inject the raw L2 packet + if err := uw.handle.WritePacketData(buf.Bytes()); err != nil { + return 0, err + } + + return len(payload), nil +} + +func (uw *UDPWriter) createIPv4Layers( + srcMAC net.HardwareAddr, + dstMAC net.HardwareAddr, + srcIP net.IP, + dstIP net.IP, + srcPort int, + dstPort int, + ttl uint8, +) ([]gopacket.SerializableLayer, error) { + var packetLayers []gopacket.SerializableLayer + + if srcMAC != nil { + eth := &layers.Ethernet{ + SrcMAC: srcMAC, + DstMAC: dstMAC, + EthernetType: layers.EthernetTypeIPv4, + } + packetLayers = append(packetLayers, eth) + } + + // define ip layer + ipLayer := &layers.IPv4{ + Version: 4, + TTL: ttl, + Protocol: layers.IPProtocolUDP, + SrcIP: srcIP, + DstIP: dstIP, + } + packetLayers = append(packetLayers, ipLayer) + + // define udp layer + udpLayer := &layers.UDP{ + SrcPort: layers.UDPPort(srcPort), + DstPort: layers.UDPPort(dstPort), + } + packetLayers = append(packetLayers, udpLayer) + + if err := udpLayer.SetNetworkLayerForChecksum(ipLayer); err != nil { + return nil, err + } + + return packetLayers, nil +} + +func (uw *UDPWriter) createIPv6Layers( + srcMAC net.HardwareAddr, + dstMAC net.HardwareAddr, + srcIP net.IP, + dstIP net.IP, + srcPort int, + dstPort int, + ttl uint8, +) ([]gopacket.SerializableLayer, error) { + var packetLayers []gopacket.SerializableLayer + + eth := &layers.Ethernet{ + SrcMAC: srcMAC, + DstMAC: dstMAC, + EthernetType: layers.EthernetTypeIPv6, + } + packetLayers = append(packetLayers, eth) + + ipLayer := &layers.IPv6{ + Version: 6, + HopLimit: ttl, + NextHeader: layers.IPProtocolUDP, + SrcIP: srcIP, + DstIP: dstIP, + } + packetLayers = append(packetLayers, ipLayer) + + udpLayer := &layers.UDP{ + SrcPort: layers.UDPPort(srcPort), + DstPort: layers.UDPPort(dstPort), + } + packetLayers = append(packetLayers, udpLayer) + + if err := udpLayer.SetNetworkLayerForChecksum(ipLayer); err != nil { + return nil, err + } + + return packetLayers, nil +} diff --git a/internal/proto/http.go b/internal/proto/http.go index 7f718a8b..1ca656f0 100644 --- a/internal/proto/http.go +++ b/internal/proto/http.go @@ -68,8 +68,8 @@ func ReadHttpRequest(rdr io.Reader) (*HTTPRequest, error) { return NewHttpRequest(req), nil } -// ExtractDomain returns the host without port information -func (r *HTTPRequest) ExtractDomain() string { +// ExtractHost returns the host without port information +func (r *HTTPRequest) ExtractHost() string { host, _, err := net.SplitHostPort(r.Host) if err != nil { return r.Host diff --git a/internal/proto/socks5.go b/internal/proto/socks5.go index 92054979..f99a46e5 100644 --- a/internal/proto/socks5.go +++ b/internal/proto/socks5.go @@ -13,42 +13,34 @@ const ( SOCKSVersion = 0x05 // Auth - AuthNone = 0x00 - AuthGSSAPI = 0x01 - AuthUserPass = 0x02 - AuthNoAccept = 0xFF + SOCKS5AuthNone = 0x00 + SOCKS5AuthGSSAPI = 0x01 + SOCKS5AuthUserPass = 0x02 + SOCKS5AuthNoAccept = 0xFF // Command - CmdConnect = 0x01 - CmdBind = 0x02 - CmdUDPAssociate = 0x03 + SOCKS5CmdConnect = 0x01 + SOCKS5CmdBind = 0x02 + SOCKS5CmdUDPAssociate = 0x03 // ATYP - ATYPIPv4 = 0x01 - ATYPFQDN = 0x03 - ATYPIPv6 = 0x04 + SOCKS5AddrTypeIPv4 = 0x01 + SOCKS5AddrTypeFQDN = 0x03 + SOCKS5AddrTypeIPv6 = 0x04 // Reply codes - ReplyCodeSuccess = 0x00 - ReplyCodeGenFailure = 0x01 - ReplyCodeCmdNotSupport = 0x07 - ReplyCodeAddrNotSupport = 0x08 + SOCKS5RCodeSuccess = 0x00 + SOCKS5RCodeGenFailure = 0x01 + SOCKS5RCodeCmdNotSupported = 0x07 + SOCKS5RCodeAddrNotSupported = 0x08 ) type SOCKS5Request struct { - Cmd byte - Domain string - IP net.IP - Port int -} - -func (r *SOCKS5Request) Host() string { - ret := r.Domain - if ret == "" { - ret = r.IP.String() - } - - return ret + Cmd byte + ATYP byte + FQDN string + IP net.IP + Port int } // ReadSocks5Request parses the SOCKS5 request details. @@ -73,18 +65,20 @@ func ReadSocks5Request(conn net.Conn) (*SOCKS5Request, error) { var ip net.IP switch atyp { - case ATYPIPv4: + case SOCKS5AddrTypeIPv4: buf := make([]byte, 4) if _, err := io.ReadFull(conn, buf); err != nil { return nil, err } + ip = net.IP(buf) - case ATYPFQDN: + case SOCKS5AddrTypeFQDN: lenBuf := make([]byte, 1) if _, err := io.ReadFull(conn, lenBuf); err != nil { return nil, err } + domainLen := int(lenBuf[0]) domainBuf := make([]byte, domainLen) if _, err := io.ReadFull(conn, domainBuf); err != nil { @@ -92,7 +86,7 @@ func ReadSocks5Request(conn net.Conn) (*SOCKS5Request, error) { } domain = string(domainBuf) - case ATYPIPv6: + case SOCKS5AddrTypeIPv6: buf := make([]byte, 16) if _, err := io.ReadFull(conn, buf); err != nil { return nil, err @@ -110,10 +104,11 @@ func ReadSocks5Request(conn net.Conn) (*SOCKS5Request, error) { port := int(binary.BigEndian.Uint16(portBuf)) return &SOCKS5Request{ - Cmd: cmd, - Domain: domain, - IP: ip, - Port: port, + Cmd: cmd, + ATYP: atyp, + FQDN: domain, + IP: ip, + Port: port, }, nil } @@ -132,15 +127,15 @@ func NewSOCKS5Reply(rep byte) *SOCKS5Reply { } func SOCKS5SuccessResponse() *SOCKS5Reply { - return NewSOCKS5Reply(ReplyCodeSuccess) + return NewSOCKS5Reply(SOCKS5RCodeSuccess) } func SOCKS5FailureResponse() *SOCKS5Reply { - return NewSOCKS5Reply(ReplyCodeGenFailure) + return NewSOCKS5Reply(SOCKS5RCodeGenFailure) } func SOCKS5CommandNotSupportedResponse() *SOCKS5Reply { - return NewSOCKS5Reply(ReplyCodeCmdNotSupport) + return NewSOCKS5Reply(SOCKS5RCodeCmdNotSupported) } func (r *SOCKS5Reply) Bind(ip net.IP) *SOCKS5Reply { @@ -157,7 +152,7 @@ func (r *SOCKS5Reply) Port(port int) *SOCKS5Reply { func (r *SOCKS5Reply) Write(w io.Writer) error { buf := make([]byte, 0, 10) - buf = append(buf, SOCKSVersion, r.Rep, 0x00, ATYPIPv4) + buf = append(buf, SOCKSVersion, r.Rep, 0x00, SOCKS5AddrTypeIPv4) // Use To4() to ensure 4 bytes if it's an IPv4 address stored in IPv6 format if ip4 := r.BindIP.To4(); ip4 != nil { diff --git a/internal/proxy/http/connect.go b/internal/proxy/http/connect.go deleted file mode 100644 index be7fb7f7..00000000 --- a/internal/proxy/http/connect.go +++ /dev/null @@ -1,54 +0,0 @@ -package http - -import ( - "context" - "errors" - "fmt" - "io" - "net" - - "github.com/rs/zerolog" - "github.com/xvzc/SpoofDPI/internal/config" - "github.com/xvzc/SpoofDPI/internal/logging" - "github.com/xvzc/SpoofDPI/internal/netutil" - "github.com/xvzc/SpoofDPI/internal/proto" - "github.com/xvzc/SpoofDPI/internal/proxy/tlsutil" -) - -type HTTPSHandler struct { - logger zerolog.Logger - bridge *tlsutil.TLSBridge -} - -func NewHTTPSHandler( - logger zerolog.Logger, - bridge *tlsutil.TLSBridge, -) *HTTPSHandler { - return &HTTPSHandler{ - logger: logger, - bridge: bridge, - } -} - -func (h *HTTPSHandler) HandleRequest( - ctx context.Context, - lConn net.Conn, - dst *netutil.Destination, - rule *config.Rule, -) error { - logger := logging.WithLocalScope(ctx, h.logger, "handshake") - - // 1. Send 200 Connection Established - if err := proto.HTTPConnectionEstablishedResponse().Write(lConn); err != nil { - if !netutil.IsConnectionResetByPeer(err) && !errors.Is(err, io.EOF) { - logger.Trace().Err(err).Msgf("proxy handshake error") - return fmt.Errorf("failed to handle proxy handshake: %w", err) - } - return nil - } - logger.Trace().Msgf("sent 200 connection established -> %s", lConn.RemoteAddr()) - - // 2. Delegate to Bridge - return h.bridge.Tunnel(ctx, lConn, dst, rule) -} - diff --git a/internal/proxy/http_proxy.go b/internal/proxy/http_proxy.go deleted file mode 100644 index 9e714d93..00000000 --- a/internal/proxy/http_proxy.go +++ /dev/null @@ -1,224 +0,0 @@ -package proxy - -import ( - "context" - "encoding/json" - "errors" - "fmt" - "io" - "net" - "time" - - "github.com/rs/zerolog" - "github.com/xvzc/SpoofDPI/internal/config" - "github.com/xvzc/SpoofDPI/internal/dns" - "github.com/xvzc/SpoofDPI/internal/logging" - "github.com/xvzc/SpoofDPI/internal/matcher" - "github.com/xvzc/SpoofDPI/internal/netutil" - "github.com/xvzc/SpoofDPI/internal/proto" - "github.com/xvzc/SpoofDPI/internal/proxy/handler" - "github.com/xvzc/SpoofDPI/internal/ptr" - "github.com/xvzc/SpoofDPI/internal/session" -) - -type HTTPProxy struct { - logger zerolog.Logger - - resolver dns.Resolver - httpHandler *handler.HTTPHandler - httpsHandler *handler.HTTPSHandler - ruleMatcher matcher.RuleMatcher - serverOpts *config.ServerOptions - policyOpts *config.PolicyOptions -} - -func NewHTTPProxy( - logger zerolog.Logger, - resolver dns.Resolver, - httpHandler *handler.HTTPHandler, - httpsHandler *handler.HTTPSHandler, - ruleMatcher matcher.RuleMatcher, - serverOpts *config.ServerOptions, - policyOpts *config.PolicyOptions, -) ProxyServer { - return &HTTPProxy{ - logger: logger, - resolver: resolver, - httpHandler: httpHandler, - httpsHandler: httpsHandler, - ruleMatcher: ruleMatcher, - serverOpts: serverOpts, - policyOpts: policyOpts, - } -} - -func (p *HTTPProxy) ListenAndServe(ctx context.Context, wait chan struct{}) { - <-wait - - logger := p.logger.With().Ctx(ctx).Logger() - - listener, err := net.ListenTCP("tcp", p.serverOpts.ListenAddr) - if err != nil { - p.logger.Fatal(). - Err(err). - Msgf("error creating listener on %s", p.serverOpts.ListenAddr.String()) - } - - logger.Info(). - Msgf("created a listener on %s", p.serverOpts.ListenAddr) - - for { - conn, err := listener.Accept() - if err != nil { - p.logger.Error(). - Err(err). - Msgf("failed to accept new connection") - - continue - } - - go p.handleNewConnection(session.WithNewTraceID(context.Background()), conn) - } -} - -func (p *HTTPProxy) handleNewConnection(ctx context.Context, conn net.Conn) { - logger := logging.WithLocalScope(ctx, p.logger, "conn") - - ctx, cancel := context.WithCancel(ctx) - defer cancel() - defer netutil.CloseConns(conn) - - req, err := proto.ReadHttpRequest(conn) - if err != nil { - if err != io.EOF { - logger.Warn().Err(err).Msg("failed to read http request") - } - - return - } - - if !req.IsValidMethod() { - logger.Warn().Str("method", req.Method).Msg("unsupported method. abort") - _ = proto.HTTPNotImplementedResponse().Write(conn) - - return - } - - domain := req.ExtractDomain() - dstPort, err := req.ExtractPort() - if err != nil { - logger.Warn().Str("host", req.Host).Msg("failed to extract port") - _ = proto.HTTPBadRequestResponse().Write(conn) - - return - } - - ctx = session.WithHostInfo(ctx, domain) - logger = logger.With().Ctx(ctx).Logger() - - logger.Debug(). - Str("method", req.Method). - Str("from", conn.RemoteAddr().String()). - Msg("new request") - - nameMatch := p.ruleMatcher.Search( - &matcher.Selector{Kind: matcher.MatchKindDomain, Domain: ptr.FromValue(domain)}, - ) - if nameMatch != nil && logger.GetLevel() == zerolog.TraceLevel { - jsonAttrs, _ := json.Marshal(nameMatch) - logger.Trace().RawJSON("values", jsonAttrs).Msg("name match") - } - - t1 := time.Now() - rSet, err := p.resolver.Resolve(ctx, domain, nil, nameMatch) - dt := time.Since(t1).Milliseconds() - if err != nil { - _ = proto.HTTPBadGatewayResponse().Write(conn) - logging.ErrorUnwrapped(&logger, "dns lookup failed", err) - - return - } - - logger.Debug(). - Int("cnt", len(rSet.Addrs)). - Str("took", fmt.Sprintf("%dms", dt)). - Msgf("dns lookup ok") - - // Avoid recursively querying self. - ok, err := netutil.ValidateDestination(rSet.Addrs, dstPort, p.serverOpts.ListenAddr) - if err != nil { - logger.Debug().Err(err).Msg("error validating dst addrs") - if !ok { - _ = proto.HTTPForbiddenResponse().Write(conn) - } - } - - var selectors []*matcher.Selector - for _, v := range rSet.Addrs { - selectors = append(selectors, &matcher.Selector{ - Kind: matcher.MatchKindAddr, - IP: ptr.FromValue(v.IP), - Port: ptr.FromValue(uint16(dstPort)), - }) - } - - addrMatch := p.ruleMatcher.SearchAll(selectors) - if addrMatch != nil && logger.GetLevel() == zerolog.TraceLevel { - jsonAttrs, _ := json.Marshal(addrMatch) - logger.Trace().RawJSON("values", jsonAttrs).Msg("addr match") - } - - bestMatch := matcher.GetHigherPriorityRule(addrMatch, nameMatch) - if bestMatch != nil && logger.GetLevel() == zerolog.TraceLevel { - jsonAttrs, _ := json.Marshal(bestMatch) - logger.Trace().RawJSON("values", jsonAttrs).Msg("best match") - } - - if bestMatch != nil && *bestMatch.Block { - logger.Debug().Msg("request is blocked by policy") - return - } - - dst := &netutil.Destination{ - Domain: domain, - Addrs: rSet.Addrs, - Port: dstPort, - Timeout: *p.serverOpts.Timeout, - } - - var handleErr error - if req.IsConnectMethod() { - handleErr = p.httpsHandler.HandleRequest(ctx, conn, dst, bestMatch) - } else { - handleErr = p.httpHandler.HandleRequest(ctx, conn, req, dst, bestMatch) - } - - if handleErr == nil { // Early exit if no error found - return - } - - logger.Warn().Err(handleErr).Msg("error handling request") - if !errors.Is(handleErr, netutil.ErrBlocked) { // Early exit if not blocked - return - } - - // ┌─────────────┐ - // │ AUTO config │ - // └─────────────┘ - if bestMatch != nil && logger.GetLevel() == zerolog.TraceLevel { - logger.Info().Msg("skipping auto-config (duplicate policy)") - return - } - - // Perform auto config if enabled and RuleTemplate is not nil - if *p.policyOpts.Auto && p.policyOpts.Template != nil { - newRule := p.policyOpts.Template.Clone() - newRule.Match = &config.MatchAttrs{Domains: []string{domain}} - - if err := p.ruleMatcher.Add(newRule); err != nil { - logger.Info().Err(err).Msg("failed to add config automatically") - } else { - logger.Info().Msg("automatically added to config") - } - } -} diff --git a/internal/proxy/proxy.go b/internal/proxy/proxy.go deleted file mode 100644 index b7a47163..00000000 --- a/internal/proxy/proxy.go +++ /dev/null @@ -1,9 +0,0 @@ -package proxy - -import ( - "context" -) - -type ProxyServer interface { - ListenAndServe(ctx context.Context, wait chan struct{}) -} diff --git a/internal/proxy/socks5/proxy.go b/internal/proxy/socks5/proxy.go deleted file mode 100644 index e0edb86e..00000000 --- a/internal/proxy/socks5/proxy.go +++ /dev/null @@ -1,259 +0,0 @@ -package socks5 - -import ( - "context" - "encoding/json" - "fmt" - "io" - "net" - "time" - - "github.com/rs/zerolog" - "github.com/xvzc/SpoofDPI/internal/config" - "github.com/xvzc/SpoofDPI/internal/dns" - "github.com/xvzc/SpoofDPI/internal/logging" - "github.com/xvzc/SpoofDPI/internal/matcher" - "github.com/xvzc/SpoofDPI/internal/netutil" - "github.com/xvzc/SpoofDPI/internal/proto" - "github.com/xvzc/SpoofDPI/internal/proxy" - "github.com/xvzc/SpoofDPI/internal/proxy/tlsutil" - "github.com/xvzc/SpoofDPI/internal/ptr" - "github.com/xvzc/SpoofDPI/internal/session" -) - -type SOCKS5Proxy struct { - logger zerolog.Logger - - resolver dns.Resolver - ruleMatcher matcher.RuleMatcher - serverOpts *config.ServerOptions - policyOpts *config.PolicyOptions - - tcpHandler *TCPHandler - udpHandler *UDPHandler -} - -func NewSOCKS5Proxy( - logger zerolog.Logger, - resolver dns.Resolver, - bridge *tlsutil.TLSBridge, - ruleMatcher matcher.RuleMatcher, - serverOpts *config.ServerOptions, - policyOpts *config.PolicyOptions, -) proxy.ProxyServer { - return &SOCKS5Proxy{ - logger: logger, - resolver: resolver, - ruleMatcher: ruleMatcher, - serverOpts: serverOpts, - policyOpts: policyOpts, - tcpHandler: NewTCPHandler( - logger, - bridge, - serverOpts, - ), - udpHandler: NewUDPHandler(logger), - } -} - -func (p *SOCKS5Proxy) ListenAndServe(ctx context.Context, wait chan struct{}) { - <-wait - - logger := p.logger.With().Ctx(ctx).Logger() - - // Using ListenTCP to match HTTPProxy style, though net.Listen is also fine - listener, err := net.ListenTCP("tcp", p.serverOpts.ListenAddr) - if err != nil { - p.logger.Fatal(). - Err(err). - Msgf("error creating socks5 listener on %s", p.serverOpts.ListenAddr.String()) - } - - logger.Info(). - Msgf("created a socks5 listener on %s", p.serverOpts.ListenAddr) - - for { - conn, err := listener.Accept() - if err != nil { - p.logger.Error(). - Err(err). - Msg("failed to accept new connection") - continue - } - - go p.handleConnection(session.WithNewTraceID(context.Background()), conn) - } -} - -func (p *SOCKS5Proxy) handleConnection(ctx context.Context, conn net.Conn) { - logger := logging.WithLocalScope(ctx, p.logger, "socks5") - - ctx, cancel := context.WithCancel(ctx) - defer cancel() - defer netutil.CloseConns(conn) - - // 1. Negotiation Phase - if err := p.negotiate(conn); err != nil { - logger.Debug().Err(err).Msg("socks5 negotiation failed") - return - } - - // 2. Request Phase - req, err := proto.ReadSocks5Request(conn) - if err != nil { - if err != io.EOF { - logger.Warn().Err(err).Msg("failed to read socks5 request") - } - return - } - - // Setup Logging Context - remoteInfo := req.Domain - if remoteInfo == "" { - remoteInfo = req.IP.String() - } - ctx = session.WithHostInfo(ctx, remoteInfo) - - switch req.Cmd { - case proto.CmdConnect: - rule, addrs, err := p.resolveAndMatch(ctx, req) - if err != nil { - return // resolveAndMatch logs error and writes failure response if needed - } - dst := &netutil.Destination{ - Domain: req.Domain, - Addrs: addrs, - Port: req.Port, - Timeout: *p.serverOpts.Timeout, - } - if err := p.tcpHandler.Handle(ctx, conn, req, dst, rule); err != nil { - return // Handler logs error - } - - p.handleAutoConfig(ctx, req, addrs, rule) - - case proto.CmdUDPAssociate: - // UDP Associate usually doesn't have destination info in the request - _ = p.udpHandler.Handle(ctx, conn, req, nil, nil) - default: - _ = proto.SOCKS5CommandNotSupportedResponse().Write(conn) - logger.Warn().Uint8("cmd", req.Cmd).Msg("unsupported socks5 command") - } -} - -func (p *SOCKS5Proxy) resolveAndMatch( - ctx context.Context, - req *proto.SOCKS5Request, -) (*config.Rule, []net.IPAddr, error) { - logger := zerolog.Ctx(ctx) - - // 1. Match Domain Rules (if domain provided) - var nameMatch *config.Rule - if req.Domain != "" { - nameMatch = p.ruleMatcher.Search( - &matcher.Selector{ - Kind: matcher.MatchKindDomain, - Domain: ptr.FromValue(req.Domain), - }, - ) - if nameMatch != nil && logger.GetLevel() == zerolog.TraceLevel { - jsonAttrs, _ := json.Marshal(nameMatch) - logger.Trace().RawJSON("values", jsonAttrs).Msg("name match") - } - } - - // 2. DNS Resolution - t1 := time.Now() - var addrs []net.IPAddr - - if req.Domain != "" { - // Resolve Domain - rSet, err := p.resolver.Resolve(ctx, req.Domain, nil, nameMatch) - if err != nil { - logging.ErrorUnwrapped(logger, "dns lookup failed", err) - return nil, nil, err - } - addrs = rSet.Addrs - } else { - addrs = []net.IPAddr{{IP: req.IP}} - } - - dt := time.Since(t1).Milliseconds() - logger.Debug(). - Int("cnt", len(addrs)). - Str("took", fmt.Sprintf("%dms", dt)). - Msg("dns lookup ok") - - // 3. Match IP Rules - var selectors []*matcher.Selector - for _, v := range addrs { - selectors = append(selectors, &matcher.Selector{ - Kind: matcher.MatchKindAddr, - IP: ptr.FromValue(v.IP), - Port: ptr.FromValue(uint16(req.Port)), - }) - } - - addrMatch := p.ruleMatcher.SearchAll(selectors) - if addrMatch != nil && logger.GetLevel() == zerolog.TraceLevel { - jsonAttrs, _ := json.Marshal(addrMatch) - logger.Trace().RawJSON("values", jsonAttrs).Msg("addr match") - } - - bestMatch := matcher.GetHigherPriorityRule(addrMatch, nameMatch) - return bestMatch, addrs, nil -} - -func (p *SOCKS5Proxy) handleAutoConfig( - ctx context.Context, - req *proto.SOCKS5Request, - addrs []net.IPAddr, - matchedRule *config.Rule, -) { - logger := zerolog.Ctx(ctx) - - if matchedRule != nil { - logger.Info(). - Interface("match", matchedRule.Match). - Str("name", *matchedRule.Name). - Msg("skipping auto-config (duplicate policy)") - return - } - - if *p.policyOpts.Auto && p.policyOpts.Template != nil { - newRule := p.policyOpts.Template.Clone() - targetDomain := req.Domain - if targetDomain == "" && len(addrs) > 0 { - targetDomain = addrs[0].IP.String() - } - - newRule.Match = &config.MatchAttrs{Domains: []string{targetDomain}} - - if err := p.ruleMatcher.Add(newRule); err != nil { - logger.Info().Err(err).Msg("failed to add config automatically") - } else { - logger.Info().Msg("automatically added to config") - } - } -} - -func (p *SOCKS5Proxy) negotiate(conn net.Conn) error { - header := make([]byte, 2) - if _, err := io.ReadFull(conn, header); err != nil { - return err - } - - if header[0] != proto.SOCKSVersion { - return fmt.Errorf("unsupported version: %d", header[0]) - } - - nMethods := int(header[1]) - methods := make([]byte, nMethods) - if _, err := io.ReadFull(conn, methods); err != nil { - return err - } - - // Respond: Version 5, Method NoAuth(0) - _, err := conn.Write([]byte{proto.SOCKSVersion, proto.AuthNone}) - return err -} diff --git a/internal/proxy/socks5/tcp.go b/internal/proxy/socks5/tcp.go deleted file mode 100644 index bcb53b48..00000000 --- a/internal/proxy/socks5/tcp.go +++ /dev/null @@ -1,77 +0,0 @@ -package socks5 - -import ( - "context" - "errors" - "net" - - "github.com/rs/zerolog" - "github.com/xvzc/SpoofDPI/internal/config" - "github.com/xvzc/SpoofDPI/internal/netutil" - "github.com/xvzc/SpoofDPI/internal/proto" - "github.com/xvzc/SpoofDPI/internal/proxy/tlsutil" -) - -type TCPHandler struct { - logger zerolog.Logger - bridge *tlsutil.TLSBridge - serverOpts *config.ServerOptions -} - -func NewTCPHandler( - logger zerolog.Logger, - bridge *tlsutil.TLSBridge, - serverOpts *config.ServerOptions, -) *TCPHandler { - return &TCPHandler{ - logger: logger, - bridge: bridge, - serverOpts: serverOpts, - } -} - -func (h *TCPHandler) Handle( - ctx context.Context, - conn net.Conn, - req *proto.SOCKS5Request, - dst *netutil.Destination, - rule *config.Rule, -) error { - logger := h.logger.With().Ctx(ctx).Logger() - - // 1. Validate Destination - ok, err := netutil.ValidateDestination(dst.Addrs, dst.Port, h.serverOpts.ListenAddr) - if err != nil { - logger.Debug().Err(err).Msg("error determining if valid destination") - if !ok { - _ = proto.SOCKS5FailureResponse().Write(conn) - return err - } - } - - // 2. Check if blocked - if rule != nil && *rule.Block { - logger.Debug().Msg("request is blocked by policy") - _ = proto.SOCKS5FailureResponse().Write(conn) - return netutil.ErrBlocked - } - - // 3. Send Success Response - if err := proto.SOCKS5SuccessResponse().Bind(net.IPv4zero).Port(0).Write(conn); err != nil { - logger.Error().Err(err).Msg("failed to write socks5 success reply") - return err - } - - // 4. Delegate to Bridge - handleErr := h.bridge.Tunnel(ctx, conn, dst, rule) - if handleErr == nil { - return nil - } - - logger.Warn().Err(handleErr).Msg("error handling request") - if !errors.Is(handleErr, netutil.ErrBlocked) { - return handleErr - } - - return nil -} \ No newline at end of file diff --git a/internal/proxy/socks5/udp.go b/internal/proxy/socks5/udp.go deleted file mode 100644 index 9368449a..00000000 --- a/internal/proxy/socks5/udp.go +++ /dev/null @@ -1,195 +0,0 @@ -package socks5 - -import ( - "context" - "encoding/binary" - "fmt" - "io" - "net" - - "github.com/rs/zerolog" - "github.com/xvzc/SpoofDPI/internal/config" - "github.com/xvzc/SpoofDPI/internal/netutil" - "github.com/xvzc/SpoofDPI/internal/proto" -) - -type UDPHandler struct { - logger zerolog.Logger -} - -func NewUDPHandler(logger zerolog.Logger) *UDPHandler { - return &UDPHandler{ - logger: logger, - } -} - -func (h *UDPHandler) Handle( - ctx context.Context, - conn net.Conn, - req *proto.SOCKS5Request, - dst *netutil.Destination, - rule *config.Rule, -) error { - logger := h.logger.With().Ctx(ctx).Logger() - - // 1. Listen on a random UDP port - udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0}) - if err != nil { - logger.Error().Err(err).Msg("failed to create udp listener") - _ = proto.SOCKS5FailureResponse().Write(conn) - return err - } - netutil.CloseConns(udpConn) - - lAddr := udpConn.LocalAddr().(*net.UDPAddr) - - logger.Debug(). - Str("bind_addr", lAddr.String()). - Msg("socks5 udp associate established") - - // 2. Reply with the bound address - if err := proto.SOCKS5SuccessResponse().Bind(lAddr.IP).Port(lAddr.Port).Write(conn); err != nil { - logger.Error().Err(err).Msg("failed to write socks5 success reply") - return err - } - - // 3. Keep TCP Alive & Relay - // We need to monitor TCP for closure. - done := make(chan struct{}) - go func() { - _, _ = io.Copy(io.Discard, conn) // Block until TCP closes - close(done) - }() - - go func() { - <-done - netutil.CloseConns(udpConn) - }() - - buf := make([]byte, 65535) - var clientAddr *net.UDPAddr - - for { - n, addr, err := udpConn.ReadFromUDP(buf) - if err != nil { - // Normal closure check - select { - case <-done: - return nil - default: - logger.Debug().Err(err).Msg("error reading from udp") - return err - } - } - - // Initial Client Identification - if clientAddr == nil { - clientAddr = addr - } - - if addr.IP.Equal(clientAddr.IP) && addr.Port == clientAddr.Port { - // Outbound: Client -> Proxy -> Target - targetAddr, payload, err := parseUDPHeader(buf[:n]) - if err != nil { - logger.Warn().Err(err).Msg("failed to parse socks5 udp header") - continue - } - - // We use the same UDP socket to send to target. - // The Target will reply to this socket. - resolvedAddr, err := net.ResolveUDPAddr("udp", targetAddr) - if err != nil { - logger.Warn(). - Err(err). - Str("addr", targetAddr). - Msg("failed to resolve udp target") - continue - } - - if _, err := udpConn.WriteTo(payload, resolvedAddr); err != nil { - logger.Warn().Err(err).Msg("failed to write udp to target") - } - } else { - // Inbound: Target -> Proxy -> Client - // Wrap with SOCKS5 Header - header := createUDPHeaderFromAddr(addr) - response := append(header, buf[:n]...) - - if _, err := udpConn.WriteToUDP(response, clientAddr); err != nil { - logger.Warn().Err(err).Msg("failed to write udp to client") - } - } - } -} - -func parseUDPHeader(b []byte) (string, []byte, error) { - if len(b) < 4 { - return "", nil, fmt.Errorf("header too short") - } - // RSV(2) FRAG(1) ATYP(1) - if b[0] != 0 || b[1] != 0 { - return "", nil, fmt.Errorf("invalid rsv") - } - frag := b[2] - if frag != 0 { - return "", nil, fmt.Errorf("fragmentation not supported") - } - - atyp := b[3] - var host string - var pos int - - switch atyp { - case proto.ATYPIPv4: - if len(b) < 10 { - return "", nil, fmt.Errorf("header too short for ipv4") - } - host = net.IP(b[4:8]).String() - pos = 8 - case proto.ATYPIPv6: - if len(b) < 22 { - return "", nil, fmt.Errorf("header too short for ipv6") - } - host = net.IP(b[4:20]).String() - pos = 20 - case proto.ATYPFQDN: - if len(b) < 5 { - return "", nil, fmt.Errorf("header too short for fqdn") - } - l := int(b[4]) - if len(b) < 5+l+2 { - return "", nil, fmt.Errorf("header too short for fqdn data") - } - host = string(b[5 : 5+l]) - pos = 5 + l - default: - return "", nil, fmt.Errorf("unsupported atyp: %d", atyp) - } - - port := binary.BigEndian.Uint16(b[pos : pos+2]) - pos += 2 - - addr := net.JoinHostPort(host, fmt.Sprintf("%d", port)) - return addr, b[pos:], nil -} - -func createUDPHeaderFromAddr(addr *net.UDPAddr) []byte { - // RSV(2) FRAG(1) ATYP(1) ... - buf := make([]byte, 0, 24) - buf = append(buf, 0, 0, 0) // RSV, FRAG - - ip4 := addr.IP.To4() - if ip4 != nil { - buf = append(buf, proto.ATYPIPv4) - buf = append(buf, ip4...) - } else { - buf = append(buf, proto.ATYPIPv6) - buf = append(buf, addr.IP.To16()...) - } - - portBuf := make([]byte, 2) - binary.BigEndian.PutUint16(portBuf, uint16(addr.Port)) - buf = append(buf, portBuf...) - - return buf -} diff --git a/internal/proxy/socks5_proxy.go b/internal/proxy/socks5_proxy.go deleted file mode 100644 index 2e14b422..00000000 --- a/internal/proxy/socks5_proxy.go +++ /dev/null @@ -1,258 +0,0 @@ -package proxy - -import ( - "context" - "encoding/json" - "fmt" - "io" - "net" - "time" - - "github.com/rs/zerolog" - "github.com/xvzc/SpoofDPI/internal/config" - "github.com/xvzc/SpoofDPI/internal/dns" - "github.com/xvzc/SpoofDPI/internal/logging" - "github.com/xvzc/SpoofDPI/internal/matcher" - "github.com/xvzc/SpoofDPI/internal/netutil" - "github.com/xvzc/SpoofDPI/internal/proto" - "github.com/xvzc/SpoofDPI/internal/proxy/handler" - "github.com/xvzc/SpoofDPI/internal/ptr" - "github.com/xvzc/SpoofDPI/internal/session" -) - -type SOCKS5Proxy struct { - logger zerolog.Logger - - resolver dns.Resolver - ruleMatcher matcher.RuleMatcher - serverOpts *config.ServerOptions - policyOpts *config.PolicyOptions - - tcpHandler *handler.TCPHandler - udpHandler *handler.UDPHandler -} - -func NewSOCKS5Proxy( - logger zerolog.Logger, - resolver dns.Resolver, - bridge *handler.Bridge, - ruleMatcher matcher.RuleMatcher, - serverOpts *config.ServerOptions, - policyOpts *config.PolicyOptions, -) ProxyServer { - return &SOCKS5Proxy{ - logger: logger, - resolver: resolver, - ruleMatcher: ruleMatcher, - serverOpts: serverOpts, - policyOpts: policyOpts, - tcpHandler: handler.NewTCPHandler( - logger, - bridge, - serverOpts, - ), - udpHandler: handler.NewUDPHandler(logger), - } -} - -func (p *SOCKS5Proxy) ListenAndServe(ctx context.Context, wait chan struct{}) { - <-wait - - logger := p.logger.With().Ctx(ctx).Logger() - - // Using ListenTCP to match HTTPProxy style, though net.Listen is also fine - listener, err := net.ListenTCP("tcp", p.serverOpts.ListenAddr) - if err != nil { - p.logger.Fatal(). - Err(err). - Msgf("error creating socks5 listener on %s", p.serverOpts.ListenAddr.String()) - } - - logger.Info(). - Msgf("created a socks5 listener on %s", p.serverOpts.ListenAddr) - - for { - conn, err := listener.Accept() - if err != nil { - p.logger.Error(). - Err(err). - Msg("failed to accept new connection") - continue - } - - go p.handleConnection(session.WithNewTraceID(context.Background()), conn) - } -} - -func (p *SOCKS5Proxy) handleConnection(ctx context.Context, conn net.Conn) { - logger := logging.WithLocalScope(ctx, p.logger, "socks5") - - ctx, cancel := context.WithCancel(ctx) - defer cancel() - defer netutil.CloseConns(conn) - - // 1. Negotiation Phase - if err := p.negotiate(conn); err != nil { - logger.Debug().Err(err).Msg("socks5 negotiation failed") - return - } - - // 2. Request Phase - req, err := proto.ReadSocks5Request(conn) - if err != nil { - if err != io.EOF { - logger.Warn().Err(err).Msg("failed to read socks5 request") - } - return - } - - // Setup Logging Context - remoteInfo := req.Domain - if remoteInfo == "" { - remoteInfo = req.IP.String() - } - ctx = session.WithHostInfo(ctx, remoteInfo) - - switch req.Cmd { - case proto.CmdConnect: - rule, addrs, err := p.resolveAndMatch(ctx, req) - if err != nil { - return // resolveAndMatch logs error and writes failure response if needed - } - dst := &netutil.Destination{ - Domain: req.Domain, - Addrs: addrs, - Port: req.Port, - Timeout: *p.serverOpts.Timeout, - } - if err := p.tcpHandler.Handle(ctx, conn, req, dst, rule); err != nil { - return // Handler logs error - } - - p.handleAutoConfig(ctx, req, addrs, rule) - - case proto.CmdUDPAssociate: - // UDP Associate usually doesn't have destination info in the request - _ = p.udpHandler.Handle(ctx, conn, req, nil, nil) - default: - _ = proto.SOCKS5CommandNotSupportedResponse().Write(conn) - logger.Warn().Uint8("cmd", req.Cmd).Msg("unsupported socks5 command") - } -} - -func (p *SOCKS5Proxy) resolveAndMatch( - ctx context.Context, - req *proto.SOCKS5Request, -) (*config.Rule, []net.IPAddr, error) { - logger := zerolog.Ctx(ctx) - - // 1. Match Domain Rules (if domain provided) - var nameMatch *config.Rule - if req.Domain != "" { - nameMatch = p.ruleMatcher.Search( - &matcher.Selector{ - Kind: matcher.MatchKindDomain, - Domain: ptr.FromValue(req.Domain), - }, - ) - if nameMatch != nil && logger.GetLevel() == zerolog.TraceLevel { - jsonAttrs, _ := json.Marshal(nameMatch) - logger.Trace().RawJSON("values", jsonAttrs).Msg("name match") - } - } - - // 2. DNS Resolution - t1 := time.Now() - var addrs []net.IPAddr - - if req.Domain != "" { - // Resolve Domain - rSet, err := p.resolver.Resolve(ctx, req.Domain, nil, nameMatch) - if err != nil { - logging.ErrorUnwrapped(logger, "dns lookup failed", err) - return nil, nil, err - } - addrs = rSet.Addrs - } else { - addrs = []net.IPAddr{{IP: req.IP}} - } - - dt := time.Since(t1).Milliseconds() - logger.Debug(). - Int("cnt", len(addrs)). - Str("took", fmt.Sprintf("%dms", dt)). - Msg("dns lookup ok") - - // 3. Match IP Rules - var selectors []*matcher.Selector - for _, v := range addrs { - selectors = append(selectors, &matcher.Selector{ - Kind: matcher.MatchKindAddr, - IP: ptr.FromValue(v.IP), - Port: ptr.FromValue(uint16(req.Port)), - }) - } - - addrMatch := p.ruleMatcher.SearchAll(selectors) - if addrMatch != nil && logger.GetLevel() == zerolog.TraceLevel { - jsonAttrs, _ := json.Marshal(addrMatch) - logger.Trace().RawJSON("values", jsonAttrs).Msg("addr match") - } - - bestMatch := matcher.GetHigherPriorityRule(addrMatch, nameMatch) - return bestMatch, addrs, nil -} - -func (p *SOCKS5Proxy) handleAutoConfig( - ctx context.Context, - req *proto.SOCKS5Request, - addrs []net.IPAddr, - matchedRule *config.Rule, -) { - logger := zerolog.Ctx(ctx) - - if matchedRule != nil { - logger.Info(). - Interface("match", matchedRule.Match). - Str("name", *matchedRule.Name). - Msg("skipping auto-config (duplicate policy)") - return - } - - if *p.policyOpts.Auto && p.policyOpts.Template != nil { - newRule := p.policyOpts.Template.Clone() - targetDomain := req.Domain - if targetDomain == "" && len(addrs) > 0 { - targetDomain = addrs[0].IP.String() - } - - newRule.Match = &config.MatchAttrs{Domains: []string{targetDomain}} - - if err := p.ruleMatcher.Add(newRule); err != nil { - logger.Info().Err(err).Msg("failed to add config automatically") - } else { - logger.Info().Msg("automatically added to config") - } - } -} - -func (p *SOCKS5Proxy) negotiate(conn net.Conn) error { - header := make([]byte, 2) - if _, err := io.ReadFull(conn, header); err != nil { - return err - } - - if header[0] != proto.SOCKSVersion { - return fmt.Errorf("unsupported version: %d", header[0]) - } - - nMethods := int(header[1]) - methods := make([]byte, nMethods) - if _, err := io.ReadFull(conn, methods); err != nil { - return err - } - - // Respond: Version 5, Method NoAuth(0) - _, err := conn.Write([]byte{proto.SOCKSVersion, proto.AuthNone}) - return err -} diff --git a/internal/proxy/tlsutil/bridge.go b/internal/proxy/tlsutil/bridge.go deleted file mode 100644 index 2f669734..00000000 --- a/internal/proxy/tlsutil/bridge.go +++ /dev/null @@ -1,130 +0,0 @@ -package tlsutil - -import ( - "context" - "fmt" - "net" - - "github.com/rs/zerolog" - "github.com/xvzc/SpoofDPI/internal/config" - "github.com/xvzc/SpoofDPI/internal/desync" - "github.com/xvzc/SpoofDPI/internal/logging" - "github.com/xvzc/SpoofDPI/internal/netutil" - "github.com/xvzc/SpoofDPI/internal/packet" - "github.com/xvzc/SpoofDPI/internal/proto" - "github.com/xvzc/SpoofDPI/internal/ptr" -) - -type TLSBridge struct { - logger zerolog.Logger - desyncer *desync.TLSDesyncer - sniffer packet.Sniffer - httpsOpts *config.HTTPSOptions -} - -func NewTLSBridge( - logger zerolog.Logger, - desyncer *desync.TLSDesyncer, - sniffer packet.Sniffer, - httpsOpts *config.HTTPSOptions, -) *TLSBridge { - return &TLSBridge{ - logger: logger, - desyncer: desyncer, - sniffer: sniffer, - httpsOpts: httpsOpts, - } -} - -// Tunnel creates a bi-directional tunnel between lConn and dst. -// It detects the first packet from lConn. If it's a ClientHello, it applies the desync strategy. -func (b *TLSBridge) Tunnel( - ctx context.Context, - lConn net.Conn, - dst *netutil.Destination, - rule *config.Rule, -) error { - httpsOpts := b.httpsOpts - if rule != nil { - httpsOpts = httpsOpts.Merge(rule.HTTPS) - } - - if b.sniffer != nil && ptr.FromPtr(httpsOpts.FakeCount) > 0 { - b.sniffer.RegisterUntracked(dst.Addrs, dst.Port) - } - - logger := logging.WithLocalScope(ctx, b.logger, "https") - - rConn, err := netutil.DialFastest(ctx, "tcp", dst.Addrs, dst.Port, dst.Timeout) - if err != nil { - return err - } - defer netutil.CloseConns(rConn) - - logger.Debug().Msgf("new remote conn -> %s", rConn.RemoteAddr()) - - // Read the first message from the client (expected to be ClientHello) - tlsMsg, err := proto.ReadTLSMessage(lConn) - if err != nil { - logger.Trace().Err(err).Msgf("failed to read first message from client") - return nil // Client might have closed connection or sent garbage - } - - logger.Debug(). - Int("len", tlsMsg.Len()). - Msgf("client hello received <- %s", lConn.RemoteAddr()) - - if !tlsMsg.IsClientHello() { - logger.Trace().Int("len", tlsMsg.Len()).Msg("not a client hello. aborting") - return nil - } - - // Send ClientHello to the remote server (with desync if configured) - n, err := b.sendClientHello(ctx, rConn, tlsMsg, httpsOpts) - if err != nil { - return fmt.Errorf("failed to send client hello: %w", err) - } - - logger.Debug(). - Int("len", n). - Msgf("sent client hello -> %s", rConn.RemoteAddr()) - - // Start bi-directional tunneling - errCh := make(chan error, 2) - - ctx, cancel := context.WithCancel(ctx) - defer cancel() - - go netutil.TunnelConns(ctx, logger, errCh, rConn, lConn) - go netutil.TunnelConns(ctx, logger, errCh, lConn, rConn) - - for range 2 { - e := <-errCh - if e == nil { - continue - } - - if netutil.IsConnectionResetByPeer(e) { - return netutil.ErrBlocked - } - - return fmt.Errorf( - "unsuccessful tunnel %s -> %s: %w", - lConn.RemoteAddr(), - rConn.RemoteAddr(), - e, - ) - } - - return nil -} - -func (b *TLSBridge) sendClientHello( - ctx context.Context, - conn net.Conn, - msg *proto.TLSMessage, - httpsOpts *config.HTTPSOptions, -) (int, error) { - logger := logging.WithLocalScope(ctx, b.logger, "client_hello") - return b.desyncer.Send(ctx, logger, conn, msg, httpsOpts) -} diff --git a/internal/proxy/http/handler.go b/internal/server/http/http.go similarity index 73% rename from internal/proxy/http/handler.go rename to internal/server/http/http.go index d847610d..80ee9b73 100644 --- a/internal/proxy/http/handler.go +++ b/internal/server/http/http.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "net" + "time" "github.com/rs/zerolog" "github.com/xvzc/SpoofDPI/internal/config" @@ -31,7 +32,7 @@ func (h *HTTPHandler) HandleRequest( ) error { logger := logging.WithLocalScope(ctx, h.logger, "http") - rConn, err := netutil.DialFastest(ctx, "tcp", dst.Addrs, dst.Port, dst.Timeout) + rConn, err := netutil.DialFastest(ctx, "tcp", dst) if err != nil { _ = proto.HTTPBadGatewayResponse().Write(lConn) return err @@ -49,31 +50,22 @@ func (h *HTTPHandler) HandleRequest( return fmt.Errorf("failed to send request: %w", err) } - errCh := make(chan error, 2) + // Start bi-directional tunneling + resCh := make(chan netutil.TransferResult, 2) ctx, cancel := context.WithCancel(ctx) defer cancel() - go netutil.TunnelConns(ctx, logger, errCh, rConn, lConn) - go netutil.TunnelConns(ctx, logger, errCh, lConn, rConn) - - for range 2 { - e := <-errCh - if e == nil { - continue - } - - if netutil.IsConnectionResetByPeer(e) { - return netutil.ErrBlocked - } - - return fmt.Errorf( - "unsuccessful tunnel %s -> %s: %w", - lConn.RemoteAddr(), - rConn.RemoteAddr(), - e, - ) - } - - return nil + startedAt := time.Now() + go netutil.TunnelConns(ctx, resCh, lConn, rConn, netutil.TunnelDirOut) + go netutil.TunnelConns(ctx, resCh, rConn, lConn, netutil.TunnelDirIn) + + return netutil.WaitAndLogTunnel( + ctx, + logger, + resCh, + startedAt, + netutil.DescribeRoute(lConn, rConn), + nil, + ) } diff --git a/internal/server/http/https.go b/internal/server/http/https.go new file mode 100644 index 00000000..8b47a54b --- /dev/null +++ b/internal/server/http/https.go @@ -0,0 +1,162 @@ +package http + +import ( + "context" + "errors" + "fmt" + "io" + "net" + "slices" + "time" + + "github.com/rs/zerolog" + "github.com/samber/lo" + "github.com/xvzc/SpoofDPI/internal/config" + "github.com/xvzc/SpoofDPI/internal/desync" + "github.com/xvzc/SpoofDPI/internal/logging" + "github.com/xvzc/SpoofDPI/internal/netutil" + "github.com/xvzc/SpoofDPI/internal/packet" + "github.com/xvzc/SpoofDPI/internal/proto" +) + +type HTTPSHandler struct { + logger zerolog.Logger + desyncer *desync.TLSDesyncer + sniffer packet.Sniffer + httpsOpts *config.HTTPSOptions +} + +func NewHTTPSHandler( + logger zerolog.Logger, + desyncer *desync.TLSDesyncer, + sniffer packet.Sniffer, + httpsOpts *config.HTTPSOptions, +) *HTTPSHandler { + return &HTTPSHandler{ + logger: logger, + desyncer: desyncer, + sniffer: sniffer, + httpsOpts: httpsOpts, + } +} + +func (h *HTTPSHandler) HandleRequest( + ctx context.Context, + lConn net.Conn, + dst *netutil.Destination, + rule *config.Rule, +) error { + opts := h.httpsOpts.Clone() + if rule != nil { + opts = opts.Merge(rule.HTTPS) + } + + logger := logging.WithLocalScope(ctx, h.logger, "handshake") + + // 1. Send 200 Connection Established + if err := proto.HTTPConnectionEstablishedResponse().Write(lConn); err != nil { + if !netutil.IsConnectionResetByPeer(err) && !errors.Is(err, io.EOF) { + logger.Trace().Err(err).Msgf("proxy handshake error") + return fmt.Errorf("failed to handle proxy handshake: %w", err) + } + return nil + } + logger.Trace().Msgf("sent 200 connection established -> %s", lConn.RemoteAddr()) + + // 2. Tunnel + return h.tunnel(ctx, lConn, dst, opts) +} + +func (h *HTTPSHandler) tunnel( + ctx context.Context, + lConn net.Conn, + dst *netutil.Destination, + opts *config.HTTPSOptions, +) error { + if h.sniffer != nil && lo.FromPtr(opts.FakeCount) > 0 { + h.sniffer.RegisterUntracked(dst.Addrs) + } + + logger := logging.WithLocalScope(ctx, h.logger, "https") + + rConn, err := netutil.DialFastest(ctx, "tcp", dst) + if err != nil { + return err + } + defer netutil.CloseConns(rConn) + + logger.Debug().Msgf("new remote conn -> %s", rConn.RemoteAddr()) + + // Read the first message from the client (expected to be ClientHello) + tlsMsg, err := proto.ReadTLSMessage(lConn) + if err != nil { + if err == io.EOF || err.Error() == "unexpected EOF" { + return nil + } + logger.Trace().Err(err).Msgf("failed to read first message from client") + return err + } + + logger.Debug(). + Int("len", tlsMsg.Len()). + Msgf("client hello received <- %s", lConn.RemoteAddr()) + + if !tlsMsg.IsClientHello() { + logger.Trace().Int("len", tlsMsg.Len()).Msg("not a client hello. aborting") + return nil + } + + // Send ClientHello to the remote server (with desync if configured) + n, err := h.sendClientHello(ctx, rConn, tlsMsg, opts) + if err != nil { + return fmt.Errorf("failed to send client hello: %w", err) + } + + logger.Debug(). + Int("len", n). + Msgf("sent client hello -> %s", rConn.RemoteAddr()) + + // Start bi-directional tunneling + resCh := make(chan netutil.TransferResult, 2) + + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + startedAt := time.Now() + go netutil.TunnelConns(ctx, resCh, lConn, rConn, netutil.TunnelDirOut) + go netutil.TunnelConns(ctx, resCh, rConn, lConn, netutil.TunnelDirIn) + + handleErrs := func(errs []error) error { + if len(errs) == 0 { + return nil + } + + if slices.ContainsFunc(errs, netutil.IsConnectionResetByPeer) { + return netutil.ErrBlocked + } + + return errs[0] + } + + return netutil.WaitAndLogTunnel( + ctx, + logger, + resCh, + startedAt, + netutil.DescribeRoute(lConn, rConn), + handleErrs, + ) +} + +func (h *HTTPSHandler) sendClientHello( + ctx context.Context, + rConn net.Conn, + msg *proto.TLSMessage, + opts *config.HTTPSOptions, +) (int, error) { + if lo.FromPtr(opts.Skip) { + return rConn.Write(msg.Raw()) + } + + return h.desyncer.Desync(ctx, h.logger, rConn, msg, opts) +} diff --git a/internal/server/http/network.go b/internal/server/http/network.go new file mode 100644 index 00000000..d2d78955 --- /dev/null +++ b/internal/server/http/network.go @@ -0,0 +1,13 @@ +//go:build !darwin && !linux + +package http + +import "github.com/rs/zerolog" + +func SetSystemProxy(logger zerolog.Logger, port uint16) error { + return nil +} + +func UnsetSystemProxy(logger zerolog.Logger) error { + return nil +} diff --git a/internal/server/http/network_darwin.go b/internal/server/http/network_darwin.go new file mode 100644 index 00000000..71e90fd3 --- /dev/null +++ b/internal/server/http/network_darwin.go @@ -0,0 +1,116 @@ +//go:build darwin + +package http + +import ( + "errors" + "fmt" + "net" + "os/exec" + "strconv" + "strings" + + "github.com/rs/zerolog" + "github.com/xvzc/SpoofDPI/internal/netutil" +) + +const ( + permissionErrorHelpText = "By default SpoofDPI tries to set itself up as a system-wide proxy server.\n" + + "Doing so may require root access on machines with\n" + + "'Settings > Privacy & Security > Advanced > Require" + + " an administrator password to access system-wide settings' enabled.\n" + + "If you do not want SpoofDPI to act as a system-wide proxy, provide" + + " -system-proxy=false." +) + +var pacListener net.Listener + +func SetSystemProxy(logger zerolog.Logger, port uint16) error { + network, err := getDefaultNetwork() + if err != nil { + return err + } + + portStr := strconv.Itoa(int(port)) + pacContent := fmt.Sprintf(`function FindProxyForURL(url, host) { + return "PROXY 127.0.0.1:%s; DIRECT"; +}`, portStr) + + pacURL, l, err := netutil.RunPACServer(pacContent) + if err != nil { + return fmt.Errorf("error creating pac server: %w", err) + } + pacListener = l + + // Enable Auto Proxy Configuration + // networksetup -setautoproxyurl + if err := networkSetup("-setautoproxyurl", network, pacURL); err != nil { + return fmt.Errorf("setting autoproxyurl: %w", err) + } + + // networksetup -setproxyautodiscovery + if err := networkSetup("-setproxyautodiscovery", network, "on"); err != nil { + return fmt.Errorf("setting proxyautodiscovery: %w", err) + } + + return nil +} + +func UnsetSystemProxy(logger zerolog.Logger) error { + if pacListener != nil { + _ = pacListener.Close() + pacListener = nil + } + + network, err := getDefaultNetwork() + if err != nil { + return err + } + + // Disable Auto Proxy Configuration + if err := networkSetup("-setautoproxystate", network, "off"); err != nil { + return fmt.Errorf("unsetting autoproxystate: %w", err) + } + + if err := networkSetup("-setproxyautodiscovery", network, "off"); err != nil { + return fmt.Errorf("unsetting proxyautodiscovery: %w", err) + } + + return nil +} + +func getDefaultNetwork() (string, error) { + const cmd = "networksetup -listnetworkserviceorder | grep" + + " `(route -n get default | grep 'interface' || route -n get -inet6 default | grep 'interface') | cut -d ':' -f2`" + + " -B 1 | head -n 1 | cut -d ' ' -f 2-" + + out, err := exec.Command("sh", "-c", cmd).Output() + if err != nil { + return "", err + } + + network := strings.TrimSpace(string(out)) + if network == "" { + return "", errors.New("no available networks") + } + return network, nil +} + +func networkSetup(args ...string) error { + cmd := exec.Command("networksetup", args...) + out, err := cmd.CombinedOutput() + if err != nil { + msg := string(out) + if isPermissionError(err) { + msg += permissionErrorHelpText + } + return fmt.Errorf("%s", msg) + } + return nil +} + +func isPermissionError(err error) bool { + var exitErr *exec.ExitError + ok := errors.As(err, &exitErr) + return ok && exitErr.ExitCode() == 14 +} diff --git a/internal/server/http/network_linux.go b/internal/server/http/network_linux.go new file mode 100644 index 00000000..918c0c91 --- /dev/null +++ b/internal/server/http/network_linux.go @@ -0,0 +1,14 @@ +//go:build linux + +package http + +import "github.com/rs/zerolog" + +func SetSystemProxy(logger zerolog.Logger, port uint16) error { + // Not implemented for Linux yet + return nil +} + +func UnsetSystemProxy(logger zerolog.Logger) error { + return nil +} diff --git a/internal/proxy/http/proxy.go b/internal/server/http/proxy.go similarity index 66% rename from internal/proxy/http/proxy.go rename to internal/server/http/proxy.go index 5bb1b88b..b4e71b43 100644 --- a/internal/proxy/http/proxy.go +++ b/internal/server/http/proxy.go @@ -2,22 +2,20 @@ package http import ( "context" - "encoding/json" "errors" "fmt" "io" "net" - "time" "github.com/rs/zerolog" + "github.com/samber/lo" "github.com/xvzc/SpoofDPI/internal/config" "github.com/xvzc/SpoofDPI/internal/dns" "github.com/xvzc/SpoofDPI/internal/logging" "github.com/xvzc/SpoofDPI/internal/matcher" "github.com/xvzc/SpoofDPI/internal/netutil" "github.com/xvzc/SpoofDPI/internal/proto" - "github.com/xvzc/SpoofDPI/internal/proxy" - "github.com/xvzc/SpoofDPI/internal/ptr" + "github.com/xvzc/SpoofDPI/internal/server" "github.com/xvzc/SpoofDPI/internal/session" ) @@ -30,6 +28,8 @@ type HTTPProxy struct { ruleMatcher matcher.RuleMatcher serverOpts *config.ServerOptions policyOpts *config.PolicyOptions + + listener net.Listener } func NewHTTPProxy( @@ -40,7 +40,7 @@ func NewHTTPProxy( ruleMatcher matcher.RuleMatcher, serverOpts *config.ServerOptions, policyOpts *config.PolicyOptions, -) proxy.ProxyServer { +) server.Server { return &HTTPProxy{ logger: logger, resolver: resolver, @@ -52,24 +52,27 @@ func NewHTTPProxy( } } -func (p *HTTPProxy) ListenAndServe(ctx context.Context, wait chan struct{}) { - <-wait - - logger := p.logger.With().Ctx(ctx).Logger() - +func (p *HTTPProxy) Start(ctx context.Context, ready chan<- struct{}) error { listener, err := net.ListenTCP("tcp", p.serverOpts.ListenAddr) if err != nil { - p.logger.Fatal(). - Err(err). - Msgf("error creating listener on %s", p.serverOpts.ListenAddr.String()) + return fmt.Errorf( + "error creating listener on %s: %w", + p.serverOpts.ListenAddr.String(), + err, + ) } + p.listener = listener - logger.Info(). - Msgf("created a listener on %s", p.serverOpts.ListenAddr) + if ready != nil { + close(ready) + } for { conn, err := listener.Accept() if err != nil { + if errors.Is(err, net.ErrClosed) { + return nil // Normal shutdown + } p.logger.Error(). Err(err). Msgf("failed to accept new connection") @@ -81,8 +84,27 @@ func (p *HTTPProxy) ListenAndServe(ctx context.Context, wait chan struct{}) { } } +func (p *HTTPProxy) Stop() error { + if p.listener != nil { + return p.listener.Close() + } + return nil +} + +func (p *HTTPProxy) SetNetworkConfig() error { + return SetSystemProxy(p.logger, uint16(p.serverOpts.ListenAddr.Port)) +} + +func (p *HTTPProxy) UnsetNetworkConfig() error { + return UnsetSystemProxy(p.logger) +} + +func (p *HTTPProxy) Addr() string { + return p.serverOpts.ListenAddr.String() +} + func (p *HTTPProxy) handleNewConnection(ctx context.Context, conn net.Conn) { - logger := logging.WithLocalScope(ctx, p.logger, "conn") + logger := logging.WithLocalScope(ctx, p.logger, "conn-init") ctx, cancel := context.WithCancel(ctx) defer cancel() @@ -97,6 +119,9 @@ func (p *HTTPProxy) handleNewConnection(ctx context.Context, conn net.Conn) { return } + logger.Debug().Str("from", conn.RemoteAddr().String()).Str("host", req.Host). + Msg("new request") + if !req.IsValidMethod() { logger.Warn().Str("method", req.Method).Msg("unsupported method. abort") _ = proto.HTTPNotImplementedResponse().Write(conn) @@ -104,7 +129,7 @@ func (p *HTTPProxy) handleNewConnection(ctx context.Context, conn net.Conn) { return } - domain := req.ExtractDomain() + host := req.ExtractHost() dstPort, err := req.ExtractPort() if err != nil { logger.Warn().Str("host", req.Host).Msg("failed to extract port") @@ -113,39 +138,35 @@ func (p *HTTPProxy) handleNewConnection(ctx context.Context, conn net.Conn) { return } - ctx = session.WithHostInfo(ctx, domain) - logger = logger.With().Ctx(ctx).Logger() - logger.Debug(). Str("method", req.Method). Str("from", conn.RemoteAddr().String()). Msg("new request") - nameMatch := p.ruleMatcher.Search( - &matcher.Selector{Kind: matcher.MatchKindDomain, Domain: ptr.FromValue(domain)}, - ) - if nameMatch != nil && logger.GetLevel() == zerolog.TraceLevel { - jsonAttrs, _ := json.Marshal(nameMatch) - logger.Trace().RawJSON("values", jsonAttrs).Msg("name match") - } + var addrs []net.IP + var nameMatch *config.Rule + if net.ParseIP(host) != nil { + addrs = []net.IP{net.ParseIP(host)} + logger.Trace().Msgf("skipping dns lookup for non-domain host %q", host) + } else { + nameMatch = p.ruleMatcher.Search( + &matcher.Selector{Kind: matcher.MatchKindDomain, Domain: lo.ToPtr(host)}, + ) - t1 := time.Now() - rSet, err := p.resolver.Resolve(ctx, domain, nil, nameMatch) - dt := time.Since(t1).Milliseconds() - if err != nil { - _ = proto.HTTPBadGatewayResponse().Write(conn) - logging.ErrorUnwrapped(&logger, "dns lookup failed", err) + rSet, err := p.resolver.Resolve(ctx, host, nil, nameMatch) + if err != nil { + _ = proto.HTTPBadGatewayResponse().Write(conn) + // logging.ErrorUnwrapped is not available, using standard error logging + logger.Error().Err(err).Msgf("dns lookup failed for %s", host) - return - } + return + } - logger.Debug(). - Int("cnt", len(rSet.Addrs)). - Str("took", fmt.Sprintf("%dms", dt)). - Msgf("dns lookup ok") + addrs = rSet.Addrs + } // Avoid recursively querying self. - ok, err := netutil.ValidateDestination(rSet.Addrs, dstPort, p.serverOpts.ListenAddr) + ok, err := netutil.ValidateDestination(addrs, dstPort, p.serverOpts.ListenAddr) if err != nil { logger.Debug().Err(err).Msg("error validating dst addrs") if !ok { @@ -154,24 +175,19 @@ func (p *HTTPProxy) handleNewConnection(ctx context.Context, conn net.Conn) { } var selectors []*matcher.Selector - for _, v := range rSet.Addrs { + for _, v := range addrs { selectors = append(selectors, &matcher.Selector{ Kind: matcher.MatchKindAddr, - IP: ptr.FromValue(v.IP), - Port: ptr.FromValue(uint16(dstPort)), + IP: lo.ToPtr(v), + Port: lo.ToPtr(uint16(dstPort)), }) } addrMatch := p.ruleMatcher.SearchAll(selectors) - if addrMatch != nil && logger.GetLevel() == zerolog.TraceLevel { - jsonAttrs, _ := json.Marshal(addrMatch) - logger.Trace().RawJSON("values", jsonAttrs).Msg("addr match") - } bestMatch := matcher.GetHigherPriorityRule(addrMatch, nameMatch) if bestMatch != nil && logger.GetLevel() == zerolog.TraceLevel { - jsonAttrs, _ := json.Marshal(bestMatch) - logger.Trace().RawJSON("values", jsonAttrs).Msg("best match") + logger.Trace().RawJSON("summary", bestMatch.JSON()).Msg("match") } if bestMatch != nil && *bestMatch.Block { @@ -180,8 +196,8 @@ func (p *HTTPProxy) handleNewConnection(ctx context.Context, conn net.Conn) { } dst := &netutil.Destination{ - Domain: domain, - Addrs: rSet.Addrs, + Domain: host, // Updated from Domain to Host + Addrs: addrs, Port: dstPort, Timeout: *p.serverOpts.Timeout, } @@ -213,7 +229,7 @@ func (p *HTTPProxy) handleNewConnection(ctx context.Context, conn net.Conn) { // Perform auto config if enabled and RuleTemplate is not nil if *p.policyOpts.Auto && p.policyOpts.Template != nil { newRule := p.policyOpts.Template.Clone() - newRule.Match = &config.MatchAttrs{Domains: []string{domain}} + newRule.Match = &config.MatchAttrs{Domains: []string{host}} if err := p.ruleMatcher.Add(newRule); err != nil { logger.Info().Err(err).Msg("failed to add config automatically") diff --git a/internal/server/server.go b/internal/server/server.go new file mode 100644 index 00000000..94d84b3f --- /dev/null +++ b/internal/server/server.go @@ -0,0 +1,17 @@ +package server + +import "context" + +// Server represents a core component that processes network traffic +type Server interface { + // Start begins the execution of the server module + Start(ctx context.Context, ready chan<- struct{}) error + SetNetworkConfig() error + UnsetNetworkConfig() error + + // Stop gracefully terminates the server and releases resources + Stop() error + + // Addr returns the network address or interface name the server is bound to + Addr() string +} diff --git a/internal/server/socks5/bind.go b/internal/server/socks5/bind.go new file mode 100644 index 00000000..2e90f884 --- /dev/null +++ b/internal/server/socks5/bind.go @@ -0,0 +1,96 @@ +package socks5 + +import ( + "context" + "net" + "time" + + "github.com/rs/zerolog" + "github.com/xvzc/SpoofDPI/internal/logging" + "github.com/xvzc/SpoofDPI/internal/netutil" + "github.com/xvzc/SpoofDPI/internal/proto" +) + +type BindHandler struct { + logger zerolog.Logger +} + +func NewBindHandler(logger zerolog.Logger) *BindHandler { + return &BindHandler{ + logger: logger, + } +} + +func (h *BindHandler) Handle( + ctx context.Context, + conn net.Conn, + req *proto.SOCKS5Request, +) error { + logger := logging.WithLocalScope(ctx, h.logger, "bind") + + // 1. Listen on a random TCP port + listener, err := net.ListenTCP("tcp", &net.TCPAddr{IP: net.IPv4zero, Port: 0}) + if err != nil { + logger.Error().Err(err).Msg("failed to create bind listener") + _ = proto.SOCKS5FailureResponse().Write(conn) + return err + } + defer func() { _ = listener.Close() }() + + logger.Debug(). + Str("addr", listener.Addr().String()). + Str("network", listener.Addr().Network()). + Msg("new listener") + + lAddr := listener.Addr().(*net.TCPAddr) + + // 2. First Reply: Send the address/port we are listening on + if err := proto.SOCKS5SuccessResponse().Bind(lAddr.IP).Port(lAddr.Port).Write(conn); err != nil { + logger.Error().Err(err).Msg("failed to write first bind reply") + return err + } + + logger.Debug(). + Str("bind_addr", lAddr.String()). + Msg("waiting for incoming connection") + + // 3. Accept Incoming Connection + // The client should now tell the application server to connect to lAddr. + remoteConn, err := listener.Accept() + if err != nil { + logger.Error().Err(err).Msg("failed to accept incoming connection") + _ = proto.SOCKS5FailureResponse().Write(conn) + return err + } + defer netutil.CloseConns(remoteConn) + + rAddr := remoteConn.RemoteAddr().(*net.TCPAddr) + + logger.Debug(). + Str("remote_addr", rAddr.String()). + Msg("accepted incoming connection") + + // 4. Second Reply: Send the address/port of the connecting host + if err := proto.SOCKS5SuccessResponse().Bind(rAddr.IP).Port(rAddr.Port).Write(conn); err != nil { + logger.Error().Err(err).Msg("failed to write second bind reply") + return err + } + + // 5. Start bi-directional tunneling + resCh := make(chan netutil.TransferResult, 2) + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + startedAt := time.Now() + go netutil.TunnelConns(ctx, resCh, remoteConn, conn, netutil.TunnelDirOut) + go netutil.TunnelConns(ctx, resCh, conn, remoteConn, netutil.TunnelDirIn) + + return netutil.WaitAndLogTunnel( + ctx, + logger, + resCh, + startedAt, + netutil.DescribeRoute(conn, remoteConn), + nil, + ) +} diff --git a/internal/server/socks5/connect.go b/internal/server/socks5/connect.go new file mode 100644 index 00000000..ac3b09c4 --- /dev/null +++ b/internal/server/socks5/connect.go @@ -0,0 +1,200 @@ +package socks5 + +import ( + "context" + "fmt" + "io" + "net" + "time" + + "github.com/rs/zerolog" + "github.com/samber/lo" + "github.com/xvzc/SpoofDPI/internal/config" + "github.com/xvzc/SpoofDPI/internal/desync" + "github.com/xvzc/SpoofDPI/internal/logging" + "github.com/xvzc/SpoofDPI/internal/netutil" + "github.com/xvzc/SpoofDPI/internal/packet" + "github.com/xvzc/SpoofDPI/internal/proto" +) + +type ConnectHandler struct { + logger zerolog.Logger + desyncer *desync.TLSDesyncer + sniffer packet.Sniffer + serverOpts *config.ServerOptions + httpsOpts *config.HTTPSOptions +} + +func NewConnectHandler( + logger zerolog.Logger, + desyncer *desync.TLSDesyncer, + sniffer packet.Sniffer, + serverOpts *config.ServerOptions, + httpsOpts *config.HTTPSOptions, +) *ConnectHandler { + return &ConnectHandler{ + logger: logger, + desyncer: desyncer, + sniffer: sniffer, + serverOpts: serverOpts, + httpsOpts: httpsOpts, + } +} + +func (h *ConnectHandler) Handle( + ctx context.Context, + lConn net.Conn, + req *proto.SOCKS5Request, + dst *netutil.Destination, + rule *config.Rule, +) error { + opts := h.httpsOpts.Clone() + if rule != nil { + opts = opts.Merge(rule.HTTPS) + } + + logger := logging.WithLocalScope(ctx, h.logger, "connect") + + // 1. Validate Destination + ok, err := netutil.ValidateDestination(dst.Addrs, dst.Port, h.serverOpts.ListenAddr) + if err != nil { + logger.Debug().Err(err).Msg("error determining if valid destination") + if !ok { + _ = proto.SOCKS5FailureResponse().Write(lConn) + return err + } + } + + // 2. Check if blocked + if rule != nil && *rule.Block { + logger.Debug().Msg("request is blocked by policy") + _ = proto.SOCKS5FailureResponse().Write(lConn) + return netutil.ErrBlocked + } + + // 3. Send Success Response + if err := proto.SOCKS5SuccessResponse().Bind(net.IPv4zero).Port(0).Write(lConn); err != nil { + logger.Error().Err(err).Msg("failed to write socks5 success reply") + return err + } + + // logger := logging.WithLocalScope(ctx, h.logger, "connect(tcp)") + + rConn, err := netutil.DialFastest(ctx, "tcp", dst) + if err != nil { + return err + } + defer netutil.CloseConns(rConn) + + logger.Debug().Msgf("new remote conn -> %s", rConn.RemoteAddr()) + + // Wrap lConn with a buffered reader to peek for TLS + bufConn := netutil.NewBufferedConn(lConn) + + // Peek first byte to check for TLS Handshake (0x16) + // We try to peek 1 byte. + b, err := bufConn.Peek(1) + if err == nil && b[0] == byte(proto.TLSHandshake) { // 0x16 + + if h.sniffer != nil && lo.FromPtr(opts.FakeCount) > 0 { + h.sniffer.RegisterUntracked(dst.Addrs) + } + + return h.handleHTTPS(ctx, bufConn, rConn, opts) + } + + // If not TLS, fall back to pure TCP tunnel + logger.Debug().Msg("not a tls handshake. fallback to pure tcp") + + resCh := make(chan netutil.TransferResult, 2) + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + startedAt := time.Now() + go netutil.TunnelConns(ctx, resCh, rConn, bufConn, netutil.TunnelDirOut) + go netutil.TunnelConns(ctx, resCh, bufConn, rConn, netutil.TunnelDirIn) + + return netutil.WaitAndLogTunnel( + ctx, + logger, + resCh, + startedAt, + netutil.DescribeRoute(bufConn, rConn), + nil, + ) +} + +func (h *ConnectHandler) handleHTTPS( + ctx context.Context, + lConn net.Conn, // This is expected to be the BufferedConn + rConn net.Conn, + opts *config.HTTPSOptions, +) error { + logger := logging.WithLocalScope(ctx, h.logger, "connect(tls)") + + // Read the first message from the client (expected to be ClientHello) + tlsMsg, err := proto.ReadTLSMessage(lConn) + if err != nil { + if err == io.EOF || err.Error() == "unexpected EOF" { + return nil + } + logger.Trace().Err(err).Msgf("failed to read first message from client") + return err + } + + // It starts with 0x16, but is it a ClientHello? + if !tlsMsg.IsClientHello() { + logger.Debug(). + Int("len", tlsMsg.Len()). + Msg("not a client hello. fallback to pure tcp") + + // Forward the initial bytes we read + if _, err := rConn.Write(tlsMsg.Raw()); err != nil { + return fmt.Errorf("failed to write initial bytes to remote: %w", err) + } + } else { + logger.Debug(). + Int("len", tlsMsg.Len()). + Msgf("client hello received <- %s", lConn.RemoteAddr()) + + // Send ClientHello to the remote server (with desync if configured) + n, err := h.sendClientHello(ctx, rConn, tlsMsg, opts) + if err != nil { + return fmt.Errorf("failed to send client hello: %w", err) + } + + logger.Debug(). + Int("len", n). + Msgf("sent client hello -> %s", rConn.RemoteAddr()) + } + + resCh := make(chan netutil.TransferResult, 2) + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + startedAt := time.Now() + go netutil.TunnelConns(ctx, resCh, rConn, lConn, netutil.TunnelDirOut) + go netutil.TunnelConns(ctx, resCh, lConn, rConn, netutil.TunnelDirIn) + + return netutil.WaitAndLogTunnel( + ctx, + logger, + resCh, + startedAt, + netutil.DescribeRoute(lConn, rConn), + nil, + ) +} + +func (h *ConnectHandler) sendClientHello( + ctx context.Context, + rConn net.Conn, + msg *proto.TLSMessage, + opts *config.HTTPSOptions, +) (int, error) { + if lo.FromPtr(opts.Skip) { + return rConn.Write(msg.Raw()) + } + + return h.desyncer.Desync(ctx, h.logger, rConn, msg, opts) +} diff --git a/internal/server/socks5/network.go b/internal/server/socks5/network.go new file mode 100644 index 00000000..2bba7628 --- /dev/null +++ b/internal/server/socks5/network.go @@ -0,0 +1,13 @@ +//go:build !darwin && !linux + +package socks5 + +import "github.com/rs/zerolog" + +func SetSystemProxy(logger zerolog.Logger, port uint16) error { + return nil +} + +func UnsetSystemProxy(logger zerolog.Logger) error { + return nil +} diff --git a/internal/server/socks5/network_darwin.go b/internal/server/socks5/network_darwin.go new file mode 100644 index 00000000..c23e07fb --- /dev/null +++ b/internal/server/socks5/network_darwin.go @@ -0,0 +1,116 @@ +//go:build darwin + +package socks5 + +import ( + "errors" + "fmt" + "net" + "os/exec" + "strconv" + "strings" + + "github.com/rs/zerolog" + "github.com/xvzc/SpoofDPI/internal/netutil" +) + +const ( + permissionErrorHelpText = "By default SpoofDPI tries to set itself up as a system-wide proxy server.\n" + + "Doing so may require root access on machines with\n" + + "'Settings > Privacy & Security > Advanced > Require" + + " an administrator password to access system-wide settings' enabled.\n" + + "If you do not want SpoofDPI to act as a system-wide proxy, provide" + + " -system-proxy=false." +) + +var pacListener net.Listener + +func SetSystemProxy(logger zerolog.Logger, port uint16) error { + network, err := getDefaultNetwork() + if err != nil { + return err + } + + portStr := strconv.Itoa(int(port)) + pacContent := fmt.Sprintf(`function FindProxyForURL(url, host) { + return "SOCKS5 127.0.0.1:%s; DIRECT"; +}`, portStr) + + pacURL, l, err := netutil.RunPACServer(pacContent) + if err != nil { + return fmt.Errorf("error creating pac server: %w", err) + } + pacListener = l + + // Enable Auto Proxy Configuration + // networksetup -setautoproxyurl + if err := networkSetup("-setautoproxyurl", network, pacURL); err != nil { + return fmt.Errorf("setting autoproxyurl: %w", err) + } + + // networksetup -setproxyautodiscovery + if err := networkSetup("-setproxyautodiscovery", network, "on"); err != nil { + return fmt.Errorf("setting proxyautodiscovery: %w", err) + } + + return nil +} + +func UnsetSystemProxy(logger zerolog.Logger) error { + if pacListener != nil { + _ = pacListener.Close() + pacListener = nil + } + + network, err := getDefaultNetwork() + if err != nil { + return err + } + + // Disable Auto Proxy Configuration + if err := networkSetup("-setautoproxystate", network, "off"); err != nil { + return fmt.Errorf("unsetting autoproxystate: %w", err) + } + + if err := networkSetup("-setproxyautodiscovery", network, "off"); err != nil { + return fmt.Errorf("unsetting proxyautodiscovery: %w", err) + } + + return nil +} + +func getDefaultNetwork() (string, error) { + const cmd = "networksetup -listnetworkserviceorder | grep" + + " `(route -n get default | grep 'interface' || route -n get -inet6 default | grep 'interface') | cut -d ':' -f2`" + + " -B 1 | head -n 1 | cut -d ' ' -f 2-" + + out, err := exec.Command("sh", "-c", cmd).Output() + if err != nil { + return "", err + } + + network := strings.TrimSpace(string(out)) + if network == "" { + return "", errors.New("no available networks") + } + return network, nil +} + +func networkSetup(args ...string) error { + cmd := exec.Command("networksetup", args...) + out, err := cmd.CombinedOutput() + if err != nil { + msg := string(out) + if isPermissionError(err) { + msg += permissionErrorHelpText + } + return fmt.Errorf("%s", msg) + } + return nil +} + +func isPermissionError(err error) bool { + var exitErr *exec.ExitError + ok := errors.As(err, &exitErr) + return ok && exitErr.ExitCode() == 14 +} diff --git a/internal/server/socks5/network_linux.go b/internal/server/socks5/network_linux.go new file mode 100644 index 00000000..1350894c --- /dev/null +++ b/internal/server/socks5/network_linux.go @@ -0,0 +1,14 @@ +//go:build linux + +package socks5 + +import "github.com/rs/zerolog" + +func SetSystemProxy(logger zerolog.Logger, port uint16) error { + // Not implemented for Linux yet + return nil +} + +func UnsetSystemProxy(logger zerolog.Logger) error { + return nil +} diff --git a/internal/server/socks5/server.go b/internal/server/socks5/server.go new file mode 100644 index 00000000..de4c6e2e --- /dev/null +++ b/internal/server/socks5/server.go @@ -0,0 +1,296 @@ +package socks5 + +import ( + "bufio" + "bytes" + "context" + "errors" + "fmt" + "io" + "net" + "net/http" + + "github.com/rs/zerolog" + "github.com/samber/lo" + "github.com/xvzc/SpoofDPI/internal/config" + "github.com/xvzc/SpoofDPI/internal/dns" + "github.com/xvzc/SpoofDPI/internal/logging" + "github.com/xvzc/SpoofDPI/internal/matcher" + "github.com/xvzc/SpoofDPI/internal/netutil" + "github.com/xvzc/SpoofDPI/internal/proto" + "github.com/xvzc/SpoofDPI/internal/server" + "github.com/xvzc/SpoofDPI/internal/session" +) + +type SOCKS5Proxy struct { + logger zerolog.Logger + + resolver dns.Resolver + ruleMatcher matcher.RuleMatcher + connectHandler *ConnectHandler + bindHandler *BindHandler + udpAssociateHandler *UdpAssociateHandler + + serverOpts *config.ServerOptions + policyOpts *config.PolicyOptions + + listener net.Listener +} + +func NewSOCKS5Proxy( + logger zerolog.Logger, + resolver dns.Resolver, + ruleMatcher matcher.RuleMatcher, + connectHandler *ConnectHandler, + bindHandler *BindHandler, + udpAssociateHandler *UdpAssociateHandler, + serverOpts *config.ServerOptions, + policyOpts *config.PolicyOptions, +) server.Server { + return &SOCKS5Proxy{ + logger: logger, + resolver: resolver, + ruleMatcher: ruleMatcher, + connectHandler: connectHandler, + bindHandler: bindHandler, + udpAssociateHandler: udpAssociateHandler, + serverOpts: serverOpts, + policyOpts: policyOpts, + } +} + +func (p *SOCKS5Proxy) Start(ctx context.Context, ready chan<- struct{}) error { + listener, err := net.ListenTCP("tcp", p.serverOpts.ListenAddr) + if err != nil { + return fmt.Errorf( + "error creating listener on %s: %w", + p.serverOpts.ListenAddr.String(), + err, + ) + } + p.listener = listener + + if ready != nil { + close(ready) + } + + for { + conn, err := listener.Accept() + if err != nil { + if errors.Is(err, net.ErrClosed) { + return nil + } + p.logger.Error(). + Err(err). + Msg("failed to accept new connection") + continue + } + + go p.handleConnection(session.WithNewTraceID(ctx), conn) + } +} + +func (p *SOCKS5Proxy) Stop() error { + if p.listener != nil { + return p.listener.Close() + } + return nil +} + +func (p *SOCKS5Proxy) SetNetworkConfig() error { + return SetSystemProxy(p.logger, uint16(p.serverOpts.ListenAddr.Port)) +} + +func (p *SOCKS5Proxy) UnsetNetworkConfig() error { + return UnsetSystemProxy(p.logger) +} + +func (p *SOCKS5Proxy) Addr() string { + return p.serverOpts.ListenAddr.String() +} + +func (p *SOCKS5Proxy) handleConnection(ctx context.Context, conn net.Conn) { + logger := logging.WithLocalScope(ctx, p.logger, "socks5") + + ctx, cancel := context.WithCancel(ctx) + defer cancel() + defer netutil.CloseConns(conn) + + // 1. Negotiation Phase + if err := p.negotiate(logger, conn); err != nil { + logger.Debug().Err(err).Msg("negotiation failed") + return + } + + // 2. Request Phase + req, err := proto.ReadSocks5Request(conn) + if err != nil { + if err != io.EOF { + logger.Warn().Err(err).Msg("failed to read request") + } + return + } + + // ctx = session.WithHostInfo(ctx, req.Host()) + // logger = logger.With().Ctx(ctx).Logger() + + logger.Trace(). + Uint8("cmd", req.Cmd). + Int("port", req.Port). + Str("fqdn", req.FQDN). + Str("ip", req.IP.String()). + Msg("new request") + + var addrs []net.IP + var nameMatch *config.Rule + + if req.IP != nil { + addrs = []net.IP{req.IP} + } else if req.ATYP == proto.SOCKS5AddrTypeFQDN && len(req.FQDN) > 1 { + nameMatch = p.ruleMatcher.Search( + &matcher.Selector{ + Kind: matcher.MatchKindDomain, + Domain: lo.ToPtr(req.FQDN), // req.Domain -> req.FQDN + }, + ) + + // Resolve Domain + rSet, err := p.resolver.Resolve(ctx, req.FQDN, nil, nameMatch) + if err != nil { + logger.Error().Str("domain", req.FQDN).Err(err).Msgf("dns lookup failed") + return + } + + addrs = rSet.Addrs + } else { + logger.Trace().Msg("no addrs specified for this request. skipping") + } + + var selectors []*matcher.Selector + for _, v := range addrs { + selectors = append(selectors, &matcher.Selector{ + Kind: matcher.MatchKindAddr, + IP: lo.ToPtr(v), + Port: lo.ToPtr(uint16(req.Port)), + }) + } + + addrMatch := p.ruleMatcher.SearchAll(selectors) + + bestMatch := matcher.GetHigherPriorityRule(addrMatch, nameMatch) + if bestMatch != nil && logger.GetLevel() == zerolog.TraceLevel { + logger.Trace().RawJSON("summary", bestMatch.JSON()).Msg("match") + } + + switch req.Cmd { + case proto.SOCKS5CmdConnect: + dst := &netutil.Destination{ + Domain: req.FQDN, + Addrs: addrs, + Port: req.Port, + Timeout: *p.serverOpts.Timeout, + } + if err = p.connectHandler.Handle(ctx, conn, req, dst, bestMatch); err != nil { + return // Handler logs error + } + + case proto.SOCKS5CmdBind: + // Bind command usually implies user wants the server to listen. + // Destination address in request is usually zero or the IP of the client, + // but SOCKS5 spec says "DST.ADDR and DST.PORT fields of the BIND request contains + // the address and port of the party the client expects to connect to the application server." + // For our basic BindHandler, we might not strictly validate this yet. + if err = p.bindHandler.Handle(ctx, conn, req); err != nil { + return + } + + case proto.SOCKS5CmdUDPAssociate: + // UDP Associate usually doesn't have destination info in the request + if err = p.udpAssociateHandler.Handle(ctx, conn, req, nil, nil); err != nil { + logger.Error().Err(err).Msg("failed to handle udp_associate") + return + } + default: + err = proto.SOCKS5CommandNotSupportedResponse().Write(conn) + logger.Warn().Uint8("cmd", req.Cmd).Msg("unsupported command") + } + + if err == nil { + return + } + + logger.Error().Err(err).Msg("failed to handle") + if errors.Is(err, netutil.ErrBlocked) { + p.handleAutoConfig(ctx, req, addrs, bestMatch) + } +} + +func (p *SOCKS5Proxy) handleAutoConfig( + ctx context.Context, + req *proto.SOCKS5Request, + addrs []net.IP, + matchedRule *config.Rule, +) { + logger := zerolog.Ctx(ctx) + + if matchedRule != nil { + logger.Trace().Msg("skipping auto-policy for this request (duplicate policy)") + return + } + + if *p.policyOpts.Auto && p.policyOpts.Template != nil { + newRule := p.policyOpts.Template.Clone() + targetDomain := req.FQDN // req.Domain -> req.FQDN + if targetDomain == "" && len(addrs) > 0 { + targetDomain = addrs[0].String() + } + + newRule.Match = &config.MatchAttrs{Domains: []string{targetDomain}} + + if err := p.ruleMatcher.Add(newRule); err != nil { + logger.Info().Err(err).Msg("failed to add config automatically") + } else { + logger.Info().Msg("automatically added to config") + } + } +} + +func (p *SOCKS5Proxy) negotiate(logger zerolog.Logger, conn net.Conn) error { + header := make([]byte, 2) + if _, err := io.ReadFull(conn, header); err != nil { + return err + } + + if header[0] != proto.SOCKSVersion { + // Check if the first byte is 'C'(67), and the second byte is 'O'(79) + // indicating a potential HTTP CONNECT method + if len(header) > 1 && header[0] == 67 && header[1] == 79 { + // Reconstruct the stream using the already read header and the remaining connection + // This allows http.ReadRequest to parse the full request line including the method + mr := io.MultiReader(bytes.NewReader(header), conn) + bufReader := bufio.NewReader(mr) + + // Parse the HTTP request headers without waiting for EOF + // ReadRequest reads only the header section and stops + req, err := http.ReadRequest(bufReader) + if err != nil { + return fmt.Errorf("invalid request(unknown): %w", err) + } + + // req.Host contains the target domain (e.g., "google.com:443") + return fmt.Errorf("invalid request: http connect to %s", req.Host) + } + + return fmt.Errorf("invalid version: %d", header[0]) + } + + nMethods := int(header[1]) + methods := make([]byte, nMethods) + if _, err := io.ReadFull(conn, methods); err != nil { + return err + } + + // Respond: Version 5, Method NoAuth(0) + _, err := conn.Write([]byte{proto.SOCKSVersion, proto.SOCKS5AuthNone}) + return err +} diff --git a/internal/server/socks5/udp_associate.go b/internal/server/socks5/udp_associate.go new file mode 100644 index 00000000..c16d23a7 --- /dev/null +++ b/internal/server/socks5/udp_associate.go @@ -0,0 +1,244 @@ +package socks5 + +import ( + "context" + "encoding/binary" + "fmt" + "io" + "net" + + "github.com/rs/zerolog" + "github.com/xvzc/SpoofDPI/internal/config" + "github.com/xvzc/SpoofDPI/internal/logging" + "github.com/xvzc/SpoofDPI/internal/netutil" + "github.com/xvzc/SpoofDPI/internal/proto" +) + +type UdpAssociateHandler struct { + logger zerolog.Logger + pool *netutil.ConnPool +} + +func NewUdpAssociateHandler( + logger zerolog.Logger, + pool *netutil.ConnPool, +) *UdpAssociateHandler { + return &UdpAssociateHandler{ + logger: logger, + pool: pool, + } +} + +func (h *UdpAssociateHandler) Handle( + ctx context.Context, + lConn net.Conn, + req *proto.SOCKS5Request, + dst *netutil.Destination, + rule *config.Rule, +) error { + logger := logging.WithLocalScope(ctx, h.logger, "udp_associate") + + // 1. Listen on a random UDP port + lAddrTCP := lConn.LocalAddr().(*net.TCPAddr) // SOCKS5 listens on TCP + lNewConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: lAddrTCP.IP, Port: 0}) + if err != nil { + logger.Error().Err(err).Msg("failed to create udp listener") + _ = proto.SOCKS5FailureResponse().Write(lConn) + return err + } + defer netutil.CloseConns(lNewConn) + + logger.Debug(). + Str("addr", lNewConn.LocalAddr().String()). + Str("network", lNewConn.LocalAddr().Network()). + Msg("new conn") + + lAddr := lNewConn.LocalAddr().(*net.UDPAddr) + + logger.Debug(). + Str("bind_addr", lAddr.String()). + Msg("socks5 udp associate established") + + // 2. Reply with the bound address + if err := proto.SOCKS5SuccessResponse().Bind(lAddr.IP).Port(lAddr.Port).Write(lConn); err != nil { + logger.Error().Err(err).Msg("failed to write socks5 success reply") + return err + } + + // 3. Keep TCP Alive & Relay + // According to [RFC1928](https://datatracker.ietf.org/doc/html/rfc1928#section-6), + // > A UDP association terminates when the TCP connection that the UDP + // > ASSOCIATE request arrived on terminates. + // Therefore, we need to monitor TCP for closure. + done := make(chan struct{}) + go func() { + _, _ = io.Copy(io.Discard, lConn) // Block until TCP closes + close(done) + }() + + buf := make([]byte, 65535) + var clientAddr *net.UDPAddr + + for { + // Wait for data + n, addr, err := lNewConn.ReadFromUDP(buf) + if err != nil { + // Normal closure check + select { + case <-done: + return nil + default: + if err != io.EOF { + logger.Debug().Err(err).Msg("error reading from udp") + } + return err + } + } + + // Initial Client Identification + if clientAddr == nil { + clientAddr = addr + } + + // Only accept packets from the client that established the association + if !addr.IP.Equal(clientAddr.IP) || addr.Port != clientAddr.Port { + continue + } + + // Outbound: Client -> Proxy -> Target + targetAddrStr, payload, err := parseUDPHeader(buf[:n]) + if err != nil { + logger.Warn().Err(err).Msg("failed to parse socks5 udp header") + continue + } + + // Key: Client Addr -> Target Addr + key := clientAddr.String() + ">" + targetAddrStr + + // Resolve address to construct Destination + uAddr, err := net.ResolveUDPAddr("udp", targetAddrStr) + if err != nil { + logger.Warn(). + Err(err). + Str("addr", targetAddrStr). + Msg("failed to resolve udp target") + continue + } + + dst := &netutil.Destination{ + Addrs: []net.IP{uAddr.IP}, + Port: uAddr.Port, + } + + rawConn, err := netutil.DialFastest(ctx, "udp", dst) + if err != nil { + logger.Warn().Err(err).Str("addr", targetAddrStr).Msg("failed to dial udp target") + continue + } + + // Add to pool (pool handles LRU eviction and deadline) + conn := h.pool.Add(key, rawConn) + + // Start a goroutine to read from the target and forward to the client + go func(targetConn *netutil.PooledConn, clientAddr *net.UDPAddr) { + respBuf := make([]byte, 65535) + for { + n, _, err := targetConn.Conn.(*net.UDPConn).ReadFromUDP(respBuf) + if err != nil { + // Connection closed or network issues + return + } + + // Inbound: Target -> Proxy -> Client + // Wrap with SOCKS5 Header + remoteAddr := targetConn.Conn.(*net.UDPConn).RemoteAddr().(*net.UDPAddr) + header := createUDPHeaderFromAddr(remoteAddr) + response := append(header, respBuf[:n]...) + + if _, err := lNewConn.WriteToUDP(response, clientAddr); err != nil { + // If we can't write back to the client, it might be gone or network issue. + // Exit this goroutine to avoid busy looping. + logger.Warn().Err(err).Msg("failed to write udp to client") + return + } + } + }(conn, clientAddr) + + // Write payload to target + if _, err := conn.Write(payload); err != nil { + logger.Warn().Err(err).Msg("failed to write udp to target") + } + } +} + +func parseUDPHeader(b []byte) (string, []byte, error) { + if len(b) < 4 { + return "", nil, fmt.Errorf("header too short") + } + // RSV(2) FRAG(1) ATYP(1) + if b[0] != 0 || b[1] != 0 { + return "", nil, fmt.Errorf("invalid rsv") + } + frag := b[2] + if frag != 0 { + return "", nil, fmt.Errorf("fragmentation not supported") + } + + atyp := b[3] + var host string + var pos int + + switch atyp { + case proto.SOCKS5AddrTypeIPv4: + if len(b) < 10 { + return "", nil, fmt.Errorf("header too short for ipv4") + } + host = net.IP(b[4:8]).String() + pos = 8 + case proto.SOCKS5AddrTypeIPv6: + if len(b) < 22 { + return "", nil, fmt.Errorf("header too short for ipv6") + } + host = net.IP(b[4:20]).String() + pos = 20 + case proto.SOCKS5AddrTypeFQDN: + if len(b) < 5 { + return "", nil, fmt.Errorf("header too short for fqdn") + } + l := int(b[4]) + if len(b) < 5+l+2 { + return "", nil, fmt.Errorf("header too short for fqdn data") + } + host = string(b[5 : 5+l]) + pos = 5 + l + default: + return "", nil, fmt.Errorf("unsupported atyp: %d", atyp) + } + + port := binary.BigEndian.Uint16(b[pos : pos+2]) + pos += 2 + + addr := net.JoinHostPort(host, fmt.Sprintf("%d", port)) + return addr, b[pos:], nil +} + +func createUDPHeaderFromAddr(addr *net.UDPAddr) []byte { + // RSV(2) FRAG(1) ATYP(1) ... + buf := make([]byte, 0, 24) + buf = append(buf, 0, 0, 0) // RSV, FRAG + + ip4 := addr.IP.To4() + if ip4 != nil { + buf = append(buf, proto.SOCKS5AddrTypeIPv4) + buf = append(buf, ip4...) + } else { + buf = append(buf, proto.SOCKS5AddrTypeIPv6) + buf = append(buf, addr.IP.To16()...) + } + + portBuf := make([]byte, 2) + binary.BigEndian.PutUint16(portBuf, uint16(addr.Port)) + buf = append(buf, portBuf...) + + return buf +} diff --git a/internal/server/tun/network.go b/internal/server/tun/network.go new file mode 100644 index 00000000..da6e6488 --- /dev/null +++ b/internal/server/tun/network.go @@ -0,0 +1,15 @@ +//go:build !darwin + +package tun + +func SetRouting(iface string, subnets []string) error { + return nil +} + +func UnsetRouting(iface string, subnets []string) error { + return nil +} + +func SetInterfaceAddress(iface string, local string, remote string) error { + return nil +} diff --git a/internal/server/tun/network_darwin.go b/internal/server/tun/network_darwin.go new file mode 100644 index 00000000..41e2a35a --- /dev/null +++ b/internal/server/tun/network_darwin.go @@ -0,0 +1,117 @@ +package tun + +import ( + "fmt" + "os/exec" + "strings" +) + +// SetRoute configures network routes for specified subnets +func SetRoute(iface string, subnets []string) error { + for _, subnet := range subnets { + targets := []string{subnet} + + /* Expand default route into two /1 subnets to override the default gateway + without removing the existing 0.0.0.0/0 entry. + */ + if subnet == "0.0.0.0/0" { + targets = []string{"0.0.0.0/1", "128.0.0.0/1"} + } + + for _, target := range targets { + cmd := exec.Command("route", "-n", "add", "-net", target, "-interface", iface) + out, err := cmd.CombinedOutput() + if err != nil { + // Check if it's a permission error + if strings.Contains(string(out), "must be root") { + return fmt.Errorf( + "permission denied: must run as root to modify routing table (sudo required)", + ) + } + return fmt.Errorf( + "failed to add route for %s on %s: %s: %w", + target, + iface, + string(out), + err, + ) + } + } + } + return nil +} + +// UnsetRoute removes previously configured network routes +func UnsetRoute(iface string, subnets []string) error { + for _, subnet := range subnets { + targets := []string{subnet} + + if subnet == "0.0.0.0/0" { + targets = []string{"0.0.0.0/1", "128.0.0.0/1"} + } + + for _, target := range targets { + /* Delete specific routes to revert traffic flow to the original gateway. */ + cmd := exec.Command("route", "-n", "delete", "-net", target, "-interface", iface) + if out, err := cmd.CombinedOutput(); err != nil { + _ = out + } + } + } + return nil +} + +// SetInterfaceAddress configures the TUN interface with local and remote endpoints +func SetInterfaceAddress(iface string, local string, remote string) error { + cmd := exec.Command("ifconfig", iface, local, remote, "up") + if out, err := cmd.CombinedOutput(); err != nil { + return fmt.Errorf("failed to set interface address: %s: %w", string(out), err) + } + return nil +} + +// UnsetGatewayRoute removes the gateway route +func UnsetGatewayRoute(gateway, iface string) error { + // Remove the direct route to the gateway + cmd := exec.Command("route", "-n", "delete", "-host", gateway, "-interface", iface) + if out, err := cmd.CombinedOutput(); err != nil { + _ = out + } + + // Also try to remove the 0.0.0.0/2 route if it exists (cleanup from previous versions) + cmd = exec.Command("route", "-n", "delete", "-net", "0.0.0.0/32", gateway) + if out, err := cmd.CombinedOutput(); err != nil { + _ = out + } + + return nil +} + +// SetGatewayRoute adds a host route to the gateway via the specified interface +// This ensures traffic destined for the gateway goes through the physical interface +func SetGatewayRoute(gateway, iface string) error { + // First, get the gateway's subnet to add a direct route + cmd := exec.Command("route", "-n", "add", "-host", gateway, "-interface", iface) + out, err := cmd.CombinedOutput() + if err != nil { + // Ignore "File exists" error - route already exists + if !strings.Contains(string(out), "File exists") { + return fmt.Errorf("failed to add gateway route: %s: %w", string(out), err) + } + } + + // Also add a less specific route that uses the gateway for 0.0.0.0/2 + // This provides a path for IP_BOUND_IF sockets on en0 to reach external hosts + // The 0/2 route is less specific than 0/1 but will be used when bound to en0 + cmd = exec.Command("route", "-n", "add", "-net", "0.0.0.0/32", gateway) + out, err = cmd.CombinedOutput() + if err != nil { + // Ignore "File exists" error + if !strings.Contains(string(out), "File exists") { + // This is optional, don't fail if it doesn't work + _ = out + } + } + + return nil +} diff --git a/internal/server/tun/server.go b/internal/server/tun/server.go new file mode 100644 index 00000000..8dceefdc --- /dev/null +++ b/internal/server/tun/server.go @@ -0,0 +1,372 @@ +package tun + +import ( + "context" + "errors" + "fmt" + "io" + "io/fs" + "net" + "os" + "strconv" + + "github.com/rs/zerolog" + "github.com/samber/lo" + "github.com/songgao/water" + "github.com/xvzc/SpoofDPI/internal/config" + "github.com/xvzc/SpoofDPI/internal/logging" + "github.com/xvzc/SpoofDPI/internal/matcher" + "github.com/xvzc/SpoofDPI/internal/netutil" + "github.com/xvzc/SpoofDPI/internal/server" + "github.com/xvzc/SpoofDPI/internal/session" + "gvisor.dev/gvisor/pkg/buffer" + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet" + "gvisor.dev/gvisor/pkg/tcpip/link/channel" + "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" + "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" + "gvisor.dev/gvisor/pkg/tcpip/transport/udp" + "gvisor.dev/gvisor/pkg/waiter" +) + +// Ensure tcpip is used to avoid "imported and not used" error +var _ tcpip.NetworkProtocolNumber = ipv4.ProtocolNumber + +type TunServer struct { + logger zerolog.Logger + config *config.Config + matcher matcher.RuleMatcher // For IP-based rule matching + + tcpHandler *TCPHandler + udpHandler *UDPHandler + + iface *water.Interface + defaultIface string + defaultGateway string +} + +func NewTunServer( + logger zerolog.Logger, + config *config.Config, + matcher matcher.RuleMatcher, + tcpHandler *TCPHandler, + udpHandler *UDPHandler, +) server.Server { + return &TunServer{ + logger: logger, + config: config, + matcher: matcher, + tcpHandler: tcpHandler, + udpHandler: udpHandler, + } +} + +func (s *TunServer) Start(ctx context.Context, ready chan<- struct{}) error { + iface, err := NewTunDevice() + if err != nil { + return fmt.Errorf("failed to create tun device: %w", err) + } + s.iface = iface + + if ready != nil { + close(ready) + } + + return s.handle(ctx, iface) +} + +func (s *TunServer) Stop() error { + if s.iface != nil { + return s.iface.Close() + } + return nil +} + +func (s *TunServer) SetNetworkConfig() error { + if s.iface == nil { + return fmt.Errorf("tun device not initialized") + } + + // Find default interface and gateway before modifying routes + defaultIface, defaultGateway, err := netutil.GetDefaultInterfaceAndGateway() + if err != nil { + return fmt.Errorf("failed to get default interface: %w", err) + } + s.logger.Info(). + Str("interface", defaultIface). + Str("gateway", defaultGateway). + Msg("determined default interface and gateway") + s.defaultIface = defaultIface + s.defaultGateway = defaultGateway + + // Update handlers with network info + s.tcpHandler.SetNetworkInfo(defaultIface, defaultGateway) + s.udpHandler.SetNetworkInfo(defaultIface, defaultGateway) + + local, remote, err := netutil.FindSafeSubnet() + if err != nil { + return fmt.Errorf("failed to find safe subnet: %w", err) + } + + if err := SetInterfaceAddress(s.iface.Name(), local, remote); err != nil { + return fmt.Errorf("failed to set interface address: %w", err) + } + + // Add route for the TUN interface subnet to ensure packets can return + // This is crucial for the TUN interface to receive packets destined for its own subnet + if err := SetRoute(s.iface.Name(), []string{local + "/30"}); err != nil { + return fmt.Errorf("failed to set local route: %w", err) + } + + // Add a host route to the gateway via the physical interface + // This ensures SpoofDPI's outbound traffic goes through en0, not utun8 + if err := SetGatewayRoute(defaultGateway, defaultIface); err != nil { + s.logger.Warn().Err(err).Msg("failed to set gateway route") + } + + return SetRoute(s.iface.Name(), []string{"0.0.0.0/0"}) // Default Route +} + +func (s *TunServer) UnsetNetworkConfig() error { + if s.iface == nil { + return nil + } + + // Remove the gateway route + if s.defaultGateway != "" && s.defaultIface != "" { + if err := UnsetGatewayRoute(s.defaultGateway, s.defaultIface); err != nil { + s.logger.Warn().Err(err).Msg("failed to unset gateway route") + } + } + + return UnsetRoute(s.iface.Name(), []string{"0.0.0.0/0"}) // Default Route +} + +func (s *TunServer) Addr() string { + if s.iface != nil { + return s.iface.Name() + } + return "tun" +} + +// matchRuleByAddr extracts IP and port from net.Addr and performs rule matching +func (s *TunServer) matchRuleByAddr(addr net.Addr) *config.Rule { + if s.matcher == nil { + return nil + } + + host, portStr, err := net.SplitHostPort(addr.String()) + if err != nil { + return nil + } + + ip := net.ParseIP(host) + if ip == nil { + return nil + } + + port, _ := strconv.Atoi(portStr) + + selector := &matcher.Selector{ + Kind: matcher.MatchKindAddr, + IP: lo.ToPtr(ip), + Port: lo.ToPtr(uint16(port)), + } + + return s.matcher.Search(selector) +} + +func (s *TunServer) handle(ctx context.Context, iface *water.Interface) error { + logger := logging.WithLocalScope(ctx, s.logger, "tun") + + // 1. Create gVisor stack + stk := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{ + tcp.NewProtocol, + udp.NewProtocol, + }, + }) + + // 2. Create channel endpoint + ep := channel.New(256, 1500, "") + + const nicID = 1 + if err := stk.CreateNIC(nicID, ep); err != nil { + return fmt.Errorf("failed to create NIC: %v", err) + } + + // 3. Enable Promiscuous mode & Spoofing + stk.SetPromiscuousMode(nicID, true) + stk.SetSpoofing(nicID, true) + + // 3.5. Add default route to the stack + // Define a subnet that matches all IPv4 addresses (0.0.0.0/0) + defaultSubnet, _ := tcpip.NewSubnet( + tcpip.AddrFrom4([4]byte{0, 0, 0, 0}), + tcpip.MaskFrom("\x00\x00\x00\x00"), + ) + + stk.SetRouteTable([]tcpip.Route{ + { + Destination: defaultSubnet, + NIC: nicID, + }, + }) + + // 4. Register TCP Forwarder + tcpFwd := tcp.NewForwarder(stk, 0, 65535, func(r *tcp.ForwarderRequest) { + var wq waiter.Queue + ep, err := r.CreateEndpoint(&wq) + if err != nil { + logger.Error().Msgf("failed to create endpoint: %v", err) + r.Complete(true) + return + } + r.Complete(false) + + conn := gonet.NewTCPConn(&wq, ep) + + // Match rule by IP before passing to handler + rule := s.matchRuleByAddr(conn.LocalAddr()) + go s.tcpHandler.Handle(session.WithNewTraceID(context.Background()), conn, rule) + }) + stk.SetTransportProtocolHandler(tcp.ProtocolNumber, tcpFwd.HandlePacket) + + // 5. Register UDP Forwarder + udpFwd := udp.NewForwarder(stk, func(r *udp.ForwarderRequest) bool { + var wq waiter.Queue + ep, err := r.CreateEndpoint(&wq) + if err != nil { + logger.Error().Msgf("failed to create udp endpoint: %v", err) + return true + } + + conn := gonet.NewUDPConn(&wq, ep) + + // Match rule by IP before passing to handler + rule := s.matchRuleByAddr(conn.LocalAddr()) + go s.udpHandler.Handle(session.WithNewTraceID(context.Background()), conn, rule) + return true + }) + stk.SetTransportProtocolHandler(udp.ProtocolNumber, udpFwd.HandlePacket) + + // 6. Start packet pump + go s.tunToStack(ctx, logger, iface, ep) + go s.stackToTun(ctx, logger, iface, ep) + + <-ctx.Done() + return nil +} + +func (s *TunServer) tunToStack( + ctx context.Context, + logger zerolog.Logger, + iface *water.Interface, + ep *channel.Endpoint, +) { + buf := make([]byte, 2000) + for { + n, err := iface.Read(buf) + if err != nil { + if errors.Is(err, fs.ErrClosed) || errors.Is(err, os.ErrClosed) { + return + } + + select { + case <-ctx.Done(): + return + default: + if err != io.EOF { + logger.Error().Err(err).Msg("failed to read from tun") + } + return + } + } + + if n < 1 { + continue + } + + version := (buf[0] >> 4) + if version != 4 { + logger.Trace().Int("version", int(version)).Msg("skipping non-ipv4 packet") + continue + } + + // Parse source and destination IP for debugging + // if n >= 20 { + // srcIP := net.IP(buf[12:16]) + // dstIP := net.IP(buf[16:20]) + // protocol := buf[9] + // logger.Trace(). + // Str("src", srcIP.String()). + // Str("dst", dstIP.String()). + // Uint8("proto", protocol). + // Int("len", n). + // Msg("injecting packet to stack") + // } + + payload := buffer.MakeWithData(append([]byte(nil), buf[:n]...)) + + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + Payload: payload, + }) + ep.InjectInbound(ipv4.ProtocolNumber, pkt) + pkt.DecRef() + } +} + +type notifier struct { + ch chan<- struct{} +} + +func (n *notifier) WriteNotify() { + select { + case n.ch <- struct{}{}: + default: + } +} + +func (s *TunServer) stackToTun( + ctx context.Context, + logger zerolog.Logger, + iface *water.Interface, + ep *channel.Endpoint, +) { + ch := make(chan struct{}, 1) + n := ¬ifier{ch: ch} + ep.AddNotify(n) + + for { + select { + case <-ctx.Done(): + return + default: + } + + pkt := ep.Read() + if pkt == nil { + select { + case <-ch: + continue + case <-ctx.Done(): + return + } + } + + views := pkt.ToView().AsSlice() + if len(views) > 0 { + _, _ = iface.Write(views) + } + pkt.DecRef() + } +} + +func NewTunDevice() (*water.Interface, error) { + config := water.Config{ + DeviceType: water.TUN, + } + return water.New(config) +} diff --git a/internal/server/tun/tcp.go b/internal/server/tun/tcp.go new file mode 100644 index 00000000..00f8cb06 --- /dev/null +++ b/internal/server/tun/tcp.go @@ -0,0 +1,225 @@ +package tun + +import ( + "context" + "fmt" + "net" + "strconv" + "time" + + "github.com/rs/zerolog" + "github.com/xvzc/SpoofDPI/internal/config" + "github.com/xvzc/SpoofDPI/internal/desync" + "github.com/xvzc/SpoofDPI/internal/logging" + "github.com/xvzc/SpoofDPI/internal/matcher" + "github.com/xvzc/SpoofDPI/internal/netutil" + "github.com/xvzc/SpoofDPI/internal/packet" + "github.com/xvzc/SpoofDPI/internal/proto" +) + +type TCPHandler struct { + logger zerolog.Logger + domainMatcher matcher.RuleMatcher // For TLS domain matching only + httpsOpts *config.HTTPSOptions + desyncer *desync.TLSDesyncer + sniffer packet.Sniffer // For TTL tracking + iface string + gateway string +} + +func NewTCPHandler( + logger zerolog.Logger, + domainMatcher matcher.RuleMatcher, + httpsOpts *config.HTTPSOptions, + desyncer *desync.TLSDesyncer, + sniffer packet.Sniffer, + iface string, + gateway string, +) *TCPHandler { + return &TCPHandler{ + logger: logger, + domainMatcher: domainMatcher, + httpsOpts: httpsOpts, + desyncer: desyncer, + sniffer: sniffer, + iface: iface, + gateway: gateway, + } +} + +func (h *TCPHandler) Handle(ctx context.Context, lConn net.Conn, rule *config.Rule) { + logger := logging.WithLocalScope(ctx, h.logger, "tcp") + + defer netutil.CloseConns(lConn) + + // Set a read deadline for the first byte to avoid hanging indefinitely + _ = lConn.SetReadDeadline(time.Now().Add(1 * time.Second)) + + lBufferedConn := netutil.NewBufferedConn(lConn) + buf, err := lBufferedConn.Peek(1) + if err != nil { + return + } + + // Reset deadline + _ = lConn.SetReadDeadline(time.Time{}) + + // Parse destination from local address (which is the original destination in TUN) + host, portStr, err := net.SplitHostPort(lConn.LocalAddr().String()) + if err != nil { + return + } + port, _ := strconv.Atoi(portStr) + + ip := net.ParseIP(host) + var iface *net.Interface + if h.iface != "" { + iface, _ = net.InterfaceByName(h.iface) + logger.Debug().Str("iface", h.iface).Msg("using interface for dial") + } else { + logger.Debug().Msg("no interface specified for dial") + } + + dst := &netutil.Destination{ + Domain: host, + Port: port, + Addrs: []net.IP{}, + Iface: iface, + Gateway: h.gateway, + } + if ip != nil { + dst.Addrs = append(dst.Addrs, ip) + } + + // Check if it's a TLS Handshake (Content Type 0x16) + if buf[0] == 0x16 { + logger.Debug().Msg("detected tls handshake") + if err := h.handleTLS(ctx, logger, lBufferedConn, dst, rule); err != nil { + logger.Debug().Err(err).Msg("tls handler failed") + } + return + } + + // Handle as plain TCP + rConn, err := netutil.DialFastest(ctx, "tcp", dst) + if err != nil { + return + } + + logger.Debug().Msgf("new remote conn -> %s", rConn.RemoteAddr()) + + resCh := make(chan netutil.TransferResult, 2) + + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + startedAt := time.Now() + go netutil.TunnelConns(ctx, resCh, lBufferedConn, rConn, netutil.TunnelDirOut) + go netutil.TunnelConns(ctx, resCh, rConn, lBufferedConn, netutil.TunnelDirIn) + + err = netutil.WaitAndLogTunnel( + ctx, + logger, + resCh, + startedAt, + netutil.DescribeRoute(lConn, rConn), + nil, + ) + if err != nil { + logger.Error().Err(err).Msg("error handling request") + } +} + +func (h *TCPHandler) handleTLS( + ctx context.Context, + logger zerolog.Logger, + lConn net.Conn, + dst *netutil.Destination, + addrRule *config.Rule, // Rule matched by IP in server.go +) error { + // Read ClientHello + tlsMsg, err := proto.ReadTLSMessage(lConn) + if err != nil { + return err + } + + if !tlsMsg.IsClientHello() { + return fmt.Errorf("not a client hello") + } + + // Extract SNI + start, end, err := tlsMsg.ExtractSNIOffset() + if err != nil { + return fmt.Errorf("failed to extract sni: %w", err) + } + dst.Domain = string(tlsMsg.Raw()[start:end]) + + logger.Trace().Str("value", dst.Domain).Msgf("extracted sni feild") + + // Match Rules + opts := h.httpsOpts.Clone() + + // First, apply IP-based rule if matched in server.go + if addrRule != nil { + logger.Trace().RawJSON("summary", addrRule.JSON()).Msg("addr match") + opts = opts.Merge(addrRule.HTTPS) + } + + // Then, try domain-based matching (TLS-specific) + if h.domainMatcher != nil { + domainSelector := &matcher.Selector{ + Kind: matcher.MatchKindDomain, + Domain: &dst.Domain, + } + if domainRule := h.domainMatcher.Search(domainSelector); domainRule != nil { + logger.Trace().RawJSON("summary", domainRule.JSON()).Msg("domain match") + // Domain rule takes priority if it has higher priority + finalRule := matcher.GetHigherPriorityRule(addrRule, domainRule) + if finalRule == domainRule { + opts = h.httpsOpts.Clone().Merge(domainRule.HTTPS) + } + } + } + + // Dial Remote + if h.sniffer != nil { + h.sniffer.RegisterUntracked(dst.Addrs) + } + rConn, err := netutil.DialFastest(ctx, "tcp", dst) + if err != nil { + return err + } + defer netutil.CloseConns(rConn) + + logger.Debug(). + Msgf("new remote conn (%s -> %s)", lConn.RemoteAddr(), rConn.RemoteAddr()) + + // Send ClientHello with Desync + if _, err := h.desyncer.Desync(ctx, logger, rConn, tlsMsg, opts); err != nil { + return err + } + + // Tunnel rest + resCh := make(chan netutil.TransferResult, 2) + + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + startedAt := time.Now() + go netutil.TunnelConns(ctx, resCh, lConn, rConn, netutil.TunnelDirOut) + go netutil.TunnelConns(ctx, resCh, rConn, lConn, netutil.TunnelDirIn) + + return netutil.WaitAndLogTunnel( + ctx, + logger, + resCh, + startedAt, + netutil.DescribeRoute(lConn, rConn), + nil, + ) +} + +func (h *TCPHandler) SetNetworkInfo(iface, gateway string) { + h.iface = iface + h.gateway = gateway +} diff --git a/internal/server/tun/udp.go b/internal/server/tun/udp.go new file mode 100644 index 00000000..0a54def3 --- /dev/null +++ b/internal/server/tun/udp.go @@ -0,0 +1,119 @@ +package tun + +import ( + "context" + "net" + "strconv" + "time" + + "github.com/rs/zerolog" + "github.com/xvzc/SpoofDPI/internal/config" + "github.com/xvzc/SpoofDPI/internal/desync" + "github.com/xvzc/SpoofDPI/internal/logging" + "github.com/xvzc/SpoofDPI/internal/netutil" +) + +type UDPHandler struct { + logger zerolog.Logger + udpOpts *config.UDPOptions + desyncer *desync.UDPDesyncer + pool *netutil.ConnPool + iface string + gateway string +} + +func NewUDPHandler( + logger zerolog.Logger, + desyncer *desync.UDPDesyncer, + udpOpts *config.UDPOptions, + pool *netutil.ConnPool, +) *UDPHandler { + return &UDPHandler{ + logger: logger, + desyncer: desyncer, + udpOpts: udpOpts, + pool: pool, + } +} + +func (h *UDPHandler) SetNetworkInfo(iface, gateway string) { + h.iface = iface + h.gateway = gateway +} + +func (h *UDPHandler) Handle(ctx context.Context, lConn net.Conn, rule *config.Rule) { + logger := logging.WithLocalScope(ctx, h.logger, "udp") + + defer netutil.CloseConns(lConn) + + host, portStr, err := net.SplitHostPort(lConn.LocalAddr().String()) + if err != nil { + return + } + port, _ := strconv.Atoi(portStr) + + var iface *net.Interface + if h.iface != "" { + iface, _ = net.InterfaceByName(h.iface) + } + + dst := &netutil.Destination{ + Domain: host, + Port: port, + Iface: iface, + Gateway: h.gateway, + } + if ip := net.ParseIP(host); ip != nil { + dst.Addrs = []net.IP{ip} + } + + // Apply rule if matched in server.go + opts := h.udpOpts.Clone() + if rule != nil { + logger.Trace().RawJSON("summary", rule.JSON()).Msg("match") + opts = opts.Merge(rule.UDP) + } + + // Key for connection pool + key := lConn.RemoteAddr().String() + ">" + lConn.LocalAddr().String() + + // Dial remote connection + rawConn, err := netutil.DialFastest(ctx, "udp", dst) + if err != nil { + return + } + + // Add to pool (pool handles LRU eviction and deadline) + rConn := h.pool.Add(key, rawConn) + + // Wrap lConn with TimeoutConn as well + timeout := *opts.Timeout + lConnWrapped := &netutil.TimeoutConn{Conn: lConn, Timeout: timeout} + + // Desync + _, _ = h.desyncer.Desync(ctx, lConnWrapped, rConn, opts) + + logger.Debug(). + Msgf("new remote conn (%s -> %s)", lConn.RemoteAddr(), rConn.RemoteAddr()) + + resCh := make(chan netutil.TransferResult, 2) + + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + startedAt := time.Now() + go netutil.TunnelConns(ctx, resCh, lConnWrapped, rConn, netutil.TunnelDirOut) + go netutil.TunnelConns(ctx, resCh, rConn, lConnWrapped, netutil.TunnelDirIn) + + err = netutil.WaitAndLogTunnel( + ctx, + logger, + resCh, + startedAt, + netutil.DescribeRoute(lConnWrapped, rConn), + nil, + ) + if err != nil { + logger.Error().Err(err).Msg("error handling request") + } +} diff --git a/internal/session/session.go b/internal/session/session.go index 1d9e3f06..280fcb28 100644 --- a/internal/session/session.go +++ b/internal/session/session.go @@ -34,7 +34,8 @@ func TraceIDFrom(ctx context.Context) (string, bool) { if ok { return traceID, true } - return "", false + + return "0000000000000000", false } // WithHostInfo returns a new context carrying the given domain name string. diff --git a/internal/system/sysproxy.go b/internal/system/sysproxy.go deleted file mode 100644 index 1419fdab..00000000 --- a/internal/system/sysproxy.go +++ /dev/null @@ -1,13 +0,0 @@ -//go:build !darwin && !linux - -package system - -import "github.com/rs/zerolog" - -func SetProxy(logger zerolog.Logger, port uint16) error { - return nil -} - -func UnsetProxy(logger zerolog.Logger) error { - return nil -} diff --git a/internal/system/sysproxy_darwin.go b/internal/system/sysproxy_darwin.go deleted file mode 100644 index 77973e1d..00000000 --- a/internal/system/sysproxy_darwin.go +++ /dev/null @@ -1,100 +0,0 @@ -//go:build darwin - -package system - -import ( - "errors" - "fmt" - "os/exec" - "strconv" - "strings" - - "github.com/rs/zerolog" -) - -const ( - getDefaultNetworkCMD = "networksetup -listnetworkserviceorder | grep" + - " `(route -n get default | grep 'interface' || route -n get -inet6 default | grep 'interface') | cut -d ':' -f2`" + - " -B 1 | head -n 1 | cut -d ' ' -f 2-" - permissionErrorHelpText = "By default SpoofDPI tries to set itself up as a system-wide proxy server.\n" + - "Doing so may require root access on machines with\n" + - "'Settings > Privacy & Security > Advanced > Require" + - " an administrator password to access system-wide settings' enabled.\n" + - "If you do not want SpoofDPI to act as a system-wide proxy, provide" + - " -system-proxy=false." -) - -func SetProxy(logger zerolog.Logger, port uint16) error { - network, err := getDefaultNetwork() - if err != nil { - return err - } - - return setProxyInternal(getProxyTypes(), network, "127.0.0.1", int(port)) -} - -func UnsetProxy(logger zerolog.Logger) error { - network, err := getDefaultNetwork() - if err != nil { - return err - } - - return unsetProxyInternal(getProxyTypes(), network) -} - -func getDefaultNetwork() (string, error) { - network, err := exec.Command("sh", "-c", getDefaultNetworkCMD).Output() - if err != nil { - return "", err - } else if len(network) == 0 { - return "", errors.New("no available networks") - } - return strings.TrimSpace(string(network)), nil -} - -func getProxyTypes() []string { - return []string{"webproxy", "securewebproxy"} -} - -func setProxyInternal(proxyTypes []string, network, domain string, port int) error { - args := []string{"", network, domain, strconv.FormatUint(uint64(port), 10)} - - for _, proxyType := range proxyTypes { - args[0] = "-set" + proxyType - if err := networkSetup(args); err != nil { - return fmt.Errorf("setting %s: %w", proxyType, err) - } - } - return nil -} - -func unsetProxyInternal(proxyTypes []string, network string) error { - args := []string{"", network, "off"} - - for _, proxyType := range proxyTypes { - args[0] = "-set" + proxyType + "state" - if err := networkSetup(args); err != nil { - return fmt.Errorf("unsetting %s: %w", proxyType, err) - } - } - return nil -} - -func networkSetup(args []string) error { - cmd := exec.Command("networksetup", args...) - out, err := cmd.CombinedOutput() - if err != nil { - msg := string(out) - if isPermissionError(err) { - msg += permissionErrorHelpText - } - return fmt.Errorf("%s", msg) - } - return nil -} - -func isPermissionError(err error) bool { - var exitErr *exec.ExitError - ok := errors.As(err, &exitErr) - return ok && exitErr.ExitCode() == 14 -} diff --git a/internal/system/sysproxy_linux.go b/internal/system/sysproxy_linux.go deleted file mode 100644 index 1916a7cd..00000000 --- a/internal/system/sysproxy_linux.go +++ /dev/null @@ -1,14 +0,0 @@ -//go:build linux - -package system - -import "github.com/rs/zerolog" - -func SetProxy(logger zerolog.Logger, port uint16) error { - logger.Info().Msgf("automatic system-wide proxy setup is not implemented on Linux") - return nil -} - -func UnsetProxy(logger zerolog.Logger) error { - return nil -} From d23b69a65368d4216a3379eaf23b7e63c64109bd Mon Sep 17 00:00:00 2001 From: xvzc Date: Fri, 16 Jan 2026 23:13:13 +0900 Subject: [PATCH 08/39] feat: support socks5 and tun for linux --- cmd/spoofdpi/main.go | 62 +++-- cmd/spoofdpi/main_test.go | 39 ++- internal/config/cli.go | 89 +++--- internal/config/cli_test.go | 77 +++--- internal/config/config.go | 67 ++--- internal/config/config_test.go | 26 +- internal/config/parse.go | 8 +- internal/config/toml_test.go | 26 +- internal/config/types.go | 167 ++++++----- internal/config/types_test.go | 147 +++++----- internal/config/validate.go | 2 +- internal/desync/tls.go | 2 +- internal/dns/https.go | 21 +- internal/dns/route.go | 28 +- internal/dns/system.go | 12 +- internal/dns/udp.go | 21 +- internal/netutil/addr.go | 21 -- internal/netutil/bind.go | 12 + .../netutil/{dial_darwin.go => bind_bsd.go} | 21 +- internal/netutil/bind_linux.go | 38 +++ internal/netutil/conn.go | 14 +- internal/netutil/dial.go | 8 +- internal/netutil/dial_other.go | 11 - internal/netutil/gateway.go | 8 + internal/netutil/gateway_bsd.go | 28 ++ internal/netutil/gateway_linux.go | 29 ++ internal/packet/tcp_sniffer.go | 31 ++- internal/server/http/https.go | 37 ++- internal/server/http/network.go | 6 +- internal/server/http/network_linux.go | 14 - internal/server/http/{proxy.go => server.go} | 21 +- internal/server/socks5/connect.go | 40 +-- internal/server/socks5/network.go | 6 +- internal/server/socks5/network_linux.go | 14 - internal/server/socks5/server.go | 19 +- internal/server/tun/network.go | 14 +- .../tun/{network_darwin.go => network_bsd.go} | 2 + internal/server/tun/network_linux.go | 260 ++++++++++++++++++ internal/server/tun/server.go | 5 +- internal/server/tun/tcp.go | 50 ++-- internal/server/tun/udp.go | 35 ++- 41 files changed, 1046 insertions(+), 492 deletions(-) create mode 100644 internal/netutil/bind.go rename internal/netutil/{dial_darwin.go => bind_bsd.go} (65%) create mode 100644 internal/netutil/bind_linux.go delete mode 100644 internal/netutil/dial_other.go create mode 100644 internal/netutil/gateway.go create mode 100644 internal/netutil/gateway_bsd.go create mode 100644 internal/netutil/gateway_linux.go delete mode 100644 internal/server/http/network_linux.go rename internal/server/http/{proxy.go => server.go} (91%) delete mode 100644 internal/server/socks5/network_linux.go rename internal/server/tun/{network_darwin.go => network_bsd.go} (99%) create mode 100644 internal/server/tun/network_linux.go diff --git a/cmd/spoofdpi/main.go b/cmd/spoofdpi/main.go index 84bd475d..5836009d 100644 --- a/cmd/spoofdpi/main.go +++ b/cmd/spoofdpi/main.go @@ -42,11 +42,11 @@ func main() { } func runApp(ctx context.Context, configDir string, cfg *config.Config) { - if !*cfg.General.Silent { + if !*cfg.App.Silent { printBanner() } - logging.SetGlobalLogger(ctx, *cfg.General.LogLevel) + logging.SetGlobalLogger(ctx, *cfg.App.LogLevel) logger := log.Logger.With().Ctx(ctx).Logger() logger.Info().Str("version", version).Msg("started spoofdpi") @@ -74,7 +74,7 @@ func runApp(ctx context.Context, configDir string, cfg *config.Config) { <-ready // System Proxy Config - if *cfg.General.SetNetworkConfig { + if *cfg.App.SetNetworkConfig { if err := srv.SetNetworkConfig(); err != nil { logger.Fatal().Err(err).Msg("failed to set system network config") } @@ -108,13 +108,23 @@ func runApp(ctx context.Context, configDir string, cfg *config.Config) { Bool("auto", *cfg.Policy.Auto). Msgf("policy") - if *cfg.Server.Timeout > 0 { + if *cfg.Conn.DNSTimeout > 0 { logger.Info(). - Str("value", fmt.Sprintf("%dms", cfg.Server.Timeout)). - Msgf("connection timeout") + Str("value", fmt.Sprintf("%dms", cfg.Conn.DNSTimeout.Milliseconds())). + Msgf("dns connection timeout") + } + if *cfg.Conn.TCPTimeout > 0 { + logger.Info(). + Str("value", fmt.Sprintf("%dms", cfg.Conn.TCPTimeout.Milliseconds())). + Msgf("tcp connection timeout") + } + if *cfg.Conn.UDPTimeout > 0 { + logger.Info(). + Str("value", fmt.Sprintf("%dms", cfg.Conn.UDPTimeout.Milliseconds())). + Msgf("udp connection timeout") } - logger.Info().Msgf("server-mode; %s", cfg.Server.Mode.String()) + logger.Info().Msgf("app-mode; %s", cfg.App.Mode.String()) logger.Info().Msgf("server started on %s", srv.Addr()) @@ -142,9 +152,17 @@ func runApp(ctx context.Context, configDir string, cfg *config.Config) { func createResolver(logger zerolog.Logger, cfg *config.Config) dns.Resolver { // create a TTL cache for storing DNS records. - udpResolver := dns.NewUDPResolver(logging.WithScope(logger, "dns"), cfg.DNS.Clone()) + udpResolver := dns.NewUDPResolver( + logging.WithScope(logger, "dns"), + cfg.DNS.Clone(), + cfg.Conn.Clone(), + ) - dohResolver := dns.NewHTTPSResolver(logging.WithScope(logger, "dns"), cfg.DNS.Clone()) + dohResolver := dns.NewHTTPSResolver( + logging.WithScope(logger, "dns"), + cfg.DNS.Clone(), + cfg.Conn.Clone(), + ) sysResolver := dns.NewSystemResolver( logging.WithScope(logger, "dns"), @@ -238,7 +256,7 @@ func createPacketObjects( logging.WithScope(logger, "pkt"), hopCache, tcpHandle, - uint8(*cfg.Server.DefaultTTL), + uint8(*cfg.Conn.DefaultFakeTTL), ) tcpSniffer.StartCapturing() @@ -254,7 +272,7 @@ func createPacketObjects( logging.WithScope(logger, "pkt"), hopCache, udpHandle, - uint8(*cfg.Server.DefaultTTL), + uint8(*cfg.Conn.DefaultFakeTTL), ) udpSniffer.StartCapturing() @@ -306,14 +324,15 @@ func createServer( tcpSniffer, ) - switch *cfg.Server.Mode { - case config.ServerModeHTTP: + switch *cfg.App.Mode { + case config.AppModeHTTP: httpHandler := http.NewHTTPHandler(logging.WithScope(logger, "hnd")) httpsHandler := http.NewHTTPSHandler( logging.WithScope(logger, "hnd"), desyncer, tcpSniffer, cfg.HTTPS.Clone(), + cfg.Conn.Clone(), ) return http.NewHTTPProxy( @@ -322,15 +341,17 @@ func createServer( httpHandler, httpsHandler, ruleMatcher, - cfg.Server.Clone(), + cfg.App.Clone(), + cfg.Conn.Clone(), cfg.Policy.Clone(), ), nil - case config.ServerModeSOCKS5: + case config.AppModeSOCKS5: connectHandler := socks5.NewConnectHandler( logging.WithScope(logger, "hnd"), desyncer, tcpSniffer, - cfg.Server.Clone(), + cfg.App.Clone(), + cfg.Conn.Clone(), cfg.HTTPS.Clone(), ) udpAssociateHandler := socks5.NewUdpAssociateHandler( @@ -346,14 +367,16 @@ func createServer( connectHandler, bindHandler, udpAssociateHandler, - cfg.Server.Clone(), + cfg.App.Clone(), + cfg.Conn.Clone(), cfg.Policy.Clone(), ), nil - case config.ServerModeTUN: + case config.AppModeTUN: tcpHandler := tun.NewTCPHandler( logging.WithScope(logger, "hnd"), ruleMatcher, // For domain-based TLS matching cfg.HTTPS.Clone(), + cfg.Conn.Clone(), desyncer, tcpSniffer, // For TTL tracking "", // iface and gateway will be set later @@ -370,6 +393,7 @@ func createServer( logging.WithScope(logger, "hnd"), udpDesyncer, cfg.UDP.Clone(), + cfg.Conn.Clone(), netutil.NewConnPool(4096, 60*time.Second), ) @@ -381,7 +405,7 @@ func createServer( udpHandler, ), nil default: - return nil, fmt.Errorf("unknown server mode: %s", *cfg.Server.Mode) + return nil, fmt.Errorf("unknown server mode: %s", *cfg.App.Mode) } } diff --git a/cmd/spoofdpi/main_test.go b/cmd/spoofdpi/main_test.go index 22101337..9334512e 100644 --- a/cmd/spoofdpi/main_test.go +++ b/cmd/spoofdpi/main_test.go @@ -22,6 +22,11 @@ func TestCreateResolver(t *testing.T) { QType: lo.ToPtr(config.DNSQueryIPv4), Cache: lo.ToPtr(true), } + cfg.Conn = &config.ConnOptions{ + DNSTimeout: lo.ToPtr(time.Duration(0)), + TCPTimeout: lo.ToPtr(time.Duration(0)), + UDPTimeout: lo.ToPtr(time.Duration(0)), + } logger := zerolog.Nop() resolver := createResolver(logger, cfg) @@ -30,15 +35,21 @@ func TestCreateResolver(t *testing.T) { } func TestCreateProxy_NoPcap(t *testing.T) { - // Setup configuration that doesn't require PCAP (root privileges) + // Setup configuration that dAppModeHTTP PCAP (root privileges) cfg := config.NewConfig() - // Server Config - cfg.Server = &config.ServerOptions{ - Mode: lo.ToPtr(config.ServerModeHTTP), + // App Config + cfg.App = &config.AppOptions{ + Mode: lo.ToPtr(config.AppModeHTTP), ListenAddr: &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 0}, - DefaultTTL: lo.ToPtr(uint8(64)), - Timeout: lo.ToPtr(time.Duration(0)), + } + + // Conn Config + cfg.Conn = &config.ConnOptions{ + DefaultFakeTTL: lo.ToPtr(uint8(64)), + DNSTimeout: lo.ToPtr(time.Duration(0)), + TCPTimeout: lo.ToPtr(time.Duration(0)), + UDPTimeout: lo.ToPtr(time.Duration(0)), } // HTTPS Config (Ensure FakeCount is 0 to disable PCAP) @@ -76,12 +87,18 @@ func TestCreateProxy_NoPcap(t *testing.T) { func TestCreateProxy_WithPolicy(t *testing.T) { cfg := config.NewConfig() - // Server Config - cfg.Server = &config.ServerOptions{ - Mode: lo.ToPtr(config.ServerModeHTTP), + // App Config + cfg.App = &config.AppOptions{ + Mode: lo.ToPtr(config.AppModeHTTP), ListenAddr: &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 0}, - DefaultTTL: lo.ToPtr(uint8(64)), - Timeout: lo.ToPtr(time.Duration(0)), + } + + // Conn Config + cfg.Conn = &config.ConnOptions{ + DefaultFakeTTL: lo.ToPtr(uint8(64)), + DNSTimeout: lo.ToPtr(time.Duration(0)), + TCPTimeout: lo.ToPtr(time.Duration(0)), + UDPTimeout: lo.ToPtr(time.Duration(0)), } // HTTPS Config diff --git a/internal/config/cli.go b/internal/config/cli.go index 08197aad..bf035404 100644 --- a/internal/config/cli.go +++ b/internal/config/cli.go @@ -36,6 +36,20 @@ func CreateCommand( return err }, Flags: []cli.Flag{ + &cli.StringFlag{ + Name: "app-mode", + Usage: fmt.Sprintf(`<"http"|"socks5"|"tun"> + Specifies the proxy mode. (default: %q)`, + defaultCfg.App.Mode.String(), + ), + OnlyOnce: true, + Validator: checkAppMode, + Action: func(ctx context.Context, cmd *cli.Command, v string) error { + argsCfg.App.Mode = lo.ToPtr(MustParseServerModeType(v)) + return nil + }, + }, + &cli.BoolFlag{ Name: "clean", Usage: ` @@ -54,29 +68,15 @@ func CreateCommand( }, &cli.Int64Flag{ - Name: "default-ttl", + Name: "default-fake-ttl", Usage: fmt.Sprintf(` - Default TTL value for manipulated packets. (default: %v)`, - *defaultCfg.Server.DefaultTTL, + Default TTL value for fake packets. (default: %v)`, + *defaultCfg.Conn.DefaultFakeTTL, ), OnlyOnce: true, Validator: checkUint8NonZero, Action: func(ctx context.Context, cmd *cli.Command, v int64) error { - argsCfg.Server.DefaultTTL = lo.ToPtr(uint8(v)) - return nil - }, - }, - - &cli.StringFlag{ - Name: "server-mode", - Usage: fmt.Sprintf(`<"http"|"socks5"|"tun"> - Specifies the proxy mode. (default: %q)`, - defaultCfg.Server.Mode.String(), - ), - OnlyOnce: true, - Validator: checkServerMode, - Action: func(ctx context.Context, cmd *cli.Command, v string) error { - argsCfg.Server.Mode = lo.ToPtr(MustParseServerModeType(v)) + argsCfg.Conn.DefaultFakeTTL = lo.ToPtr(uint8(v)) return nil }, }, @@ -157,6 +157,25 @@ func CreateCommand( }, }, + &cli.Int64Flag{ + Name: "dns-timeout", + Usage: fmt.Sprintf(` + Timeout for dns connection in milliseconds. + No effect when the value is 0 (default: %v, max: %v)`, + defaultCfg.Conn.DNSTimeout.Milliseconds(), + math.MaxUint16, + ), + Value: 0, + OnlyOnce: true, + Validator: checkUint16, + Action: func(ctx context.Context, cmd *cli.Command, v int64) error { + argsCfg.Conn.DNSTimeout = lo.ToPtr( + time.Duration(v * int64(time.Millisecond)), + ) + return nil + }, + }, + &cli.Int64Flag{ Name: "https-fake-count", Usage: fmt.Sprintf(` @@ -280,14 +299,18 @@ func CreateCommand( &cli.Int64Flag{ Name: "udp-timeout", Usage: fmt.Sprintf(` - UDP session timeout in milliseconds. (default: %v)`, - *defaultCfg.UDP.Timeout, + Timeout for udp connection in milliseconds. + No effect when the value is 0 (default: %v, max: %v)`, + defaultCfg.Conn.UDPTimeout.Milliseconds(), + math.MaxUint16, ), Value: 0, OnlyOnce: true, Validator: checkUint16, Action: func(ctx context.Context, cmd *cli.Command, v int64) error { - argsCfg.UDP.Timeout = lo.ToPtr(time.Duration(v) * time.Millisecond) + argsCfg.Conn.UDPTimeout = lo.ToPtr( + time.Duration(v * int64(time.Millisecond)), + ) return nil }, }, @@ -302,7 +325,7 @@ func CreateCommand( if v == "" { return nil } - argsCfg.Server.ListenAddr = lo.ToPtr(MustParseTCPAddr(v)) + argsCfg.App.ListenAddr = lo.ToPtr(MustParseTCPAddr(v)) return nil }, }, @@ -314,7 +337,7 @@ func CreateCommand( OnlyOnce: true, Validator: checkLogLevel, Action: func(ctx context.Context, cmd *cli.Command, v string) error { - argsCfg.General.LogLevel = lo.ToPtr(MustParseLogLevel(v)) + argsCfg.App.LogLevel = lo.ToPtr(MustParseLogLevel(v)) return nil }, }, @@ -336,11 +359,11 @@ func CreateCommand( Name: "silent", Usage: fmt.Sprintf(` Do not show the banner at start up (default: %v)`, - *defaultCfg.General.Silent, + *defaultCfg.App.Silent, ), OnlyOnce: true, Action: func(ctx context.Context, cmd *cli.Command, v bool) error { - argsCfg.General.Silent = lo.ToPtr(v) + argsCfg.App.Silent = lo.ToPtr(v) return nil }, }, @@ -349,28 +372,28 @@ func CreateCommand( Name: "network-config", Usage: fmt.Sprintf(` Automatically set system-wide proxy configuration (default: %v)`, - *defaultCfg.General.SetNetworkConfig, + *defaultCfg.App.SetNetworkConfig, ), OnlyOnce: true, Action: func(ctx context.Context, cmd *cli.Command, v bool) error { - argsCfg.General.SetNetworkConfig = lo.ToPtr(v) + argsCfg.App.SetNetworkConfig = lo.ToPtr(v) return nil }, }, &cli.Int64Flag{ - Name: "timeout", + Name: "tcp-timeout", Usage: fmt.Sprintf(` Timeout for tcp connection in milliseconds. No effect when the value is 0 (default: %v, max: %v)`, - *defaultCfg.Server.Timeout, + defaultCfg.Conn.TCPTimeout.Milliseconds(), math.MaxUint16, ), Value: 0, OnlyOnce: true, Validator: checkUint16, Action: func(ctx context.Context, cmd *cli.Command, v int64) error { - argsCfg.Server.Timeout = lo.ToPtr( + argsCfg.Conn.TCPTimeout = lo.ToPtr( time.Duration(v * int64(time.Millisecond)), ) return nil @@ -419,12 +442,12 @@ func CreateCommand( finalCfg := defaultCfg.Merge(tomlCfg.Merge(argsCfg)) - if finalCfg.Server.ListenAddr == nil { + if finalCfg.App.ListenAddr == nil { port := 8080 - if *finalCfg.Server.Mode == ServerModeSOCKS5 { + if *finalCfg.App.Mode == AppModeSOCKS5 { port = 1080 } - finalCfg.Server.ListenAddr = &net.TCPAddr{ + finalCfg.App.ListenAddr = &net.TCPAddr{ IP: net.ParseIP("127.0.0.1"), Port: port, } diff --git a/internal/config/cli_test.go b/internal/config/cli_test.go index 70ced9eb..3c0ee8e7 100644 --- a/internal/config/cli_test.go +++ b/internal/config/cli_test.go @@ -24,12 +24,14 @@ func TestCreateCommand_Flags(t *testing.T) { args: []string{"spoofdpi", "--clean"}, assert: func(t *testing.T, cfg *Config) { // Verify defaults are preserved - assert.Equal(t, zerolog.InfoLevel, *cfg.General.LogLevel) - assert.False(t, *cfg.General.Silent) - assert.False(t, *cfg.General.SetNetworkConfig) - assert.Equal(t, "127.0.0.1:8080", cfg.Server.ListenAddr.String()) - assert.Equal(t, uint8(64), *cfg.Server.DefaultTTL) - assert.Equal(t, time.Duration(0), *cfg.Server.Timeout) + assert.Equal(t, zerolog.InfoLevel, *cfg.App.LogLevel) + assert.False(t, *cfg.App.Silent) + assert.False(t, *cfg.App.SetNetworkConfig) + assert.Equal(t, "127.0.0.1:8080", cfg.App.ListenAddr.String()) + assert.Equal(t, uint8(8), *cfg.Conn.DefaultFakeTTL) + assert.Equal(t, int64(5000), cfg.Conn.DNSTimeout.Milliseconds()) + assert.Equal(t, int64(10000), cfg.Conn.TCPTimeout.Milliseconds()) + assert.Equal(t, int64(25000), cfg.Conn.UDPTimeout.Milliseconds()) assert.Equal(t, "8.8.8.8:53", cfg.DNS.Addr.String()) assert.Equal(t, DNSModeUDP, *cfg.DNS.Mode) assert.Equal(t, "https://dns.google/dns-query", *cfg.DNS.HTTPSURL) @@ -54,8 +56,10 @@ func TestCreateCommand_Flags(t *testing.T) { "--silent", "--network-config", "--listen-addr", "127.0.0.1:9090", - "--default-ttl", "128", - "--timeout", "5000", + "--default-fake-ttl", "128", + "--dns-timeout", "5000", + "--tcp-timeout", "5000", + "--udp-timeout", "5000", "--dns-addr", "1.1.1.1:53", "--dns-mode", "https", "--dns-https-url", "https://cloudflare-dns.com/dns-query", @@ -69,19 +73,20 @@ func TestCreateCommand_Flags(t *testing.T) { "--https-skip", "--udp-fake-count", "5", "--udp-fake-packet", "0x01, 0x02", - "--udp-timeout", "1000", "--policy-auto", }, assert: func(t *testing.T, cfg *Config) { // General - assert.Equal(t, zerolog.DebugLevel, *cfg.General.LogLevel) - assert.True(t, *cfg.General.Silent) - assert.True(t, *cfg.General.SetNetworkConfig) + assert.Equal(t, zerolog.DebugLevel, *cfg.App.LogLevel) + assert.True(t, *cfg.App.Silent) + assert.True(t, *cfg.App.SetNetworkConfig) // Server - assert.Equal(t, "127.0.0.1:9090", cfg.Server.ListenAddr.String()) - assert.Equal(t, uint8(128), *cfg.Server.DefaultTTL) - assert.Equal(t, 5000*time.Millisecond, *cfg.Server.Timeout) + assert.Equal(t, "127.0.0.1:9090", cfg.App.ListenAddr.String()) + assert.Equal(t, uint8(128), *cfg.Conn.DefaultFakeTTL) + assert.Equal(t, 5000*time.Millisecond, *cfg.Conn.DNSTimeout) + assert.Equal(t, 5000*time.Millisecond, *cfg.Conn.TCPTimeout) + assert.Equal(t, 5000*time.Millisecond, *cfg.Conn.UDPTimeout) // DNS assert.Equal(t, "1.1.1.1:53", cfg.DNS.Addr.String()) @@ -101,7 +106,7 @@ func TestCreateCommand_Flags(t *testing.T) { // UDP assert.Equal(t, 5, *cfg.UDP.FakeCount) assert.Equal(t, []byte{0x01, 0x02}, cfg.UDP.FakePacket) - assert.Equal(t, 1000*time.Millisecond, *cfg.UDP.Timeout) + assert.Equal(t, []byte{0x01, 0x02}, cfg.UDP.FakePacket) // Policy assert.True(t, *cfg.Policy.Auto) @@ -118,7 +123,7 @@ func TestCreateCommand_Flags(t *testing.T) { "--https-split-mode", "random", }, assert: func(t *testing.T, cfg *Config) { - assert.Equal(t, zerolog.ErrorLevel, *cfg.General.LogLevel) + assert.Equal(t, zerolog.ErrorLevel, *cfg.App.LogLevel) assert.Equal(t, DNSModeSystem, *cfg.DNS.Mode) assert.Equal(t, DNSQueryAll, *cfg.DNS.QType) assert.Equal(t, HTTPSSplitModeRandom, *cfg.HTTPS.SplitMode) @@ -132,9 +137,9 @@ func TestCreateCommand_Flags(t *testing.T) { "--listen-addr", "[::1]:1080", }, assert: func(t *testing.T, cfg *Config) { - assert.Equal(t, "[::1]:1080", cfg.Server.ListenAddr.String()) + assert.Equal(t, "[::1]:1080", cfg.App.ListenAddr.String()) ip := net.ParseIP("::1") - assert.True(t, cfg.Server.ListenAddr.IP.Equal(ip)) + assert.True(t, cfg.App.ListenAddr.IP.Equal(ip)) }, }, { @@ -142,11 +147,11 @@ func TestCreateCommand_Flags(t *testing.T) { args: []string{ "spoofdpi", "--clean", - "--server-mode", "socks5", + "--app-mode", "socks5", }, assert: func(t *testing.T, cfg *Config) { - assert.Equal(t, "127.0.0.1:1080", cfg.Server.ListenAddr.String()) - assert.Equal(t, ServerModeSOCKS5, *cfg.Server.Mode) + assert.Equal(t, "127.0.0.1:1080", cfg.App.ListenAddr.String()) + assert.Equal(t, AppModeSOCKS5, *cfg.App.Mode) }, }, } @@ -182,8 +187,10 @@ func TestCreateCommand_OverrideTOML(t *testing.T) { [server] listen-addr = "127.0.0.1:8080" - timeout = 1000 - default-ttl = 100 + dns-timeout = 1000 + tcp-timeout = 1000 + udp-timeout = 1000 + default-fake-ttl = 100 [dns] addr = "8.8.8.8:53" @@ -248,8 +255,10 @@ func TestCreateCommand_OverrideTOML(t *testing.T) { "--silent=false", "--network-config=false", "--listen-addr", "127.0.0.1:9090", - "--timeout", "2000", - "--default-ttl", "200", + "--dns-timeout", "2000", + "--tcp-timeout", "2000", + "--udp-timeout", "2000", + "--default-fake-ttl", "200", "--dns-addr", "1.1.1.1:53", "--dns-cache=false", "--dns-mode", "udp", @@ -272,14 +281,16 @@ func TestCreateCommand_OverrideTOML(t *testing.T) { // Verify Overrides // General - assert.Equal(t, zerolog.ErrorLevel, *capturedCfg.General.LogLevel) - assert.False(t, *capturedCfg.General.Silent) - assert.False(t, *capturedCfg.General.SetNetworkConfig) + assert.Equal(t, zerolog.ErrorLevel, *capturedCfg.App.LogLevel) + assert.False(t, *capturedCfg.App.Silent) + assert.False(t, *capturedCfg.App.SetNetworkConfig) // Server - assert.Equal(t, "127.0.0.1:9090", capturedCfg.Server.ListenAddr.String()) - assert.Equal(t, 2000*time.Millisecond, *capturedCfg.Server.Timeout) - assert.Equal(t, uint8(200), *capturedCfg.Server.DefaultTTL) + assert.Equal(t, "127.0.0.1:9090", capturedCfg.App.ListenAddr.String()) + assert.Equal(t, 2000*time.Millisecond, *capturedCfg.Conn.DNSTimeout) + assert.Equal(t, 2000*time.Millisecond, *capturedCfg.Conn.TCPTimeout) + assert.Equal(t, 2000*time.Millisecond, *capturedCfg.Conn.UDPTimeout) + assert.Equal(t, uint8(200), *capturedCfg.Conn.DefaultFakeTTL) // DNS assert.Equal(t, "1.1.1.1:53", capturedCfg.DNS.Addr.String()) @@ -299,7 +310,7 @@ func TestCreateCommand_OverrideTOML(t *testing.T) { // UDP assert.Equal(t, 20, *capturedCfg.UDP.FakeCount) assert.Equal(t, []byte{0xcc, 0xdd}, capturedCfg.UDP.FakePacket) - assert.Equal(t, time.Duration(0), *capturedCfg.UDP.Timeout) + assert.Equal(t, []byte{0xcc, 0xdd}, capturedCfg.UDP.FakePacket) // Policy assert.False(t, *capturedCfg.Policy.Auto) diff --git a/internal/config/config.go b/internal/config/config.go index 2f58acb0..06fdc77c 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -22,12 +22,12 @@ type cloner[T any] interface { var _ merger[*Config] = (*Config)(nil) type Config struct { - General *GeneralOptions `toml:"general"` - Server *ServerOptions `toml:"server"` - DNS *DNSOptions `toml:"dns"` - HTTPS *HTTPSOptions `toml:"https"` - UDP *UDPOptions `toml:"udp"` - Policy *PolicyOptions `toml:"policy"` + App *AppOptions `toml:"general"` + Conn *ConnOptions `toml:"connection"` + DNS *DNSOptions `toml:"dns"` + HTTPS *HTTPSOptions `toml:"https"` + UDP *UDPOptions `toml:"udp"` + Policy *PolicyOptions `toml:"policy"` } func (c *Config) UnmarshalTOML(data any) (err error) { @@ -36,8 +36,8 @@ func (c *Config) UnmarshalTOML(data any) (err error) { return fmt.Errorf("non-table type config file") } - c.General = findStructFrom[GeneralOptions](m, "general", &err) - c.Server = findStructFrom[ServerOptions](m, "server", &err) + c.App = findStructFrom[AppOptions](m, "general", &err) + c.Conn = findStructFrom[ConnOptions](m, "connection", &err) c.DNS = findStructFrom[DNSOptions](m, "dns", &err) c.HTTPS = findStructFrom[HTTPSOptions](m, "https", &err) c.UDP = findStructFrom[UDPOptions](m, "udp", &err) @@ -48,12 +48,12 @@ func (c *Config) UnmarshalTOML(data any) (err error) { func NewConfig() *Config { return &Config{ - General: &GeneralOptions{}, - Server: &ServerOptions{}, - DNS: &DNSOptions{}, - HTTPS: &HTTPSOptions{}, - UDP: &UDPOptions{}, - Policy: &PolicyOptions{}, + App: &AppOptions{}, + Conn: &ConnOptions{}, + DNS: &DNSOptions{}, + HTTPS: &HTTPSOptions{}, + UDP: &UDPOptions{}, + Policy: &PolicyOptions{}, } } @@ -63,12 +63,12 @@ func (c *Config) Clone() *Config { } return &Config{ - General: c.General.Clone(), - Server: c.Server.Clone(), - DNS: c.DNS.Clone(), - HTTPS: c.HTTPS.Clone(), - UDP: c.UDP.Clone(), - Policy: c.Policy.Clone(), + App: c.App.Clone(), + Conn: c.Conn.Clone(), + DNS: c.DNS.Clone(), + HTTPS: c.HTTPS.Clone(), + UDP: c.UDP.Clone(), + Policy: c.Policy.Clone(), } } @@ -82,12 +82,12 @@ func (origin *Config) Merge(overrides *Config) *Config { } return &Config{ - General: origin.General.Merge(overrides.General), - Server: origin.Server.Merge(overrides.Server), - DNS: origin.DNS.Merge(overrides.DNS), - HTTPS: origin.HTTPS.Merge(overrides.HTTPS), - UDP: origin.UDP.Merge(overrides.UDP), - Policy: origin.Policy.Merge(overrides.Policy), + App: origin.App.Merge(overrides.App), + Conn: origin.Conn.Merge(overrides.Conn), + DNS: origin.DNS.Merge(overrides.DNS), + HTTPS: origin.HTTPS.Merge(overrides.HTTPS), + UDP: origin.UDP.Merge(overrides.UDP), + Policy: origin.Policy.Merge(overrides.Policy), } } @@ -132,16 +132,18 @@ func (c *Config) ShouldEnablePcap() bool { func getDefault() *Config { //exhaustruct:enforce return &Config{ - General: &GeneralOptions{ + App: &AppOptions{ LogLevel: lo.ToPtr(zerolog.InfoLevel), Silent: lo.ToPtr(false), SetNetworkConfig: lo.ToPtr(false), + Mode: lo.ToPtr(AppModeHTTP), + ListenAddr: nil, }, - Server: &ServerOptions{ - Mode: lo.ToPtr(ServerModeHTTP), - DefaultTTL: lo.ToPtr(uint8(64)), - ListenAddr: nil, - Timeout: lo.ToPtr(time.Duration(0)), + Conn: &ConnOptions{ + DefaultFakeTTL: lo.ToPtr(uint8(8)), + DNSTimeout: lo.ToPtr(time.Duration(5000) * time.Millisecond), + TCPTimeout: lo.ToPtr(time.Duration(10000) * time.Millisecond), + UDPTimeout: lo.ToPtr(time.Duration(25000) * time.Millisecond), }, DNS: &DNSOptions{ Mode: lo.ToPtr(DNSModeUDP), @@ -162,7 +164,6 @@ func getDefault() *Config { //exhaustruct:enforce UDP: &UDPOptions{ FakeCount: lo.ToPtr(0), FakePacket: make([]byte, 64), - Timeout: lo.ToPtr(time.Duration(0)), }, Policy: &PolicyOptions{ Auto: lo.ToPtr(false), diff --git a/internal/config/config_test.go b/internal/config/config_test.go index dcaca3ef..f42c2720 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -18,7 +18,7 @@ func TestConfig_UnmarshalTOML(t *testing.T) { { name: "valid config", input: map[string]any{ - "server": map[string]any{ + "general": map[string]any{ "listen-addr": "127.0.0.1:9090", }, "dns": map[string]any{ @@ -40,7 +40,7 @@ func TestConfig_UnmarshalTOML(t *testing.T) { }, wantErr: false, assert: func(t *testing.T, c Config) { - assert.Equal(t, "127.0.0.1:9090", c.Server.ListenAddr.String()) + assert.Equal(t, "127.0.0.1:9090", c.App.ListenAddr.String()) assert.Equal(t, "1.1.1.1:53", c.DNS.Addr.String()) if assert.Len(t, c.Policy.Overrides, 1) { assert.Equal(t, "test", *c.Policy.Overrides[0].Name) @@ -55,7 +55,7 @@ func TestConfig_UnmarshalTOML(t *testing.T) { { name: "validation error", input: map[string]any{ - "server": map[string]any{ + "general": map[string]any{ "listen-addr": "invalid-addr", }, }, @@ -149,17 +149,17 @@ func TestConfig_Merge(t *testing.T) { { name: "keep toml if arg is nil", tomlCfg: &Config{ - Server: &ServerOptions{ + App: &AppOptions{ ListenAddr: &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 8080}, }, }, argsCfg: &Config{ - Server: &ServerOptions{ + App: &AppOptions{ ListenAddr: nil, }, }, assert: func(t *testing.T, merged *Config) { - assert.Equal(t, "127.0.0.1:8080", merged.Server.ListenAddr.String()) + assert.Equal(t, "127.0.0.1:8080", merged.App.ListenAddr.String()) }, }, { @@ -199,17 +199,17 @@ func TestConfig_Clone(t *testing.T) { { name: "non-nil receiver", input: &Config{ - General: &GeneralOptions{}, - Server: &ServerOptions{}, - DNS: &DNSOptions{}, - HTTPS: &HTTPSOptions{}, - Policy: &PolicyOptions{}, + App: &AppOptions{}, + Conn: &ConnOptions{}, + DNS: &DNSOptions{}, + HTTPS: &HTTPSOptions{}, + Policy: &PolicyOptions{}, }, assert: func(t *testing.T, input *Config, output *Config) { assert.NotNil(t, output) assert.NotSame(t, input, output) - assert.NotSame(t, input.General, output.General) - assert.NotSame(t, input.Server, output.Server) + assert.NotSame(t, input.App, output.App) + assert.NotSame(t, input.Conn, output.Conn) assert.NotSame(t, input.DNS, output.DNS) assert.NotSame(t, input.HTTPS, output.HTTPS) assert.NotSame(t, input.Policy, output.Policy) diff --git a/internal/config/parse.go b/internal/config/parse.go index 5f3ccd79..23905376 100644 --- a/internal/config/parse.go +++ b/internal/config/parse.go @@ -108,14 +108,14 @@ func MustParseLogLevel(s string) zerolog.Level { return level } -func MustParseServerModeType(s string) ServerModeType { +func MustParseServerModeType(s string) AppModeType { switch s { case "http": - return ServerModeHTTP + return AppModeHTTP case "socks5": - return ServerModeSOCKS5 + return AppModeSOCKS5 case "tun": - return ServerModeTUN + return AppModeTUN default: panic(fmt.Sprintf("cannot parse %q to ServerModeType", s)) } diff --git a/internal/config/toml_test.go b/internal/config/toml_test.go index 4e47de25..6757a2c0 100644 --- a/internal/config/toml_test.go +++ b/internal/config/toml_test.go @@ -418,10 +418,13 @@ func TestFromTomlFile(t *testing.T) { log-level = "debug" silent = true network-config = true - [server] - listen-addr = "127.0.0.1:8080" - timeout = 1000 - default-ttl = 100 + mode = "socks5" + listen-addr = "127.0.0.1:8080" + [connection] + dns-timeout = 1000 + tcp-timeout = 1000 + udp-timeout = 1000 + default-fake-ttl = 100 [dns] addr = "8.8.8.8:53" @@ -481,17 +484,20 @@ func TestFromTomlFile(t *testing.T) { return } - assert.Equal(t, "127.0.0.1:8080", cfg.Server.ListenAddr.String()) - assert.Equal(t, time.Duration(1000*time.Millisecond), *cfg.Server.Timeout) - assert.Equal(t, zerolog.DebugLevel, *cfg.General.LogLevel) - assert.True(t, *cfg.General.Silent) - assert.True(t, *cfg.General.SetNetworkConfig) + assert.Equal(t, "127.0.0.1:8080", cfg.App.ListenAddr.String()) + assert.Equal(t, time.Duration(1000*time.Millisecond), *cfg.Conn.DNSTimeout) + assert.Equal(t, time.Duration(1000*time.Millisecond), *cfg.Conn.TCPTimeout) + assert.Equal(t, time.Duration(1000*time.Millisecond), *cfg.Conn.UDPTimeout) + assert.Equal(t, zerolog.DebugLevel, *cfg.App.LogLevel) + assert.True(t, *cfg.App.Silent) + assert.True(t, *cfg.App.SetNetworkConfig) + assert.Equal(t, AppModeSOCKS5, *cfg.App.Mode) assert.Equal(t, "8.8.8.8:53", cfg.DNS.Addr.String()) assert.True(t, *cfg.DNS.Cache) assert.Equal(t, DNSModeHTTPS, *cfg.DNS.Mode) assert.Equal(t, "https://1.1.1.1/dns-query", *cfg.DNS.HTTPSURL) assert.Equal(t, DNSQueryIPv4, *cfg.DNS.QType) - assert.Equal(t, uint8(100), *cfg.Server.DefaultTTL) + assert.Equal(t, uint8(100), *cfg.Conn.DefaultFakeTTL) assert.True(t, *cfg.Policy.Auto) assert.True(t, *cfg.HTTPS.Disorder) assert.Equal(t, uint8(5), *cfg.HTTPS.FakeCount) diff --git a/internal/config/types.go b/internal/config/types.go index 90dd1772..0f6f9c5a 100644 --- a/internal/config/types.go +++ b/internal/config/types.go @@ -32,17 +32,19 @@ func clonePrimitive[T primitive](x *T) *T { // ┌─────────────────┐ // │ GENERAL OPTIONS │ // └─────────────────┘ -var _ merger[*GeneralOptions] = (*GeneralOptions)(nil) +var _ merger[*AppOptions] = (*AppOptions)(nil) var availableLogLevelValues = []string{"info", "warn", "trace", "error", "debug"} -type GeneralOptions struct { +type AppOptions struct { LogLevel *zerolog.Level `toml:"log-level"` Silent *bool `toml:"silent"` SetNetworkConfig *bool `toml:"network-config"` + Mode *AppModeType `toml:"mode"` + ListenAddr *net.TCPAddr `toml:"listen-addr"` } -func (o *GeneralOptions) UnmarshalTOML(data any) (err error) { +func (o *AppOptions) UnmarshalTOML(data any) (err error) { m, ok := data.(map[string]any) if !ok { return fmt.Errorf("non-table type general config") @@ -53,11 +55,17 @@ func (o *GeneralOptions) UnmarshalTOML(data any) (err error) { if p := findFrom(m, "log-level", parseStringFn(checkLogLevel), &err); isOk(p, err) { o.LogLevel = lo.ToPtr(MustParseLogLevel(*p)) } + if p := findFrom(m, "mode", parseStringFn(checkAppMode), &err); isOk(p, err) { + o.Mode = lo.ToPtr(MustParseServerModeType(*p)) + } + if p := findFrom(m, "listen-addr", parseStringFn(checkHostPort), &err); isOk(p, err) { + o.ListenAddr = lo.ToPtr(MustParseTCPAddr(*p)) + } return err } -func (o *GeneralOptions) Clone() *GeneralOptions { +func (o *AppOptions) Clone() *AppOptions { if o == nil { return nil } @@ -67,14 +75,25 @@ func (o *GeneralOptions) Clone() *GeneralOptions { newLevel = lo.ToPtr(MustParseLogLevel(strings.ToLower(o.LogLevel.String()))) } - return &GeneralOptions{ + var newAddr *net.TCPAddr + if o.ListenAddr != nil { + newAddr = &net.TCPAddr{ + IP: append(net.IP(nil), o.ListenAddr.IP...), + Port: o.ListenAddr.Port, + Zone: o.ListenAddr.Zone, + } + } + + return &AppOptions{ LogLevel: newLevel, Silent: clonePrimitive(o.Silent), SetNetworkConfig: clonePrimitive(o.SetNetworkConfig), + Mode: clonePrimitive(o.Mode), + ListenAddr: newAddr, } } -func (origin *GeneralOptions) Merge(overrides *GeneralOptions) *GeneralOptions { +func (origin *AppOptions) Merge(overrides *AppOptions) *AppOptions { if overrides == nil { return origin.Clone() } @@ -83,88 +102,108 @@ func (origin *GeneralOptions) Merge(overrides *GeneralOptions) *GeneralOptions { return overrides.Clone() } - return &GeneralOptions{ + return &AppOptions{ LogLevel: lo.CoalesceOrEmpty(overrides.LogLevel, origin.LogLevel), Silent: lo.CoalesceOrEmpty(overrides.Silent, origin.Silent), SetNetworkConfig: lo.CoalesceOrEmpty( overrides.SetNetworkConfig, origin.SetNetworkConfig, ), + Mode: lo.CoalesceOrEmpty(overrides.Mode, origin.Mode), + ListenAddr: lo.CoalesceOrEmpty(overrides.ListenAddr, origin.ListenAddr), } } -// ┌────────────────┐ -// │ SERVER OPTIONS │ -// └────────────────┘ -var _ merger[*ServerOptions] = (*ServerOptions)(nil) +// ┌──────────────────────┐ +// │ CONNECTION OPTIONS │ +// └──────────────────────┘ +var _ merger[*ConnOptions] = (*ConnOptions)(nil) -type ServerModeType int +type AppModeType int const ( - ServerModeHTTP ServerModeType = iota - ServerModeSOCKS5 - ServerModeTUN + AppModeHTTP AppModeType = iota + AppModeSOCKS5 + AppModeTUN ) -var availableServerModeValues = []string{"http", "socks5", "tun"} +var availableAppModeValues = []string{"http", "socks5", "tun"} -func (t ServerModeType) String() string { - return availableServerModeValues[t] +func (t AppModeType) String() string { + return availableAppModeValues[t] } -type ServerOptions struct { - Mode *ServerModeType `toml:"mode"` - DefaultTTL *uint8 `toml:"default-ttl"` - ListenAddr *net.TCPAddr `toml:"listen-addr"` - Timeout *time.Duration `toml:"timeout"` +type ConnOptions struct { + DefaultFakeTTL *uint8 `toml:"default-fake-ttl"` + DNSTimeout *time.Duration `toml:"dns-timeout"` + TCPTimeout *time.Duration `toml:"tcp-timeout"` + UDPTimeout *time.Duration `toml:"udp-timeout"` } -func (o *ServerOptions) UnmarshalTOML(data any) (err error) { +func (o *ConnOptions) UnmarshalTOML(data any) (err error) { v, ok := data.(map[string]any) if !ok { - return fmt.Errorf("non-table type server config") + return fmt.Errorf("non-table type connection config") } - if p := findFrom(v, "mode", parseStringFn(checkServerMode), &err); isOk(p, err) { - o.Mode = lo.ToPtr(MustParseServerModeType(*p)) - } - - o.DefaultTTL = findFrom(v, "default-ttl", parseIntFn[uint8](checkUint8NonZero), &err) - - if p := findFrom(v, "listen-addr", parseStringFn(checkHostPort), &err); isOk(p, err) { - o.ListenAddr = lo.ToPtr(MustParseTCPAddr(*p)) - } + o.DefaultFakeTTL = findFrom( + v, + "default-fake-ttl", + parseIntFn[uint8](checkUint8NonZero), + &err, + ) - if p := findFrom(v, "timeout", parseIntFn[uint16](checkUint16), &err); isOk(p, err) { - o.Timeout = lo.ToPtr(time.Duration(*p) * time.Millisecond) + if p := findFrom( + v, + "dns-timeout", + parseIntFn[uint16](checkUint16), + &err, + ); isOk( + p, + err, + ) { + o.DNSTimeout = lo.ToPtr(time.Duration(*p) * time.Millisecond) + } + if p := findFrom( + v, + "tcp-timeout", + parseIntFn[uint16](checkUint16), + &err, + ); isOk( + p, + err, + ) { + o.TCPTimeout = lo.ToPtr(time.Duration(*p) * time.Millisecond) + } + if p := findFrom( + v, + "udp-timeout", + parseIntFn[uint16](checkUint16), + &err, + ); isOk( + p, + err, + ) { + o.UDPTimeout = lo.ToPtr(time.Duration(*p) * time.Millisecond) } return err } -func (o *ServerOptions) Clone() *ServerOptions { +func (o *ConnOptions) Clone() *ConnOptions { if o == nil { return nil } - var newAddr *net.TCPAddr - if o.ListenAddr != nil { - newAddr = &net.TCPAddr{ - IP: append(net.IP(nil), o.ListenAddr.IP...), - Port: o.ListenAddr.Port, - Zone: o.ListenAddr.Zone, - } - } - - return &ServerOptions{ - Mode: clonePrimitive(o.Mode), - DefaultTTL: clonePrimitive(o.DefaultTTL), - ListenAddr: newAddr, - Timeout: clonePrimitive(o.Timeout), + return &ConnOptions{ + DefaultFakeTTL: clonePrimitive(o.DefaultFakeTTL), + DNSTimeout: clonePrimitive(o.DNSTimeout), + TCPTimeout: clonePrimitive(o.TCPTimeout), + UDPTimeout: clonePrimitive(o.UDPTimeout), } } -func (origin *ServerOptions) Merge(overrides *ServerOptions) *ServerOptions { +func (origin *ConnOptions) Merge(overrides *ConnOptions) *ConnOptions { if overrides == nil { return origin.Clone() } @@ -173,11 +212,11 @@ func (origin *ServerOptions) Merge(overrides *ServerOptions) *ServerOptions { return overrides.Clone() } - return &ServerOptions{ - Mode: lo.CoalesceOrEmpty(overrides.Mode, origin.Mode), - DefaultTTL: lo.CoalesceOrEmpty(overrides.DefaultTTL, origin.DefaultTTL), - ListenAddr: lo.CoalesceOrEmpty(overrides.ListenAddr, origin.ListenAddr), - Timeout: lo.CoalesceOrEmpty(overrides.Timeout, origin.Timeout), + return &ConnOptions{ + DefaultFakeTTL: lo.CoalesceOrEmpty(overrides.DefaultFakeTTL, origin.DefaultFakeTTL), + DNSTimeout: lo.CoalesceOrEmpty(overrides.DNSTimeout, origin.DNSTimeout), + TCPTimeout: lo.CoalesceOrEmpty(overrides.TCPTimeout, origin.TCPTimeout), + UDPTimeout: lo.CoalesceOrEmpty(overrides.UDPTimeout, origin.UDPTimeout), } } @@ -524,9 +563,8 @@ func (origin *HTTPSOptions) Merge(overrides *HTTPSOptions) *HTTPSOptions { var _ merger[*UDPOptions] = (*UDPOptions)(nil) type UDPOptions struct { - FakeCount *int `toml:"fake-count" json:"fc,omitempty"` - FakePacket []byte `toml:"fake-packet" json:"fp,omitempty"` - Timeout *time.Duration `toml:"timeout" json:"to,omitempty"` + FakeCount *int `toml:"fake-count" json:"fc,omitempty"` + FakePacket []byte `toml:"fake-packet" json:"fp,omitempty"` } func (o *UDPOptions) UnmarshalTOML(data any) (err error) { @@ -540,10 +578,6 @@ func (o *UDPOptions) UnmarshalTOML(data any) (err error) { ) o.FakePacket = findSliceFrom(m, "fake-packet", parseByteFn(nil), &err) - if p := findFrom(m, "timeout", parseIntFn[uint16](checkUint16), &err); isOk(p, err) { - o.Timeout = lo.ToPtr(time.Duration(*p) * time.Millisecond) - } - return err } @@ -555,7 +589,6 @@ func (o *UDPOptions) Clone() *UDPOptions { return &UDPOptions{ FakeCount: clonePrimitive(o.FakeCount), FakePacket: append([]byte(nil), o.FakePacket...), - Timeout: clonePrimitive(o.Timeout), } } @@ -576,7 +609,6 @@ func (origin *UDPOptions) Merge(overrides *UDPOptions) *UDPOptions { return &UDPOptions{ FakeCount: lo.CoalesceOrEmpty(overrides.FakeCount, origin.FakeCount), FakePacket: fakePacket, - Timeout: lo.CoalesceOrEmpty(overrides.Timeout, origin.Timeout), } } @@ -730,6 +762,7 @@ type Rule struct { DNS *DNSOptions `toml:"dns-override" json:"D,omitempty"` HTTPS *HTTPSOptions `toml:"https-override" json:"H,omitempty"` UDP *UDPOptions `toml:"udp-override" json:"U,omitempty"` + Conn *ConnOptions `toml:"conn-override" json:"C,omitempty"` } func (r *Rule) UnmarshalTOML(data any) (err error) { @@ -745,6 +778,7 @@ func (r *Rule) UnmarshalTOML(data any) (err error) { r.DNS = findStructFrom[DNSOptions](m, "dns", &err) r.HTTPS = findStructFrom[HTTPSOptions](m, "https", &err) r.UDP = findStructFrom[UDPOptions](m, "udp", &err) + r.Conn = findStructFrom[ConnOptions](m, "connection", &err) // if err == nil { // err = checkRule(*r) @@ -765,6 +799,7 @@ func (r *Rule) Clone() *Rule { DNS: r.DNS.Clone(), HTTPS: r.HTTPS.Clone(), UDP: r.UDP.Clone(), + Conn: r.Conn.Clone(), } } diff --git a/internal/config/types_test.go b/internal/config/types_test.go index 6848f967..b806114e 100644 --- a/internal/config/types_test.go +++ b/internal/config/types_test.go @@ -15,12 +15,12 @@ import ( // ┌─────────────────┐ // │ GENERAL OPTIONS │ // └─────────────────┘ -func TestGeneralOptions_UnmarshalTOML(t *testing.T) { +func TestAppOptions_UnmarshalTOML(t *testing.T) { tcs := []struct { name string input any wantErr bool - assert func(t *testing.T, o GeneralOptions) + assert func(t *testing.T, o AppOptions) }{ { name: "valid general options", @@ -28,12 +28,14 @@ func TestGeneralOptions_UnmarshalTOML(t *testing.T) { "log-level": "debug", "silent": true, "network-config": true, + "mode": "socks5", }, wantErr: false, - assert: func(t *testing.T, o GeneralOptions) { + assert: func(t *testing.T, o AppOptions) { assert.Equal(t, zerolog.DebugLevel, *o.LogLevel) assert.True(t, *o.Silent) assert.True(t, *o.SetNetworkConfig) + assert.Equal(t, AppModeSOCKS5, *o.Mode) }, }, { @@ -45,7 +47,7 @@ func TestGeneralOptions_UnmarshalTOML(t *testing.T) { for _, tc := range tcs { t.Run(tc.name, func(t *testing.T) { - var o GeneralOptions + var o AppOptions err := o.UnmarshalTOML(tc.input) if tc.wantErr { assert.Error(t, err) @@ -59,26 +61,26 @@ func TestGeneralOptions_UnmarshalTOML(t *testing.T) { } } -func TestGeneralOptions_Clone(t *testing.T) { +func TestAppOptions_Clone(t *testing.T) { tcs := []struct { name string - input *GeneralOptions - assert func(t *testing.T, input *GeneralOptions, output *GeneralOptions) + input *AppOptions + assert func(t *testing.T, input *AppOptions, output *AppOptions) }{ { name: "nil receiver", input: nil, - assert: func(t *testing.T, input *GeneralOptions, output *GeneralOptions) { + assert: func(t *testing.T, input *AppOptions, output *AppOptions) { assert.Nil(t, output) }, }, { name: "non-nil receiver", - input: &GeneralOptions{ + input: &AppOptions{ LogLevel: lo.ToPtr(zerolog.DebugLevel), Silent: lo.ToPtr(true), }, - assert: func(t *testing.T, input *GeneralOptions, output *GeneralOptions) { + assert: func(t *testing.T, input *AppOptions, output *AppOptions) { assert.NotNil(t, output) assert.Equal(t, zerolog.DebugLevel, *output.LogLevel) assert.True(t, *output.Silent) @@ -95,39 +97,39 @@ func TestGeneralOptions_Clone(t *testing.T) { } } -func TestGeneralOptions_Merge(t *testing.T) { +func TestAppOptions_Merge(t *testing.T) { tcs := []struct { name string - base *GeneralOptions - override *GeneralOptions - assert func(t *testing.T, output *GeneralOptions) + base *AppOptions + override *AppOptions + assert func(t *testing.T, output *AppOptions) }{ { name: "nil receiver", base: nil, - override: &GeneralOptions{Silent: lo.ToPtr(true)}, - assert: func(t *testing.T, output *GeneralOptions) { + override: &AppOptions{Silent: lo.ToPtr(true)}, + assert: func(t *testing.T, output *AppOptions) { assert.True(t, *output.Silent) }, }, { name: "nil override", - base: &GeneralOptions{Silent: lo.ToPtr(false)}, + base: &AppOptions{Silent: lo.ToPtr(false)}, override: nil, - assert: func(t *testing.T, output *GeneralOptions) { + assert: func(t *testing.T, output *AppOptions) { assert.False(t, *output.Silent) }, }, { name: "merge values", - base: &GeneralOptions{ + base: &AppOptions{ Silent: lo.ToPtr(false), LogLevel: lo.ToPtr(zerolog.InfoLevel), }, - override: &GeneralOptions{ + override: &AppOptions{ Silent: lo.ToPtr(true), }, - assert: func(t *testing.T, output *GeneralOptions) { + assert: func(t *testing.T, output *AppOptions) { assert.True(t, *output.Silent) assert.Equal(t, zerolog.InfoLevel, *output.LogLevel) }, @@ -145,25 +147,27 @@ func TestGeneralOptions_Merge(t *testing.T) { // ┌────────────────┐ // │ SERVER OPTIONS │ // └────────────────┘ -func TestServerOptions_UnmarshalTOML(t *testing.T) { +func TestConnOptions_UnmarshalTOML(t *testing.T) { tcs := []struct { name string input any wantErr bool - assert func(t *testing.T, o ServerOptions) + assert func(t *testing.T, o ConnOptions) }{ { name: "valid server options", input: map[string]any{ - "default-ttl": int64(64), - "listen-addr": "127.0.0.1:8080", - "timeout": int64(1000), + "default-fake-ttl": int64(64), + "dns-timeout": int64(1000), + "tcp-timeout": int64(1000), + "udp-timeout": int64(1000), }, wantErr: false, - assert: func(t *testing.T, o ServerOptions) { - assert.Equal(t, uint8(64), *o.DefaultTTL) - assert.Equal(t, "127.0.0.1:8080", o.ListenAddr.String()) - assert.Equal(t, 1000*time.Millisecond, *o.Timeout) + assert: func(t *testing.T, o ConnOptions) { + assert.Equal(t, uint8(64), *o.DefaultFakeTTL) + assert.Equal(t, 1000*time.Millisecond, *o.DNSTimeout) + assert.Equal(t, 1000*time.Millisecond, *o.TCPTimeout) + assert.Equal(t, 1000*time.Millisecond, *o.UDPTimeout) }, }, { @@ -175,7 +179,7 @@ func TestServerOptions_UnmarshalTOML(t *testing.T) { for _, tc := range tcs { t.Run(tc.name, func(t *testing.T) { - var o ServerOptions + var o ConnOptions err := o.UnmarshalTOML(tc.input) if tc.wantErr { assert.Error(t, err) @@ -189,33 +193,34 @@ func TestServerOptions_UnmarshalTOML(t *testing.T) { } } -func TestServerOptions_Clone(t *testing.T) { +func TestConnOptions_Clone(t *testing.T) { tcs := []struct { name string - input *ServerOptions - assert func(t *testing.T, input *ServerOptions, output *ServerOptions) + input *ConnOptions + assert func(t *testing.T, input *ConnOptions, output *ConnOptions) }{ { name: "nil receiver", input: nil, - assert: func(t *testing.T, input *ServerOptions, output *ServerOptions) { + assert: func(t *testing.T, input *ConnOptions, output *ConnOptions) { assert.Nil(t, output) }, }, { name: "non-nil receiver", - input: &ServerOptions{ - DefaultTTL: lo.ToPtr(uint8(64)), - ListenAddr: &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 8080}, + input: &ConnOptions{ + DefaultFakeTTL: lo.ToPtr(uint8(64)), + DNSTimeout: lo.ToPtr(time.Duration(1000) * time.Millisecond), + TCPTimeout: lo.ToPtr(time.Duration(1000) * time.Millisecond), + UDPTimeout: lo.ToPtr(time.Duration(1000) * time.Millisecond), }, - assert: func(t *testing.T, input *ServerOptions, output *ServerOptions) { + assert: func(t *testing.T, input *ConnOptions, output *ConnOptions) { assert.NotNil(t, output) - assert.Equal(t, uint8(64), *output.DefaultTTL) - assert.Equal(t, "127.0.0.1:8080", output.ListenAddr.String()) + assert.Equal(t, uint8(64), *output.DefaultFakeTTL) + assert.Equal(t, 1000*time.Millisecond, *output.DNSTimeout) + assert.Equal(t, 1000*time.Millisecond, *output.TCPTimeout) + assert.Equal(t, 1000*time.Millisecond, *output.UDPTimeout) assert.NotSame(t, input, output) - if output.ListenAddr != nil { - assert.NotSame(t, input.ListenAddr, output.ListenAddr) - } }, }, } @@ -228,41 +233,45 @@ func TestServerOptions_Clone(t *testing.T) { } } -func TestServerOptions_Merge(t *testing.T) { +func TestConnOptions_Merge(t *testing.T) { tcs := []struct { name string - base *ServerOptions - override *ServerOptions - assert func(t *testing.T, output *ServerOptions) + base *ConnOptions + override *ConnOptions + assert func(t *testing.T, output *ConnOptions) }{ { name: "nil receiver", base: nil, - override: &ServerOptions{DefaultTTL: lo.ToPtr(uint8(64))}, - assert: func(t *testing.T, output *ServerOptions) { - assert.Equal(t, uint8(64), *output.DefaultTTL) + override: &ConnOptions{DefaultFakeTTL: lo.ToPtr(uint8(64))}, + assert: func(t *testing.T, output *ConnOptions) { + assert.Equal(t, uint8(64), *output.DefaultFakeTTL) }, }, { name: "nil override", - base: &ServerOptions{DefaultTTL: lo.ToPtr(uint8(128))}, + base: &ConnOptions{DefaultFakeTTL: lo.ToPtr(uint8(128))}, override: nil, - assert: func(t *testing.T, output *ServerOptions) { - assert.Equal(t, uint8(128), *output.DefaultTTL) + assert: func(t *testing.T, output *ConnOptions) { + assert.Equal(t, uint8(128), *output.DefaultFakeTTL) }, }, { name: "merge values", - base: &ServerOptions{ - DefaultTTL: lo.ToPtr(uint8(64)), - ListenAddr: &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 8080}, + base: &ConnOptions{ + DefaultFakeTTL: lo.ToPtr(uint8(64)), + DNSTimeout: lo.ToPtr(time.Duration(1000) * time.Millisecond), + TCPTimeout: lo.ToPtr(time.Duration(1000) * time.Millisecond), + UDPTimeout: lo.ToPtr(time.Duration(1000) * time.Millisecond), }, - override: &ServerOptions{ - DefaultTTL: lo.ToPtr(uint8(128)), + override: &ConnOptions{ + DefaultFakeTTL: lo.ToPtr(uint8(128)), }, - assert: func(t *testing.T, output *ServerOptions) { - assert.Equal(t, uint8(128), *output.DefaultTTL) - assert.Equal(t, "127.0.0.1:8080", output.ListenAddr.String()) + assert: func(t *testing.T, output *ConnOptions) { + assert.Equal(t, uint8(128), *output.DefaultFakeTTL) + assert.Equal(t, 1000*time.Millisecond, *output.DNSTimeout) + assert.Equal(t, 1000*time.Millisecond, *output.TCPTimeout) + assert.Equal(t, 1000*time.Millisecond, *output.UDPTimeout) }, }, } @@ -846,6 +855,20 @@ func TestRule_UnmarshalTOML(t *testing.T) { assert.True(t, *r.Block) }, }, + { + name: "valid rule with connection options", + input: map[string]any{ + "name": "rule2", + "connection": map[string]any{ + "tcp-timeout": int64(500), + }, + }, + wantErr: false, + assert: func(t *testing.T, r Rule) { + assert.Equal(t, "rule2", *r.Name) + assert.Equal(t, time.Duration(500*time.Millisecond), *r.Conn.TCPTimeout) + }, + }, { name: "invalid type", input: "invalid", diff --git a/internal/config/validate.go b/internal/config/validate.go index ba24e50b..43643a0e 100644 --- a/internal/config/validate.go +++ b/internal/config/validate.go @@ -51,7 +51,7 @@ var ( checkUint8 = int64Range(0, math.MaxUint8) checkUint16 = int64Range(0, math.MaxUint16) checkUint8NonZero = int64Range(1, math.MaxUint8) - checkServerMode = checkOneOf(availableServerModeValues...) + checkAppMode = checkOneOf(availableAppModeValues...) checkDNSMode = checkOneOf(availableDNSModeValues...) checkDNSQueryType = checkOneOf(availableDNSQueryValues...) checkHTTPSSplitMode = checkOneOf(availableHTTPSModeValues...) diff --git a/internal/desync/tls.go b/internal/desync/tls.go index 7d411d05..04299530 100644 --- a/internal/desync/tls.go +++ b/internal/desync/tls.go @@ -91,7 +91,7 @@ func (d *TLSDesyncer) sendSegments( total := 0 for _, chunk := range segments { if !ttlErrored && chunk.Lazy { - setTTLWrap(0) + setTTLWrap(1) } n, err := conn.Write(chunk.Packet) diff --git a/internal/dns/https.go b/internal/dns/https.go index 9bbce947..ab97bf08 100644 --- a/internal/dns/https.go +++ b/internal/dns/https.go @@ -21,21 +21,23 @@ import ( var _ Resolver = (*HTTPSResolver)(nil) type HTTPSResolver struct { - logger zerolog.Logger - client *http.Client - dnsOpts *config.DNSOptions + logger zerolog.Logger + client *http.Client + defaultDNSOpts *config.DNSOptions + defaultConnOpts *config.ConnOptions } func NewHTTPSResolver( logger zerolog.Logger, - dnsOpts *config.DNSOptions, + defaultDNSOpts *config.DNSOptions, + defaultConnOpts *config.ConnOptions, ) *HTTPSResolver { tr := &http.Transport{ TLSClientConfig: &tls.Config{ NextProtos: []string{"h2", "http/1.1"}, }, DialContext: (&net.Dialer{ - Timeout: 7 * time.Second, + Timeout: *defaultConnOpts.DNSTimeout, KeepAlive: 30 * time.Second, }).DialContext, TLSHandshakeTimeout: 9 * time.Second, @@ -54,9 +56,10 @@ func NewHTTPSResolver( logger: logger, client: &http.Client{ Transport: tr, - Timeout: 10 * time.Second, + Timeout: *defaultConnOpts.DNSTimeout, }, - dnsOpts: dnsOpts, + defaultDNSOpts: defaultDNSOpts, + defaultConnOpts: defaultConnOpts, } } @@ -64,7 +67,7 @@ func (dr *HTTPSResolver) Info() []ResolverInfo { return []ResolverInfo{ { Name: "https", - Dst: *dr.dnsOpts.HTTPSURL, + Dst: *dr.defaultDNSOpts.HTTPSURL, }, } } @@ -75,7 +78,7 @@ func (dr *HTTPSResolver) Resolve( fallback Resolver, rule *config.Rule, ) (*RecordSet, error) { - opts := dr.dnsOpts.Clone() + opts := dr.defaultDNSOpts.Clone() if rule != nil { opts = opts.Merge(rule.DNS) } diff --git a/internal/dns/route.go b/internal/dns/route.go index dfc97928..5823c0a4 100644 --- a/internal/dns/route.go +++ b/internal/dns/route.go @@ -14,12 +14,12 @@ import ( ) type RouteResolver struct { - logger zerolog.Logger - https Resolver - udp Resolver - system Resolver - cache Resolver - dnsOpts *config.DNSOptions + logger zerolog.Logger + https Resolver + udp Resolver + system Resolver + cache Resolver + defaultDNSOpts *config.DNSOptions } func NewRouteResolver( @@ -28,15 +28,15 @@ func NewRouteResolver( udp Resolver, sys Resolver, cache Resolver, - dnsOpts *config.DNSOptions, + defaultDNSOpts *config.DNSOptions, ) *RouteResolver { return &RouteResolver{ - logger: logger, - https: doh, - udp: udp, - system: sys, - cache: cache, - dnsOpts: dnsOpts, + logger: logger, + https: doh, + udp: udp, + system: sys, + cache: cache, + defaultDNSOpts: defaultDNSOpts, } } @@ -55,7 +55,7 @@ func (rr *RouteResolver) Resolve( fallback Resolver, rule *config.Rule, ) (*RecordSet, error) { - opts := rr.dnsOpts.Clone() + opts := rr.defaultDNSOpts.Clone() if rule != nil { opts = opts.Merge(rule.DNS) } diff --git a/internal/dns/system.go b/internal/dns/system.go index 6769fbd1..d6c318f0 100644 --- a/internal/dns/system.go +++ b/internal/dns/system.go @@ -15,17 +15,17 @@ type SystemResolver struct { logger zerolog.Logger *net.Resolver - dnsOpts *config.DNSOptions + defaultDNSOpts *config.DNSOptions } func NewSystemResolver( logger zerolog.Logger, - dnsOps *config.DNSOptions, + defaultDNSOpts *config.DNSOptions, ) *SystemResolver { return &SystemResolver{ - logger: logger, - Resolver: &net.Resolver{PreferGo: true}, - dnsOpts: dnsOps, + logger: logger, + Resolver: &net.Resolver{PreferGo: true}, + defaultDNSOpts: defaultDNSOpts, } } @@ -44,7 +44,7 @@ func (sr *SystemResolver) Resolve( fallback Resolver, rule *config.Rule, ) (*RecordSet, error) { - opts := sr.dnsOpts.Clone() + opts := sr.defaultDNSOpts.Clone() if rule != nil { opts = opts.Merge(rule.DNS) } diff --git a/internal/dns/udp.go b/internal/dns/udp.go index 2fba04b1..3c7b1862 100644 --- a/internal/dns/udp.go +++ b/internal/dns/udp.go @@ -14,18 +14,23 @@ var _ Resolver = (*UDPResolver)(nil) type UDPResolver struct { logger zerolog.Logger - client *dns.Client - dnsOpts *config.DNSOptions + client *dns.Client + defaultDNSOpts *config.DNSOptions + defaultConnOpts *config.ConnOptions } func NewUDPResolver( logger zerolog.Logger, - dnsOpts *config.DNSOptions, + defaultDNSOpts *config.DNSOptions, + defaultConnOpts *config.ConnOptions, ) *UDPResolver { return &UDPResolver{ - client: &dns.Client{}, - dnsOpts: dnsOpts, - logger: logger, + client: &dns.Client{ + Timeout: *defaultConnOpts.DNSTimeout, + }, + defaultDNSOpts: defaultDNSOpts, + defaultConnOpts: defaultConnOpts, + logger: logger, } } @@ -33,7 +38,7 @@ func (ur *UDPResolver) Info() []ResolverInfo { return []ResolverInfo{ { Name: "udp", - Dst: ur.dnsOpts.Addr.String(), + Dst: ur.defaultDNSOpts.Addr.String(), }, } } @@ -44,7 +49,7 @@ func (ur *UDPResolver) Resolve( fallback Resolver, rule *config.Rule, ) (*RecordSet, error) { - opts := ur.dnsOpts.Clone() + opts := ur.defaultDNSOpts.Clone() if rule != nil { opts = opts.Merge(rule.DNS) } diff --git a/internal/netutil/addr.go b/internal/netutil/addr.go index 24b85548..0b51194a 100644 --- a/internal/netutil/addr.go +++ b/internal/netutil/addr.go @@ -3,8 +3,6 @@ package netutil import ( "fmt" "net" - "os/exec" - "regexp" "strconv" "time" ) @@ -145,25 +143,6 @@ func GetDefaultInterfaceAndGateway() (string, string, error) { return ifaceName, gateway, nil } -// getDefaultGateway parses the system route table to find the default gateway -func getDefaultGateway() (string, error) { - // Use netstat to get the default route on macOS - cmd := exec.Command("route", "-n", "get", "default") - out, err := cmd.Output() - if err != nil { - return "", err - } - - // Parse output to find gateway line - re := regexp.MustCompile(`gateway:\s+(\d+\.\d+\.\d+\.\d+)`) - matches := re.FindSubmatch(out) - if len(matches) < 2 { - return "", fmt.Errorf("could not parse gateway from route output") - } - - return string(matches[1]), nil -} - // GetDefaultInterface returns the name of the default network interface func GetDefaultInterface() (string, error) { ifaceName, _, err := GetDefaultInterfaceAndGateway() diff --git a/internal/netutil/bind.go b/internal/netutil/bind.go new file mode 100644 index 00000000..c406894a --- /dev/null +++ b/internal/netutil/bind.go @@ -0,0 +1,12 @@ +//go:build !linux && !darwin && !freebsd + +package netutil + +import ( + "net" +) + +// bindToInterface is a no-op on unsupported platforms. +func bindToInterface(dialer *net.Dialer, iface *net.Interface, targetIP net.IP) error { + return nil +} diff --git a/internal/netutil/dial_darwin.go b/internal/netutil/bind_bsd.go similarity index 65% rename from internal/netutil/dial_darwin.go rename to internal/netutil/bind_bsd.go index 35d26553..287fafb5 100644 --- a/internal/netutil/dial_darwin.go +++ b/internal/netutil/bind_bsd.go @@ -1,8 +1,9 @@ -//go:build darwin +//go:build darwin || freebsd package netutil import ( + "fmt" "net" "syscall" @@ -10,9 +11,17 @@ import ( ) // bindToInterface sets the dialer's Control function to bind the socket -// to a specific network interface using IP_BOUND_IF on Darwin. -func bindToInterface(dialer *net.Dialer, iface *net.Interface, targetIP net.IP) { - addrs, _ := iface.Addrs() +// to a specific network interface using IP_BOUND_IF on BSD systems. +func bindToInterface(dialer *net.Dialer, iface *net.Interface, targetIP net.IP) error { + if iface == nil { + return nil + } + + addrs, err := iface.Addrs() + if err != nil { + return fmt.Errorf("failed to get interface addresses: %w", err) + } + for _, addr := range addrs { if ipnet, ok := addr.(*net.IPNet); ok { if targetIP.To4() != nil && ipnet.IP.To4() != nil && !ipnet.IP.IsLoopback() { @@ -33,8 +42,10 @@ func bindToInterface(dialer *net.Dialer, iface *net.Interface, targetIP net.IP) return setsockoptErr } - break + return nil } } } + + return fmt.Errorf("no suitable IP address found on interface %s for target %s", iface.Name, targetIP) } diff --git a/internal/netutil/bind_linux.go b/internal/netutil/bind_linux.go new file mode 100644 index 00000000..8467bd3e --- /dev/null +++ b/internal/netutil/bind_linux.go @@ -0,0 +1,38 @@ +//go:build linux + +package netutil + +import ( + "fmt" + "net" +) + +// bindToInterface sets the dialer's LocalAddr to use the interface's IP as the source address. +// On Linux, we only set LocalAddr because SO_BINDTODEVICE can cause issues with +// socket lookup for incoming packets. +func bindToInterface(dialer *net.Dialer, iface *net.Interface, targetIP net.IP) error { + if iface == nil { + return nil + } + + // Find the interface's IP address to use as source + addrs, err := iface.Addrs() + if err != nil { + return fmt.Errorf("failed to get interface addresses: %w", err) + } + + for _, addr := range addrs { + if ipnet, ok := addr.(*net.IPNet); ok { + // Match IP version: use IPv4 source for IPv4 target, IPv6 for IPv6 + if targetIP.To4() != nil && ipnet.IP.To4() != nil && !ipnet.IP.IsLoopback() { + dialer.LocalAddr = &net.TCPAddr{IP: ipnet.IP} + return nil + } else if targetIP.To4() == nil && ipnet.IP.To4() == nil && !ipnet.IP.IsLoopback() { + dialer.LocalAddr = &net.TCPAddr{IP: ipnet.IP} + return nil + } + } + } + + return fmt.Errorf("no suitable IP address found on interface %s for target %s", iface.Name, targetIP) +} diff --git a/internal/netutil/conn.go b/internal/netutil/conn.go index 62754f91..51861f71 100644 --- a/internal/netutil/conn.go +++ b/internal/netutil/conn.go @@ -157,6 +157,7 @@ func CloseConns(closers ...io.Closer) { } // SetTTL configures the TTL or Hop Limit depending on the IP version. +// The isIPv4 parameter is determined by examining the remote address of the connection. func SetTTL(conn net.Conn, isIPv4 bool, ttl uint8) error { tcpConn, ok := conn.(*net.TCPConn) if !ok { @@ -168,8 +169,19 @@ func SetTTL(conn net.Conn, isIPv4 bool, ttl uint8) error { return err } + // Re-check IP version using remote address to handle IPv4-mapped IPv6 addresses + // On Linux, when using dual-stack sockets, the local address might be IPv6 + // but the actual connection could be IPv4-mapped (::ffff:x.x.x.x) + actualIPv4 := isIPv4 + if tcpAddr, ok := conn.RemoteAddr().(*net.TCPAddr); ok { + // If the IP is IPv4 or IPv4-mapped IPv6, we should use IPv4 options + if ip4 := tcpAddr.IP.To4(); ip4 != nil { + actualIPv4 = true + } + } + var level, opt int - if isIPv4 { + if actualIPv4 { level = syscall.IPPROTO_IP opt = syscall.IP_TTL } else { diff --git a/internal/netutil/dial.go b/internal/netutil/dial.go index dd6293bb..6467913f 100644 --- a/internal/netutil/dial.go +++ b/internal/netutil/dial.go @@ -53,7 +53,13 @@ func DialFastest( // If Iface is specified, bind to the interface if dst.Iface != nil { - bindToInterface(dialer, dst.Iface, ip) + if err := bindToInterface(dialer, dst.Iface, ip); err != nil { + select { + case results <- dialResult{conn: nil, err: err}: + case <-ctx.Done(): + } + return + } } conn, err := dialer.DialContext(ctx, network, targetAddr) diff --git a/internal/netutil/dial_other.go b/internal/netutil/dial_other.go deleted file mode 100644 index 0d0d0fa9..00000000 --- a/internal/netutil/dial_other.go +++ /dev/null @@ -1,11 +0,0 @@ -//go:build !darwin - -package netutil - -import "net" - -// bindToInterface is a no-op on non-Darwin systems. -// Interface binding is handled differently or not supported on other platforms. -func bindToInterface(dialer *net.Dialer, iface *net.Interface, targetIP net.IP) { - // No-op: interface binding via IP_BOUND_IF is Darwin-specific -} diff --git a/internal/netutil/gateway.go b/internal/netutil/gateway.go new file mode 100644 index 00000000..2acfae57 --- /dev/null +++ b/internal/netutil/gateway.go @@ -0,0 +1,8 @@ +//go:build !darwin && !freebsd && !linux + +package netutil + +// getDefaultGateway parses the system route table to find the default gateway on Linux +func getDefaultGateway() (string, error) { + return "", nil +} diff --git a/internal/netutil/gateway_bsd.go b/internal/netutil/gateway_bsd.go new file mode 100644 index 00000000..65ced920 --- /dev/null +++ b/internal/netutil/gateway_bsd.go @@ -0,0 +1,28 @@ +//go:build darwin || freebsd + +package netutil + +import ( + "fmt" + "os/exec" + "regexp" +) + +// getDefaultGateway parses the system route table to find the default gateway on macOS/BSD +func getDefaultGateway() (string, error) { + // Use route to get the default route on macOS + cmd := exec.Command("route", "-n", "get", "default") + out, err := cmd.Output() + if err != nil { + return "", err + } + + // Parse output to find gateway line + re := regexp.MustCompile(`gateway:\s+(\d+\.\d+\.\d+\.\d+)`) + matches := re.FindSubmatch(out) + if len(matches) < 2 { + return "", fmt.Errorf("could not parse gateway from route output") + } + + return string(matches[1]), nil +} diff --git a/internal/netutil/gateway_linux.go b/internal/netutil/gateway_linux.go new file mode 100644 index 00000000..10b9e738 --- /dev/null +++ b/internal/netutil/gateway_linux.go @@ -0,0 +1,29 @@ +//go:build linux + +package netutil + +import ( + "fmt" + "os/exec" + "strings" +) + +// getDefaultGateway parses the system route table to find the default gateway on Linux +func getDefaultGateway() (string, error) { + // Use ip route to get the default route on Linux + cmd := exec.Command("ip", "route", "show", "default") + out, err := cmd.Output() + if err != nil { + return "", err + } + + // Parse output: "default via 192.168.0.1 dev enp12s0 ..." + fields := strings.Fields(string(out)) + for i, field := range fields { + if field == "via" && i+1 < len(fields) { + return fields[i+1], nil + } + } + + return "", fmt.Errorf("could not parse gateway from ip route output: %s", string(out)) +} diff --git a/internal/packet/tcp_sniffer.go b/internal/packet/tcp_sniffer.go index 3720002d..52f5b474 100644 --- a/internal/packet/tcp_sniffer.go +++ b/internal/packet/tcp_sniffer.go @@ -2,7 +2,6 @@ package packet import ( "context" - "fmt" "net" "github.com/google/gopacket" @@ -105,9 +104,9 @@ func (ts *TCPSniffer) processPacket(ctx context.Context, p gopacket.Packet) { ip, _ := ipLayer.(*layers.IPv4) // Skip packets from local/private IPs (outgoing packets) if isLocalIP(ip.SrcIP) { - fmt.Println(ip.SrcIP) return } + srcIP = ip.SrcIP.String() ttlLeft = ip.TTL } else if ipLayer := p.Layer(layers.LayerTypeIPv6); ipLayer != nil { @@ -124,13 +123,17 @@ func (ts *TCPSniffer) processPacket(ctx context.Context, p gopacket.Packet) { key := srcIP // Calculate hop count from the TTL nhops := estimateHops(ttlLeft) - ok := ts.nhopCache.Set(key, nhops, cache.Options().WithUpdateExistingOnly(true)) - if ok { - logger.Trace(). - Str("from", key). - Uint8("nhops", nhops). - Uint8("ttlLeft", ttlLeft). - Msgf("ttl(tcp) update") + + stored, exists := ts.nhopCache.Get(key) + + if ts.nhopCache.Set(key, nhops, cache.Options().WithUpdateExistingOnly(true)) { + if !exists || stored != nhops { + logger.Trace(). + Str("from", key). + Uint8("nhops", nhops). + Uint8("ttlLeft", ttlLeft). + Msgf("ttl(tcp) update") + } } } @@ -172,10 +175,16 @@ func generateSynAckFilter(linkType layers.LinkType) []BPFInstruction { } else { // Check IP Version == 4 at the base offset // Load byte at baseOffset, mask 0xF0, check if 0x40 - instructions = append(instructions, + instructions = append( + instructions, // BPFInstruction{Op: 0x30, Jt: 0, Jf: 0, K: baseOffset}, // Ldb [baseOffset] BPFInstruction{Op: 0x54, Jt: 0, Jf: 0, K: 0xf0}, // And 0xf0 - BPFInstruction{Op: 0x15, Jt: 0, Jf: 8, K: 0x40}, // Jeq 0x40, True, False(Skip to End) + BPFInstruction{ + Op: 0x15, + Jt: 0, + Jf: 8, + K: 0x40, + }, // Jeq 0x40, True, False(Skip to End) ) } diff --git a/internal/server/http/https.go b/internal/server/http/https.go index 8b47a54b..7bc35a6f 100644 --- a/internal/server/http/https.go +++ b/internal/server/http/https.go @@ -20,23 +20,26 @@ import ( ) type HTTPSHandler struct { - logger zerolog.Logger - desyncer *desync.TLSDesyncer - sniffer packet.Sniffer - httpsOpts *config.HTTPSOptions + logger zerolog.Logger + desyncer *desync.TLSDesyncer + sniffer packet.Sniffer + defaultHTTPSOpts *config.HTTPSOptions + defaultConnOpts *config.ConnOptions } func NewHTTPSHandler( logger zerolog.Logger, desyncer *desync.TLSDesyncer, sniffer packet.Sniffer, - httpsOpts *config.HTTPSOptions, + defaultHTTPSOpts *config.HTTPSOptions, + defaultConnOpts *config.ConnOptions, ) *HTTPSHandler { return &HTTPSHandler{ - logger: logger, - desyncer: desyncer, - sniffer: sniffer, - httpsOpts: httpsOpts, + logger: logger, + desyncer: desyncer, + sniffer: sniffer, + defaultHTTPSOpts: defaultHTTPSOpts, + defaultConnOpts: defaultConnOpts, } } @@ -46,9 +49,11 @@ func (h *HTTPSHandler) HandleRequest( dst *netutil.Destination, rule *config.Rule, ) error { - opts := h.httpsOpts.Clone() + httpsOpts := h.defaultHTTPSOpts.Clone() + connOpts := h.defaultConnOpts.Clone() if rule != nil { - opts = opts.Merge(rule.HTTPS) + httpsOpts = httpsOpts.Merge(rule.HTTPS) + connOpts = connOpts.Merge(rule.Conn) } logger := logging.WithLocalScope(ctx, h.logger, "handshake") @@ -64,21 +69,23 @@ func (h *HTTPSHandler) HandleRequest( logger.Trace().Msgf("sent 200 connection established -> %s", lConn.RemoteAddr()) // 2. Tunnel - return h.tunnel(ctx, lConn, dst, opts) + return h.tunnel(ctx, lConn, dst, httpsOpts, connOpts) } func (h *HTTPSHandler) tunnel( ctx context.Context, lConn net.Conn, dst *netutil.Destination, - opts *config.HTTPSOptions, + httpsOpts *config.HTTPSOptions, + connOpts *config.ConnOptions, ) error { - if h.sniffer != nil && lo.FromPtr(opts.FakeCount) > 0 { + if h.sniffer != nil && lo.FromPtr(httpsOpts.FakeCount) > 0 { h.sniffer.RegisterUntracked(dst.Addrs) } logger := logging.WithLocalScope(ctx, h.logger, "https") + dst.Timeout = *connOpts.TCPTimeout rConn, err := netutil.DialFastest(ctx, "tcp", dst) if err != nil { return err @@ -107,7 +114,7 @@ func (h *HTTPSHandler) tunnel( } // Send ClientHello to the remote server (with desync if configured) - n, err := h.sendClientHello(ctx, rConn, tlsMsg, opts) + n, err := h.sendClientHello(ctx, rConn, tlsMsg, httpsOpts) if err != nil { return fmt.Errorf("failed to send client hello: %w", err) } diff --git a/internal/server/http/network.go b/internal/server/http/network.go index d2d78955..13e47a26 100644 --- a/internal/server/http/network.go +++ b/internal/server/http/network.go @@ -1,8 +1,10 @@ -//go:build !darwin && !linux +//go:build !darwin package http -import "github.com/rs/zerolog" +import ( + "github.com/rs/zerolog" +) func SetSystemProxy(logger zerolog.Logger, port uint16) error { return nil diff --git a/internal/server/http/network_linux.go b/internal/server/http/network_linux.go deleted file mode 100644 index 918c0c91..00000000 --- a/internal/server/http/network_linux.go +++ /dev/null @@ -1,14 +0,0 @@ -//go:build linux - -package http - -import "github.com/rs/zerolog" - -func SetSystemProxy(logger zerolog.Logger, port uint16) error { - // Not implemented for Linux yet - return nil -} - -func UnsetSystemProxy(logger zerolog.Logger) error { - return nil -} diff --git a/internal/server/http/proxy.go b/internal/server/http/server.go similarity index 91% rename from internal/server/http/proxy.go rename to internal/server/http/server.go index b4e71b43..12f8837e 100644 --- a/internal/server/http/proxy.go +++ b/internal/server/http/server.go @@ -26,7 +26,8 @@ type HTTPProxy struct { httpHandler *HTTPHandler httpsHandler *HTTPSHandler ruleMatcher matcher.RuleMatcher - serverOpts *config.ServerOptions + appOpts *config.AppOptions + connOpts *config.ConnOptions policyOpts *config.PolicyOptions listener net.Listener @@ -38,7 +39,8 @@ func NewHTTPProxy( httpHandler *HTTPHandler, httpsHandler *HTTPSHandler, ruleMatcher matcher.RuleMatcher, - serverOpts *config.ServerOptions, + appOpts *config.AppOptions, + connOpts *config.ConnOptions, policyOpts *config.PolicyOptions, ) server.Server { return &HTTPProxy{ @@ -47,17 +49,18 @@ func NewHTTPProxy( httpHandler: httpHandler, httpsHandler: httpsHandler, ruleMatcher: ruleMatcher, - serverOpts: serverOpts, + appOpts: appOpts, + connOpts: connOpts, policyOpts: policyOpts, } } func (p *HTTPProxy) Start(ctx context.Context, ready chan<- struct{}) error { - listener, err := net.ListenTCP("tcp", p.serverOpts.ListenAddr) + listener, err := net.ListenTCP("tcp", p.appOpts.ListenAddr) if err != nil { return fmt.Errorf( "error creating listener on %s: %w", - p.serverOpts.ListenAddr.String(), + p.appOpts.ListenAddr.String(), err, ) } @@ -92,7 +95,7 @@ func (p *HTTPProxy) Stop() error { } func (p *HTTPProxy) SetNetworkConfig() error { - return SetSystemProxy(p.logger, uint16(p.serverOpts.ListenAddr.Port)) + return SetSystemProxy(p.logger, uint16(p.appOpts.ListenAddr.Port)) } func (p *HTTPProxy) UnsetNetworkConfig() error { @@ -100,7 +103,7 @@ func (p *HTTPProxy) UnsetNetworkConfig() error { } func (p *HTTPProxy) Addr() string { - return p.serverOpts.ListenAddr.String() + return p.appOpts.ListenAddr.String() } func (p *HTTPProxy) handleNewConnection(ctx context.Context, conn net.Conn) { @@ -166,7 +169,7 @@ func (p *HTTPProxy) handleNewConnection(ctx context.Context, conn net.Conn) { } // Avoid recursively querying self. - ok, err := netutil.ValidateDestination(addrs, dstPort, p.serverOpts.ListenAddr) + ok, err := netutil.ValidateDestination(addrs, dstPort, p.appOpts.ListenAddr) if err != nil { logger.Debug().Err(err).Msg("error validating dst addrs") if !ok { @@ -199,7 +202,7 @@ func (p *HTTPProxy) handleNewConnection(ctx context.Context, conn net.Conn) { Domain: host, // Updated from Domain to Host Addrs: addrs, Port: dstPort, - Timeout: *p.serverOpts.Timeout, + Timeout: *p.connOpts.TCPTimeout, } var handleErr error diff --git a/internal/server/socks5/connect.go b/internal/server/socks5/connect.go index ac3b09c4..fbe0f61d 100644 --- a/internal/server/socks5/connect.go +++ b/internal/server/socks5/connect.go @@ -18,26 +18,29 @@ import ( ) type ConnectHandler struct { - logger zerolog.Logger - desyncer *desync.TLSDesyncer - sniffer packet.Sniffer - serverOpts *config.ServerOptions - httpsOpts *config.HTTPSOptions + logger zerolog.Logger + desyncer *desync.TLSDesyncer + sniffer packet.Sniffer + appOpts *config.AppOptions + defaultConnOpts *config.ConnOptions + defaultHTTPSOpts *config.HTTPSOptions } func NewConnectHandler( logger zerolog.Logger, desyncer *desync.TLSDesyncer, sniffer packet.Sniffer, - serverOpts *config.ServerOptions, - httpsOpts *config.HTTPSOptions, + appOpts *config.AppOptions, + defaultConnOpts *config.ConnOptions, + defaultHTTPSOpts *config.HTTPSOptions, ) *ConnectHandler { return &ConnectHandler{ - logger: logger, - desyncer: desyncer, - sniffer: sniffer, - serverOpts: serverOpts, - httpsOpts: httpsOpts, + logger: logger, + desyncer: desyncer, + sniffer: sniffer, + appOpts: appOpts, + defaultConnOpts: defaultConnOpts, + defaultHTTPSOpts: defaultHTTPSOpts, } } @@ -48,15 +51,17 @@ func (h *ConnectHandler) Handle( dst *netutil.Destination, rule *config.Rule, ) error { - opts := h.httpsOpts.Clone() + httpsOpts := h.defaultHTTPSOpts.Clone() + connOpts := h.defaultConnOpts.Clone() if rule != nil { - opts = opts.Merge(rule.HTTPS) + httpsOpts = httpsOpts.Merge(rule.HTTPS) + connOpts = connOpts.Merge(rule.Conn) } logger := logging.WithLocalScope(ctx, h.logger, "connect") // 1. Validate Destination - ok, err := netutil.ValidateDestination(dst.Addrs, dst.Port, h.serverOpts.ListenAddr) + ok, err := netutil.ValidateDestination(dst.Addrs, dst.Port, h.appOpts.ListenAddr) if err != nil { logger.Debug().Err(err).Msg("error determining if valid destination") if !ok { @@ -79,6 +84,7 @@ func (h *ConnectHandler) Handle( } // logger := logging.WithLocalScope(ctx, h.logger, "connect(tcp)") + dst.Timeout = *connOpts.TCPTimeout rConn, err := netutil.DialFastest(ctx, "tcp", dst) if err != nil { @@ -96,11 +102,11 @@ func (h *ConnectHandler) Handle( b, err := bufConn.Peek(1) if err == nil && b[0] == byte(proto.TLSHandshake) { // 0x16 - if h.sniffer != nil && lo.FromPtr(opts.FakeCount) > 0 { + if h.sniffer != nil && lo.FromPtr(httpsOpts.FakeCount) > 0 { h.sniffer.RegisterUntracked(dst.Addrs) } - return h.handleHTTPS(ctx, bufConn, rConn, opts) + return h.handleHTTPS(ctx, bufConn, rConn, httpsOpts) } // If not TLS, fall back to pure TCP tunnel diff --git a/internal/server/socks5/network.go b/internal/server/socks5/network.go index 2bba7628..c7100f56 100644 --- a/internal/server/socks5/network.go +++ b/internal/server/socks5/network.go @@ -1,8 +1,10 @@ -//go:build !darwin && !linux +//go:build !darwin package socks5 -import "github.com/rs/zerolog" +import ( + "github.com/rs/zerolog" +) func SetSystemProxy(logger zerolog.Logger, port uint16) error { return nil diff --git a/internal/server/socks5/network_linux.go b/internal/server/socks5/network_linux.go deleted file mode 100644 index 1350894c..00000000 --- a/internal/server/socks5/network_linux.go +++ /dev/null @@ -1,14 +0,0 @@ -//go:build linux - -package socks5 - -import "github.com/rs/zerolog" - -func SetSystemProxy(logger zerolog.Logger, port uint16) error { - // Not implemented for Linux yet - return nil -} - -func UnsetSystemProxy(logger zerolog.Logger) error { - return nil -} diff --git a/internal/server/socks5/server.go b/internal/server/socks5/server.go index de4c6e2e..7bc1748f 100644 --- a/internal/server/socks5/server.go +++ b/internal/server/socks5/server.go @@ -31,7 +31,8 @@ type SOCKS5Proxy struct { bindHandler *BindHandler udpAssociateHandler *UdpAssociateHandler - serverOpts *config.ServerOptions + appOpts *config.AppOptions + connOpts *config.ConnOptions policyOpts *config.PolicyOptions listener net.Listener @@ -44,7 +45,8 @@ func NewSOCKS5Proxy( connectHandler *ConnectHandler, bindHandler *BindHandler, udpAssociateHandler *UdpAssociateHandler, - serverOpts *config.ServerOptions, + appOpts *config.AppOptions, + connOpts *config.ConnOptions, policyOpts *config.PolicyOptions, ) server.Server { return &SOCKS5Proxy{ @@ -54,17 +56,18 @@ func NewSOCKS5Proxy( connectHandler: connectHandler, bindHandler: bindHandler, udpAssociateHandler: udpAssociateHandler, - serverOpts: serverOpts, + appOpts: appOpts, + connOpts: connOpts, policyOpts: policyOpts, } } func (p *SOCKS5Proxy) Start(ctx context.Context, ready chan<- struct{}) error { - listener, err := net.ListenTCP("tcp", p.serverOpts.ListenAddr) + listener, err := net.ListenTCP("tcp", p.appOpts.ListenAddr) if err != nil { return fmt.Errorf( "error creating listener on %s: %w", - p.serverOpts.ListenAddr.String(), + p.appOpts.ListenAddr.String(), err, ) } @@ -98,7 +101,7 @@ func (p *SOCKS5Proxy) Stop() error { } func (p *SOCKS5Proxy) SetNetworkConfig() error { - return SetSystemProxy(p.logger, uint16(p.serverOpts.ListenAddr.Port)) + return SetSystemProxy(p.logger, uint16(p.appOpts.ListenAddr.Port)) } func (p *SOCKS5Proxy) UnsetNetworkConfig() error { @@ -106,7 +109,7 @@ func (p *SOCKS5Proxy) UnsetNetworkConfig() error { } func (p *SOCKS5Proxy) Addr() string { - return p.serverOpts.ListenAddr.String() + return p.appOpts.ListenAddr.String() } func (p *SOCKS5Proxy) handleConnection(ctx context.Context, conn net.Conn) { @@ -188,7 +191,7 @@ func (p *SOCKS5Proxy) handleConnection(ctx context.Context, conn net.Conn) { Domain: req.FQDN, Addrs: addrs, Port: req.Port, - Timeout: *p.serverOpts.Timeout, + Timeout: *p.connOpts.TCPTimeout, } if err = p.connectHandler.Handle(ctx, conn, req, dst, bestMatch); err != nil { return // Handler logs error diff --git a/internal/server/tun/network.go b/internal/server/tun/network.go index da6e6488..1cbfb5c4 100644 --- a/internal/server/tun/network.go +++ b/internal/server/tun/network.go @@ -1,15 +1,23 @@ -//go:build !darwin +//go:build !darwin && !linux && !freebsd package tun -func SetRouting(iface string, subnets []string) error { +func SetRoute(iface string, subnets []string) error { return nil } -func UnsetRouting(iface string, subnets []string) error { +func UnsetRoute(iface string, subnets []string) error { return nil } func SetInterfaceAddress(iface string, local string, remote string) error { return nil } + +func UnsetGatewayRoute(gateway, iface string) error { + return nil +} + +func SetGatewayRoute(gateway, iface string) error { + return nil +} diff --git a/internal/server/tun/network_darwin.go b/internal/server/tun/network_bsd.go similarity index 99% rename from internal/server/tun/network_darwin.go rename to internal/server/tun/network_bsd.go index 41e2a35a..304e646b 100644 --- a/internal/server/tun/network_darwin.go +++ b/internal/server/tun/network_bsd.go @@ -1,3 +1,5 @@ +//go:build darwin && freebsd + package tun import ( diff --git a/internal/server/tun/network_linux.go b/internal/server/tun/network_linux.go new file mode 100644 index 00000000..50ce5a44 --- /dev/null +++ b/internal/server/tun/network_linux.go @@ -0,0 +1,260 @@ +//go:build linux + +package tun + +import ( + "fmt" + "os/exec" + "regexp" + "strconv" + "strings" + "sync" +) + +var ( + allocatedTableID int + allocatedTableOnce sync.Once +) + +// getOrAllocateTableID returns a routing table ID, allocating one if needed. +// The ID is cached after first allocation. +func getOrAllocateTableID() (int, error) { + var initErr error + allocatedTableOnce.Do(func() { + allocatedTableID, initErr = findAvailableTableID() + }) + if allocatedTableID == 0 { + return 0, initErr + } + return allocatedTableID, nil +} + +// findAvailableTableID finds an unused routing table ID in the range 100-252. +func findAvailableTableID() (int, error) { + usedTables := make(map[int]bool) + + // Parse "ip rule show" output to find used table IDs + cmd := exec.Command("ip", "rule", "show") + if out, err := cmd.Output(); err == nil { + re := regexp.MustCompile(`lookup\s+(\d+)`) + matches := re.FindAllStringSubmatch(string(out), -1) + for _, match := range matches { + if len(match) >= 2 { + if id, err := strconv.Atoi(match[1]); err == nil { + usedTables[id] = true + } + } + } + } + + // Also check /etc/iproute2/rt_tables for reserved tables + rtTablesCmd := exec.Command("cat", "/etc/iproute2/rt_tables") + if rtOut, err := rtTablesCmd.Output(); err == nil { + lines := strings.Split(string(rtOut), "\n") + for _, line := range lines { + line = strings.TrimSpace(line) + if line == "" || strings.HasPrefix(line, "#") { + continue + } + fields := strings.Fields(line) + if len(fields) >= 1 { + if id, err := strconv.Atoi(fields[0]); err == nil { + usedTables[id] = true + } + } + } + } + + // Find first available ID in range 100-252 (253-255 are reserved) + for id := 100; id <= 252; id++ { + if !usedTables[id] { + return id, nil + } + } + + return 0, fmt.Errorf("no available routing table ID in range 100-252") +} + +// SetRoute configures network routes for specified subnets using ip route +func SetRoute(iface string, subnets []string) error { + for _, subnet := range subnets { + targets := []string{subnet} + + /* Expand default route into two /1 subnets to override the default gateway + without removing the existing 0.0.0.0/0 entry. + */ + if subnet == "0.0.0.0/0" { + targets = []string{"0.0.0.0/1", "128.0.0.0/1"} + } + + for _, target := range targets { + cmd := exec.Command("ip", "route", "add", target, "dev", iface) + out, err := cmd.CombinedOutput() + if err != nil { + // Check if it's a permission error + if strings.Contains(string(out), "Operation not permitted") || + strings.Contains(string(out), "RTNETLINK answers: Operation not permitted") { + return fmt.Errorf( + "permission denied: must run as root to modify routing table (sudo required)", + ) + } + // Ignore "File exists" error - route already exists + if strings.Contains(string(out), "File exists") { + continue + } + return fmt.Errorf( + "failed to add route for %s on %s: %s: %w", + target, + iface, + string(out), + err, + ) + } + } + } + return nil +} + +// UnsetRoute removes previously configured network routes using ip route +func UnsetRoute(iface string, subnets []string) error { + for _, subnet := range subnets { + targets := []string{subnet} + + if subnet == "0.0.0.0/0" { + targets = []string{"0.0.0.0/1", "128.0.0.0/1"} + } + + for _, target := range targets { + /* Delete specific routes to revert traffic flow to the original gateway. */ + cmd := exec.Command("ip", "route", "del", target, "dev", iface) + if out, err := cmd.CombinedOutput(); err != nil { + _ = out + } + } + } + return nil +} + +// SetInterfaceAddress configures the TUN interface with local and remote endpoints using ip addr +func SetInterfaceAddress(iface string, local string, remote string) error { + // Add the IP address to the interface + cmd := exec.Command("ip", "addr", "add", local, "peer", remote, "dev", iface) + out, err := cmd.CombinedOutput() + if err != nil { + // Ignore if address already exists + if !strings.Contains(string(out), "File exists") { + return fmt.Errorf("failed to set interface address: %s: %w", string(out), err) + } + } + + // Bring the interface up + cmd = exec.Command("ip", "link", "set", "dev", iface, "up") + if out, err := cmd.CombinedOutput(); err != nil { + return fmt.Errorf("failed to bring interface up: %s: %w", string(out), err) + } + + return nil +} + +// UnsetGatewayRoute removes the gateway route and policy routing rules +func UnsetGatewayRoute(gateway, iface string) error { + tableID, err := getOrAllocateTableID() + if err != nil { + return err + } + tableIDStr := strconv.Itoa(tableID) + + // Get the interface's IP address for policy routing cleanup + ifaceIP := getInterfaceIP(iface) + if ifaceIP != "" { + // Remove the policy rule + cmd := exec.Command("ip", "rule", "del", "from", ifaceIP, "lookup", tableIDStr) + if out, err := cmd.CombinedOutput(); err != nil { + _ = out + } + } + + // Remove the route from the allocated table + cmd := exec.Command("ip", "route", "del", "default", "table", tableIDStr) + if out, err := cmd.CombinedOutput(); err != nil { + _ = out + } + + // Remove the direct route to the gateway + cmd = exec.Command("ip", "route", "del", gateway, "dev", iface) + if out, err := cmd.CombinedOutput(); err != nil { + _ = out + } + + return nil +} + +// SetGatewayRoute configures policy routing so that packets from the physical interface +// are routed via the gateway, while other packets go through TUN +func SetGatewayRoute(gateway, iface string) error { + tableID, err := getOrAllocateTableID() + if err != nil { + return fmt.Errorf("failed to allocate routing table ID: %w", err) + } + tableIDStr := strconv.Itoa(tableID) + + // First, add a direct route to the gateway via the interface + cmd := exec.Command("ip", "route", "add", gateway, "dev", iface) + out, err := cmd.CombinedOutput() + if err != nil { + // Ignore "File exists" error - route already exists + if !strings.Contains(string(out), "File exists") { + return fmt.Errorf("failed to add gateway route: %s: %w", string(out), err) + } + } + + // Add default route to the allocated table via the gateway + cmd = exec.Command("ip", "route", "add", "default", "via", gateway, "dev", iface, "table", tableIDStr) + out, err = cmd.CombinedOutput() + if err != nil { + if !strings.Contains(string(out), "File exists") { + // Optional, don't fail + _ = out + } + } + + // Get the interface's IP address for policy routing + ifaceIP := getInterfaceIP(iface) + if ifaceIP == "" { + return fmt.Errorf("failed to get IP address for interface %s", iface) + } + + // Add policy rule: packets from this IP use the allocated table + cmd = exec.Command("ip", "rule", "add", "from", ifaceIP, "lookup", tableIDStr) + out, err = cmd.CombinedOutput() + if err != nil { + if !strings.Contains(string(out), "File exists") { + return fmt.Errorf("failed to add policy rule: %s: %w", string(out), err) + } + } + + return nil +} + +// getInterfaceIP returns the first IPv4 address of the given interface +func getInterfaceIP(ifaceName string) string { + cmd := exec.Command("ip", "-4", "addr", "show", ifaceName) + out, err := cmd.Output() + if err != nil { + return "" + } + + // Parse output to find inet line + lines := strings.Split(string(out), "\n") + for _, line := range lines { + line = strings.TrimSpace(line) + if strings.HasPrefix(line, "inet ") { + parts := strings.Fields(line) + if len(parts) >= 2 { + ip := strings.Split(parts[1], "/")[0] + return ip + } + } + } + return "" +} diff --git a/internal/server/tun/server.go b/internal/server/tun/server.go index 8dceefdc..31a5b198 100644 --- a/internal/server/tun/server.go +++ b/internal/server/tun/server.go @@ -115,7 +115,10 @@ func (s *TunServer) SetNetworkConfig() error { // Add route for the TUN interface subnet to ensure packets can return // This is crucial for the TUN interface to receive packets destined for its own subnet - if err := SetRoute(s.iface.Name(), []string{local + "/30"}); err != nil { + // Calculate the network address for /30 subnet (e.g., 10.0.0.1 -> 10.0.0.0/30) + localIP := net.ParseIP(local) + networkAddr := net.IPv4(localIP[12], localIP[13], localIP[14], localIP[15]&0xFC) // Mask with /30 + if err := SetRoute(s.iface.Name(), []string{networkAddr.String() + "/30"}); err != nil { return fmt.Errorf("failed to set local route: %w", err) } diff --git a/internal/server/tun/tcp.go b/internal/server/tun/tcp.go index 00f8cb06..fc2ec4d3 100644 --- a/internal/server/tun/tcp.go +++ b/internal/server/tun/tcp.go @@ -18,32 +18,35 @@ import ( ) type TCPHandler struct { - logger zerolog.Logger - domainMatcher matcher.RuleMatcher // For TLS domain matching only - httpsOpts *config.HTTPSOptions - desyncer *desync.TLSDesyncer - sniffer packet.Sniffer // For TTL tracking - iface string - gateway string + logger zerolog.Logger + domainMatcher matcher.RuleMatcher // For TLS domain matching only + defaultHTTPSOpts *config.HTTPSOptions + defaultConnOpts *config.ConnOptions + desyncer *desync.TLSDesyncer + sniffer packet.Sniffer // For TTL tracking + iface string + gateway string } func NewTCPHandler( logger zerolog.Logger, domainMatcher matcher.RuleMatcher, - httpsOpts *config.HTTPSOptions, + defaultHTTPSOpts *config.HTTPSOptions, + defaultConnOpts *config.ConnOptions, desyncer *desync.TLSDesyncer, sniffer packet.Sniffer, iface string, gateway string, ) *TCPHandler { return &TCPHandler{ - logger: logger, - domainMatcher: domainMatcher, - httpsOpts: httpsOpts, - desyncer: desyncer, - sniffer: sniffer, - iface: iface, - gateway: gateway, + logger: logger, + domainMatcher: domainMatcher, + defaultHTTPSOpts: defaultHTTPSOpts, + defaultConnOpts: defaultConnOpts, + desyncer: desyncer, + sniffer: sniffer, + iface: iface, + gateway: gateway, } } @@ -87,6 +90,9 @@ func (h *TCPHandler) Handle(ctx context.Context, lConn net.Conn, rule *config.Ru Iface: iface, Gateway: h.gateway, } + if h.defaultConnOpts != nil && h.defaultConnOpts.TCPTimeout != nil { + dst.Timeout = *h.defaultConnOpts.TCPTimeout + } if ip != nil { dst.Addrs = append(dst.Addrs, ip) } @@ -103,6 +109,7 @@ func (h *TCPHandler) Handle(ctx context.Context, lConn net.Conn, rule *config.Ru // Handle as plain TCP rConn, err := netutil.DialFastest(ctx, "tcp", dst) if err != nil { + logger.Error().Msgf("failed to dial %v", err) return } @@ -157,12 +164,14 @@ func (h *TCPHandler) handleTLS( logger.Trace().Str("value", dst.Domain).Msgf("extracted sni feild") // Match Rules - opts := h.httpsOpts.Clone() + httpsOpts := h.defaultHTTPSOpts.Clone() + connOpts := h.defaultConnOpts.Clone() // First, apply IP-based rule if matched in server.go if addrRule != nil { logger.Trace().RawJSON("summary", addrRule.JSON()).Msg("addr match") - opts = opts.Merge(addrRule.HTTPS) + httpsOpts = httpsOpts.Merge(addrRule.HTTPS) + connOpts = connOpts.Merge(addrRule.Conn) } // Then, try domain-based matching (TLS-specific) @@ -176,11 +185,14 @@ func (h *TCPHandler) handleTLS( // Domain rule takes priority if it has higher priority finalRule := matcher.GetHigherPriorityRule(addrRule, domainRule) if finalRule == domainRule { - opts = h.httpsOpts.Clone().Merge(domainRule.HTTPS) + httpsOpts = h.defaultHTTPSOpts.Clone().Merge(domainRule.HTTPS) + connOpts = h.defaultConnOpts.Clone().Merge(domainRule.Conn) } } } + dst.Timeout = *connOpts.TCPTimeout + // Dial Remote if h.sniffer != nil { h.sniffer.RegisterUntracked(dst.Addrs) @@ -195,7 +207,7 @@ func (h *TCPHandler) handleTLS( Msgf("new remote conn (%s -> %s)", lConn.RemoteAddr(), rConn.RemoteAddr()) // Send ClientHello with Desync - if _, err := h.desyncer.Desync(ctx, logger, rConn, tlsMsg, opts); err != nil { + if _, err := h.desyncer.Desync(ctx, logger, rConn, tlsMsg, httpsOpts); err != nil { return err } diff --git a/internal/server/tun/udp.go b/internal/server/tun/udp.go index 0a54def3..e417b197 100644 --- a/internal/server/tun/udp.go +++ b/internal/server/tun/udp.go @@ -14,25 +14,28 @@ import ( ) type UDPHandler struct { - logger zerolog.Logger - udpOpts *config.UDPOptions - desyncer *desync.UDPDesyncer - pool *netutil.ConnPool - iface string - gateway string + logger zerolog.Logger + defaultUDPOpts *config.UDPOptions + defaultConnOpts *config.ConnOptions + desyncer *desync.UDPDesyncer + pool *netutil.ConnPool + iface string + gateway string } func NewUDPHandler( logger zerolog.Logger, desyncer *desync.UDPDesyncer, - udpOpts *config.UDPOptions, + defaultUDPOpts *config.UDPOptions, + defaultConnOpts *config.ConnOptions, pool *netutil.ConnPool, ) *UDPHandler { return &UDPHandler{ - logger: logger, - desyncer: desyncer, - udpOpts: udpOpts, - pool: pool, + logger: logger, + desyncer: desyncer, + defaultUDPOpts: defaultUDPOpts, + defaultConnOpts: defaultConnOpts, + pool: pool, } } @@ -68,10 +71,12 @@ func (h *UDPHandler) Handle(ctx context.Context, lConn net.Conn, rule *config.Ru } // Apply rule if matched in server.go - opts := h.udpOpts.Clone() + udpOpts := h.defaultUDPOpts.Clone() + connOpts := h.defaultConnOpts.Clone() if rule != nil { logger.Trace().RawJSON("summary", rule.JSON()).Msg("match") - opts = opts.Merge(rule.UDP) + udpOpts = udpOpts.Merge(rule.UDP) + connOpts = connOpts.Merge(rule.Conn) } // Key for connection pool @@ -87,11 +92,11 @@ func (h *UDPHandler) Handle(ctx context.Context, lConn net.Conn, rule *config.Ru rConn := h.pool.Add(key, rawConn) // Wrap lConn with TimeoutConn as well - timeout := *opts.Timeout + timeout := *connOpts.UDPTimeout lConnWrapped := &netutil.TimeoutConn{Conn: lConn, Timeout: timeout} // Desync - _, _ = h.desyncer.Desync(ctx, lConnWrapped, rConn, opts) + _, _ = h.desyncer.Desync(ctx, lConnWrapped, rConn, udpOpts) logger.Debug(). Msgf("new remote conn (%s -> %s)", lConn.RemoteAddr(), rConn.RemoteAddr()) From 3fa51d634d6c0f1b6f0e2c180da7e83bc48bafcf Mon Sep 17 00:00:00 2001 From: xvzc Date: Fri, 16 Jan 2026 23:16:32 +0900 Subject: [PATCH 09/39] fix: build condition for tun mode --- internal/server/tun/network_bsd.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/internal/server/tun/network_bsd.go b/internal/server/tun/network_bsd.go index 304e646b..3fe52daa 100644 --- a/internal/server/tun/network_bsd.go +++ b/internal/server/tun/network_bsd.go @@ -1,4 +1,4 @@ -//go:build darwin && freebsd +//go:build darwin || freebsd package tun From bfa53b08233e8928b1d347e72e7c8713d0079bcc Mon Sep 17 00:00:00 2001 From: xvzc Date: Sat, 17 Jan 2026 23:33:22 +0900 Subject: [PATCH 10/39] refactor(netutil): consolidate OS-specific code into netutil_*.go - Merge bind_*.go and gateway_*.go into unified netutil_*.go files - netutil_bsd.go: bindToInterface (IP_BOUND_IF) + getDefaultGateway for macOS/BSD - netutil_linux.go: bindToInterface (LocalAddr) + getDefaultGateway for Linux - netutil.go: no-op fallbacks for unsupported platforms - Simplify bindToInterface on BSD to only use iface.Index - Reduces file count from 6 to 3 --- internal/netutil/bind_bsd.go | 51 ---------------- internal/netutil/gateway.go | 8 --- internal/netutil/gateway_bsd.go | 28 --------- internal/netutil/gateway_linux.go | 29 ---------- internal/netutil/{bind.go => netutil.go} | 6 ++ internal/netutil/netutil_bsd.go | 58 +++++++++++++++++++ .../{bind_linux.go => netutil_linux.go} | 22 +++++++ 7 files changed, 86 insertions(+), 116 deletions(-) delete mode 100644 internal/netutil/bind_bsd.go delete mode 100644 internal/netutil/gateway.go delete mode 100644 internal/netutil/gateway_bsd.go delete mode 100644 internal/netutil/gateway_linux.go rename internal/netutil/{bind.go => netutil.go} (56%) create mode 100644 internal/netutil/netutil_bsd.go rename internal/netutil/{bind_linux.go => netutil_linux.go} (64%) diff --git a/internal/netutil/bind_bsd.go b/internal/netutil/bind_bsd.go deleted file mode 100644 index 287fafb5..00000000 --- a/internal/netutil/bind_bsd.go +++ /dev/null @@ -1,51 +0,0 @@ -//go:build darwin || freebsd - -package netutil - -import ( - "fmt" - "net" - "syscall" - - "golang.org/x/sys/unix" -) - -// bindToInterface sets the dialer's Control function to bind the socket -// to a specific network interface using IP_BOUND_IF on BSD systems. -func bindToInterface(dialer *net.Dialer, iface *net.Interface, targetIP net.IP) error { - if iface == nil { - return nil - } - - addrs, err := iface.Addrs() - if err != nil { - return fmt.Errorf("failed to get interface addresses: %w", err) - } - - for _, addr := range addrs { - if ipnet, ok := addr.(*net.IPNet); ok { - if targetIP.To4() != nil && ipnet.IP.To4() != nil && !ipnet.IP.IsLoopback() { - ifaceIndex := iface.Index - dialer.Control = func(network, address string, c syscall.RawConn) error { - var setsockoptErr error - err := c.Control(func(fd uintptr) { - setsockoptErr = unix.SetsockoptInt( - int(fd), - unix.IPPROTO_IP, - unix.IP_BOUND_IF, - ifaceIndex, - ) - }) - if err != nil { - return err - } - - return setsockoptErr - } - return nil - } - } - } - - return fmt.Errorf("no suitable IP address found on interface %s for target %s", iface.Name, targetIP) -} diff --git a/internal/netutil/gateway.go b/internal/netutil/gateway.go deleted file mode 100644 index 2acfae57..00000000 --- a/internal/netutil/gateway.go +++ /dev/null @@ -1,8 +0,0 @@ -//go:build !darwin && !freebsd && !linux - -package netutil - -// getDefaultGateway parses the system route table to find the default gateway on Linux -func getDefaultGateway() (string, error) { - return "", nil -} diff --git a/internal/netutil/gateway_bsd.go b/internal/netutil/gateway_bsd.go deleted file mode 100644 index 65ced920..00000000 --- a/internal/netutil/gateway_bsd.go +++ /dev/null @@ -1,28 +0,0 @@ -//go:build darwin || freebsd - -package netutil - -import ( - "fmt" - "os/exec" - "regexp" -) - -// getDefaultGateway parses the system route table to find the default gateway on macOS/BSD -func getDefaultGateway() (string, error) { - // Use route to get the default route on macOS - cmd := exec.Command("route", "-n", "get", "default") - out, err := cmd.Output() - if err != nil { - return "", err - } - - // Parse output to find gateway line - re := regexp.MustCompile(`gateway:\s+(\d+\.\d+\.\d+\.\d+)`) - matches := re.FindSubmatch(out) - if len(matches) < 2 { - return "", fmt.Errorf("could not parse gateway from route output") - } - - return string(matches[1]), nil -} diff --git a/internal/netutil/gateway_linux.go b/internal/netutil/gateway_linux.go deleted file mode 100644 index 10b9e738..00000000 --- a/internal/netutil/gateway_linux.go +++ /dev/null @@ -1,29 +0,0 @@ -//go:build linux - -package netutil - -import ( - "fmt" - "os/exec" - "strings" -) - -// getDefaultGateway parses the system route table to find the default gateway on Linux -func getDefaultGateway() (string, error) { - // Use ip route to get the default route on Linux - cmd := exec.Command("ip", "route", "show", "default") - out, err := cmd.Output() - if err != nil { - return "", err - } - - // Parse output: "default via 192.168.0.1 dev enp12s0 ..." - fields := strings.Fields(string(out)) - for i, field := range fields { - if field == "via" && i+1 < len(fields) { - return fields[i+1], nil - } - } - - return "", fmt.Errorf("could not parse gateway from ip route output: %s", string(out)) -} diff --git a/internal/netutil/bind.go b/internal/netutil/netutil.go similarity index 56% rename from internal/netutil/bind.go rename to internal/netutil/netutil.go index c406894a..f19e56ba 100644 --- a/internal/netutil/bind.go +++ b/internal/netutil/netutil.go @@ -3,6 +3,7 @@ package netutil import ( + "fmt" "net" ) @@ -10,3 +11,8 @@ import ( func bindToInterface(dialer *net.Dialer, iface *net.Interface, targetIP net.IP) error { return nil } + +// getDefaultGateway is not supported on this platform. +func getDefaultGateway() (string, error) { + return "", fmt.Errorf("getDefaultGateway not supported on this platform") +} diff --git a/internal/netutil/netutil_bsd.go b/internal/netutil/netutil_bsd.go new file mode 100644 index 00000000..ab5330b6 --- /dev/null +++ b/internal/netutil/netutil_bsd.go @@ -0,0 +1,58 @@ +//go:build darwin || freebsd + +package netutil + +import ( + "fmt" + "net" + "os/exec" + "regexp" + "syscall" + + "golang.org/x/sys/unix" +) + +// bindToInterface sets the dialer's Control function to bind the socket +// to a specific network interface using IP_BOUND_IF on BSD systems. +func bindToInterface(dialer *net.Dialer, iface *net.Interface, targetIP net.IP) error { + if iface == nil { + return nil + } + + ifaceIndex := iface.Index + dialer.Control = func(network, address string, c syscall.RawConn) error { + var setsockoptErr error + err := c.Control(func(fd uintptr) { + setsockoptErr = unix.SetsockoptInt( + int(fd), + unix.IPPROTO_IP, + unix.IP_BOUND_IF, + ifaceIndex, + ) + }) + if err != nil { + return err + } + return setsockoptErr + } + return nil +} + +// getDefaultGateway parses the system route table to find the default gateway on macOS/BSD +func getDefaultGateway() (string, error) { + // Use route to get the default route on macOS + cmd := exec.Command("route", "-n", "get", "default") + out, err := cmd.Output() + if err != nil { + return "", err + } + + // Parse output to find gateway line + re := regexp.MustCompile(`gateway:\s+(\d+\.\d+\.\d+\.\d+)`) + matches := re.FindSubmatch(out) + if len(matches) < 2 { + return "", fmt.Errorf("could not parse gateway from route output") + } + + return string(matches[1]), nil +} diff --git a/internal/netutil/bind_linux.go b/internal/netutil/netutil_linux.go similarity index 64% rename from internal/netutil/bind_linux.go rename to internal/netutil/netutil_linux.go index 8467bd3e..1e370b34 100644 --- a/internal/netutil/bind_linux.go +++ b/internal/netutil/netutil_linux.go @@ -5,6 +5,8 @@ package netutil import ( "fmt" "net" + "os/exec" + "strings" ) // bindToInterface sets the dialer's LocalAddr to use the interface's IP as the source address. @@ -36,3 +38,23 @@ func bindToInterface(dialer *net.Dialer, iface *net.Interface, targetIP net.IP) return fmt.Errorf("no suitable IP address found on interface %s for target %s", iface.Name, targetIP) } + +// getDefaultGateway parses the system route table to find the default gateway on Linux +func getDefaultGateway() (string, error) { + // Use ip route to get the default route on Linux + cmd := exec.Command("ip", "route", "show", "default") + out, err := cmd.Output() + if err != nil { + return "", err + } + + // Parse output: "default via 192.168.0.1 dev enp12s0 ..." + fields := strings.Fields(string(out)) + for i, field := range fields { + if field == "via" && i+1 < len(fields) { + return fields[i+1], nil + } + } + + return "", fmt.Errorf("could not parse gateway from ip route output: %s", string(out)) +} From cbdb7a54b074273a3c172f6d1f31b2205cae3caf Mon Sep 17 00:00:00 2001 From: xvzc Date: Sat, 17 Jan 2026 23:34:23 +0900 Subject: [PATCH 11/39] fix(tun): use -ifscope for proper interface-bound routing on macOS - Replace 0.0.0.0/32 hack with proper -ifscope default route - SetGatewayRoute now uses 'route add -ifscope default ' - This creates a scoped default route for IP_BOUND_IF sockets - UnsetGatewayRoute updated to clean up ifscope routes properly - More logical and standards-compliant routing approach --- internal/server/tun/network_bsd.go | 30 ++++++++++++++++-------------- 1 file changed, 16 insertions(+), 14 deletions(-) diff --git a/internal/server/tun/network_bsd.go b/internal/server/tun/network_bsd.go index 3fe52daa..c4bc2f3d 100644 --- a/internal/server/tun/network_bsd.go +++ b/internal/server/tun/network_bsd.go @@ -72,16 +72,17 @@ func SetInterfaceAddress(iface string, local string, remote string) error { return nil } -// UnsetGatewayRoute removes the gateway route +// UnsetGatewayRoute removes the scoped gateway route func UnsetGatewayRoute(gateway, iface string) error { - // Remove the direct route to the gateway - cmd := exec.Command("route", "-n", "delete", "-host", gateway, "-interface", iface) + // Remove the scoped default route for the physical interface + // This undoes the SetGatewayRoute ifscope route + cmd := exec.Command("route", "delete", "-ifscope", iface, "default") if out, err := cmd.CombinedOutput(); err != nil { _ = out } - // Also try to remove the 0.0.0.0/2 route if it exists (cleanup from previous versions) - cmd = exec.Command("route", "-n", "delete", "-net", "0.0.0.0/32", gateway) + // Remove the direct host route to the gateway + cmd = exec.Command("route", "-n", "delete", "-host", gateway, "-interface", iface) if out, err := cmd.CombinedOutput(); err != nil { _ = out } @@ -89,23 +90,24 @@ func UnsetGatewayRoute(gateway, iface string) error { return nil } -// SetGatewayRoute adds a host route to the gateway via the specified interface -// This ensures traffic destined for the gateway goes through the physical interface +// SetGatewayRoute adds a scoped default route via the physical interface +// This ensures IP_BOUND_IF sockets on the physical interface can reach the gateway func SetGatewayRoute(gateway, iface string) error { - // First, get the gateway's subnet to add a direct route - cmd := exec.Command("route", "-n", "add", "-host", gateway, "-interface", iface) + // Add a scoped default route for the physical interface + // When a socket is bound to this interface via IP_BOUND_IF, + // this scoped route will be used instead of the TUN routes (0/1, 128/1) + cmd := exec.Command("route", "add", "-ifscope", iface, "default", gateway) out, err := cmd.CombinedOutput() if err != nil { // Ignore "File exists" error - route already exists if !strings.Contains(string(out), "File exists") { - return fmt.Errorf("failed to add gateway route: %s: %w", string(out), err) + return fmt.Errorf("failed to add scoped default route: %s: %w", string(out), err) } } - // Also add a less specific route that uses the gateway for 0.0.0.0/2 - // This provides a path for IP_BOUND_IF sockets on en0 to reach external hosts - // The 0/2 route is less specific than 0/1 but will be used when bound to en0 - cmd = exec.Command("route", "-n", "add", "-net", "0.0.0.0/32", gateway) + // Also add a host route to the gateway via the physical interface + // This ensures packets to the gateway itself go through the right interface + cmd = exec.Command("route", "-n", "add", "-host", gateway, "-interface", iface) out, err = cmd.CombinedOutput() if err != nil { // Ignore "File exists" error From 2acfb16b9cb8daf5ebfb7a0a229e12d7dfa29475 Mon Sep 17 00:00:00 2001 From: xvzc Date: Sat, 17 Jan 2026 23:35:02 +0900 Subject: [PATCH 12/39] refactor(netutil): rename TimeoutConn to IdleTimeoutConn - Better reflects the semantics: deadline extends on activity - Updated comments to clarify idle-based timeout behavior - Connection stays alive as long as there are Read/Write operations --- internal/netutil/conn.go | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/internal/netutil/conn.go b/internal/netutil/conn.go index 51861f71..1fd377c8 100644 --- a/internal/netutil/conn.go +++ b/internal/netutil/conn.go @@ -223,28 +223,28 @@ func (b *BufferedConn) Peek(n int) ([]byte, error) { return b.r.Peek(n) } -// TimeoutConn wraps a net.Conn to update the read deadline on every Read call. -// This is useful for UDP sessions which do not have a natural EOF. -type TimeoutConn struct { +// IdleTimeoutConn wraps a net.Conn to extend the deadline on every Read/Write call. +// This is useful for sessions which should stay alive as long as there is activity. +type IdleTimeoutConn struct { net.Conn Timeout time.Duration LastActive time.Time ExpiredAt time.Time // Calculated expiration time for cleanup } -func (c *TimeoutConn) Read(b []byte) (int, error) { +func (c *IdleTimeoutConn) Read(b []byte) (int, error) { c.ExtendDeadline() return c.Conn.Read(b) } -func (c *TimeoutConn) Write(b []byte) (int, error) { +func (c *IdleTimeoutConn) Write(b []byte) (int, error) { c.ExtendDeadline() return c.Conn.Write(b) } // ExtendDeadline attempts to extend the connection's deadline. // Returns false if the connection was already expired. -func (c *TimeoutConn) ExtendDeadline() bool { +func (c *IdleTimeoutConn) ExtendDeadline() bool { now := time.Now() // Check if already expired From eaae9728f73686f4bcd9d630da9cee75e02f4cf1 Mon Sep 17 00:00:00 2001 From: xvzc Date: Sat, 17 Jan 2026 23:50:51 +0900 Subject: [PATCH 13/39] refactor(config): rename UDPTimeout to UDPIdleTimeout - ConnOptions.UDPTimeout -> ConnOptions.UDPIdleTimeout - CLI flag: --udp-timeout -> --udp-idle-timeout - TOML key: udp-timeout -> udp-idle-timeout - Updated all usages across codebase and tests - Clearer naming to indicate idle-based timeout behavior --- cmd/spoofdpi/main.go | 6 +++--- cmd/spoofdpi/main_test.go | 6 +++--- internal/config/cli.go | 8 ++++---- internal/config/cli_test.go | 12 ++++++------ internal/config/config.go | 2 +- internal/config/toml_test.go | 4 ++-- internal/config/types.go | 10 +++++----- internal/config/types_test.go | 12 ++++++------ internal/server/tun/udp.go | 6 +++--- 9 files changed, 33 insertions(+), 33 deletions(-) diff --git a/cmd/spoofdpi/main.go b/cmd/spoofdpi/main.go index 5836009d..52b52566 100644 --- a/cmd/spoofdpi/main.go +++ b/cmd/spoofdpi/main.go @@ -118,10 +118,10 @@ func runApp(ctx context.Context, configDir string, cfg *config.Config) { Str("value", fmt.Sprintf("%dms", cfg.Conn.TCPTimeout.Milliseconds())). Msgf("tcp connection timeout") } - if *cfg.Conn.UDPTimeout > 0 { + if *cfg.Conn.UDPIdleTimeout > 0 { logger.Info(). - Str("value", fmt.Sprintf("%dms", cfg.Conn.UDPTimeout.Milliseconds())). - Msgf("udp connection timeout") + Str("value", fmt.Sprintf("%dms", cfg.Conn.UDPIdleTimeout.Milliseconds())). + Msgf("udp idle timeout") } logger.Info().Msgf("app-mode; %s", cfg.App.Mode.String()) diff --git a/cmd/spoofdpi/main_test.go b/cmd/spoofdpi/main_test.go index 9334512e..20fd04f2 100644 --- a/cmd/spoofdpi/main_test.go +++ b/cmd/spoofdpi/main_test.go @@ -25,7 +25,7 @@ func TestCreateResolver(t *testing.T) { cfg.Conn = &config.ConnOptions{ DNSTimeout: lo.ToPtr(time.Duration(0)), TCPTimeout: lo.ToPtr(time.Duration(0)), - UDPTimeout: lo.ToPtr(time.Duration(0)), + UDPIdleTimeout: lo.ToPtr(time.Duration(0)), } logger := zerolog.Nop() @@ -49,7 +49,7 @@ func TestCreateProxy_NoPcap(t *testing.T) { DefaultFakeTTL: lo.ToPtr(uint8(64)), DNSTimeout: lo.ToPtr(time.Duration(0)), TCPTimeout: lo.ToPtr(time.Duration(0)), - UDPTimeout: lo.ToPtr(time.Duration(0)), + UDPIdleTimeout: lo.ToPtr(time.Duration(0)), } // HTTPS Config (Ensure FakeCount is 0 to disable PCAP) @@ -98,7 +98,7 @@ func TestCreateProxy_WithPolicy(t *testing.T) { DefaultFakeTTL: lo.ToPtr(uint8(64)), DNSTimeout: lo.ToPtr(time.Duration(0)), TCPTimeout: lo.ToPtr(time.Duration(0)), - UDPTimeout: lo.ToPtr(time.Duration(0)), + UDPIdleTimeout: lo.ToPtr(time.Duration(0)), } // HTTPS Config diff --git a/internal/config/cli.go b/internal/config/cli.go index bf035404..d84b7fc4 100644 --- a/internal/config/cli.go +++ b/internal/config/cli.go @@ -297,18 +297,18 @@ func CreateCommand( }, &cli.Int64Flag{ - Name: "udp-timeout", + Name: "udp-idle-timeout", Usage: fmt.Sprintf(` - Timeout for udp connection in milliseconds. + Idle timeout for udp connection in milliseconds. No effect when the value is 0 (default: %v, max: %v)`, - defaultCfg.Conn.UDPTimeout.Milliseconds(), + defaultCfg.Conn.UDPIdleTimeout.Milliseconds(), math.MaxUint16, ), Value: 0, OnlyOnce: true, Validator: checkUint16, Action: func(ctx context.Context, cmd *cli.Command, v int64) error { - argsCfg.Conn.UDPTimeout = lo.ToPtr( + argsCfg.Conn.UDPIdleTimeout = lo.ToPtr( time.Duration(v * int64(time.Millisecond)), ) return nil diff --git a/internal/config/cli_test.go b/internal/config/cli_test.go index 3c0ee8e7..52a9cd16 100644 --- a/internal/config/cli_test.go +++ b/internal/config/cli_test.go @@ -31,7 +31,7 @@ func TestCreateCommand_Flags(t *testing.T) { assert.Equal(t, uint8(8), *cfg.Conn.DefaultFakeTTL) assert.Equal(t, int64(5000), cfg.Conn.DNSTimeout.Milliseconds()) assert.Equal(t, int64(10000), cfg.Conn.TCPTimeout.Milliseconds()) - assert.Equal(t, int64(25000), cfg.Conn.UDPTimeout.Milliseconds()) + assert.Equal(t, int64(25000), cfg.Conn.UDPIdleTimeout.Milliseconds()) assert.Equal(t, "8.8.8.8:53", cfg.DNS.Addr.String()) assert.Equal(t, DNSModeUDP, *cfg.DNS.Mode) assert.Equal(t, "https://dns.google/dns-query", *cfg.DNS.HTTPSURL) @@ -59,7 +59,7 @@ func TestCreateCommand_Flags(t *testing.T) { "--default-fake-ttl", "128", "--dns-timeout", "5000", "--tcp-timeout", "5000", - "--udp-timeout", "5000", + "--udp-idle-timeout", "5000", "--dns-addr", "1.1.1.1:53", "--dns-mode", "https", "--dns-https-url", "https://cloudflare-dns.com/dns-query", @@ -86,7 +86,7 @@ func TestCreateCommand_Flags(t *testing.T) { assert.Equal(t, uint8(128), *cfg.Conn.DefaultFakeTTL) assert.Equal(t, 5000*time.Millisecond, *cfg.Conn.DNSTimeout) assert.Equal(t, 5000*time.Millisecond, *cfg.Conn.TCPTimeout) - assert.Equal(t, 5000*time.Millisecond, *cfg.Conn.UDPTimeout) + assert.Equal(t, 5000*time.Millisecond, *cfg.Conn.UDPIdleTimeout) // DNS assert.Equal(t, "1.1.1.1:53", cfg.DNS.Addr.String()) @@ -189,7 +189,7 @@ func TestCreateCommand_OverrideTOML(t *testing.T) { listen-addr = "127.0.0.1:8080" dns-timeout = 1000 tcp-timeout = 1000 - udp-timeout = 1000 + udp-idle-timeout = 1000 default-fake-ttl = 100 [dns] @@ -257,7 +257,7 @@ func TestCreateCommand_OverrideTOML(t *testing.T) { "--listen-addr", "127.0.0.1:9090", "--dns-timeout", "2000", "--tcp-timeout", "2000", - "--udp-timeout", "2000", + "--udp-idle-timeout", "2000", "--default-fake-ttl", "200", "--dns-addr", "1.1.1.1:53", "--dns-cache=false", @@ -289,7 +289,7 @@ func TestCreateCommand_OverrideTOML(t *testing.T) { assert.Equal(t, "127.0.0.1:9090", capturedCfg.App.ListenAddr.String()) assert.Equal(t, 2000*time.Millisecond, *capturedCfg.Conn.DNSTimeout) assert.Equal(t, 2000*time.Millisecond, *capturedCfg.Conn.TCPTimeout) - assert.Equal(t, 2000*time.Millisecond, *capturedCfg.Conn.UDPTimeout) + assert.Equal(t, 2000*time.Millisecond, *capturedCfg.Conn.UDPIdleTimeout) assert.Equal(t, uint8(200), *capturedCfg.Conn.DefaultFakeTTL) // DNS diff --git a/internal/config/config.go b/internal/config/config.go index 06fdc77c..b210f519 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -143,7 +143,7 @@ func getDefault() *Config { //exhaustruct:enforce DefaultFakeTTL: lo.ToPtr(uint8(8)), DNSTimeout: lo.ToPtr(time.Duration(5000) * time.Millisecond), TCPTimeout: lo.ToPtr(time.Duration(10000) * time.Millisecond), - UDPTimeout: lo.ToPtr(time.Duration(25000) * time.Millisecond), + UDPIdleTimeout: lo.ToPtr(time.Duration(25000) * time.Millisecond), }, DNS: &DNSOptions{ Mode: lo.ToPtr(DNSModeUDP), diff --git a/internal/config/toml_test.go b/internal/config/toml_test.go index 6757a2c0..a9eb565a 100644 --- a/internal/config/toml_test.go +++ b/internal/config/toml_test.go @@ -423,7 +423,7 @@ func TestFromTomlFile(t *testing.T) { [connection] dns-timeout = 1000 tcp-timeout = 1000 - udp-timeout = 1000 + udp-idle-timeout = 1000 default-fake-ttl = 100 [dns] @@ -487,7 +487,7 @@ func TestFromTomlFile(t *testing.T) { assert.Equal(t, "127.0.0.1:8080", cfg.App.ListenAddr.String()) assert.Equal(t, time.Duration(1000*time.Millisecond), *cfg.Conn.DNSTimeout) assert.Equal(t, time.Duration(1000*time.Millisecond), *cfg.Conn.TCPTimeout) - assert.Equal(t, time.Duration(1000*time.Millisecond), *cfg.Conn.UDPTimeout) + assert.Equal(t, time.Duration(1000*time.Millisecond), *cfg.Conn.UDPIdleTimeout) assert.Equal(t, zerolog.DebugLevel, *cfg.App.LogLevel) assert.True(t, *cfg.App.Silent) assert.True(t, *cfg.App.SetNetworkConfig) diff --git a/internal/config/types.go b/internal/config/types.go index 0f6f9c5a..8286c601 100644 --- a/internal/config/types.go +++ b/internal/config/types.go @@ -137,7 +137,7 @@ type ConnOptions struct { DefaultFakeTTL *uint8 `toml:"default-fake-ttl"` DNSTimeout *time.Duration `toml:"dns-timeout"` TCPTimeout *time.Duration `toml:"tcp-timeout"` - UDPTimeout *time.Duration `toml:"udp-timeout"` + UDPIdleTimeout *time.Duration `toml:"udp-idle-timeout"` } func (o *ConnOptions) UnmarshalTOML(data any) (err error) { @@ -177,14 +177,14 @@ func (o *ConnOptions) UnmarshalTOML(data any) (err error) { } if p := findFrom( v, - "udp-timeout", + "udp-idle-timeout", parseIntFn[uint16](checkUint16), &err, ); isOk( p, err, ) { - o.UDPTimeout = lo.ToPtr(time.Duration(*p) * time.Millisecond) + o.UDPIdleTimeout = lo.ToPtr(time.Duration(*p) * time.Millisecond) } return err @@ -199,7 +199,7 @@ func (o *ConnOptions) Clone() *ConnOptions { DefaultFakeTTL: clonePrimitive(o.DefaultFakeTTL), DNSTimeout: clonePrimitive(o.DNSTimeout), TCPTimeout: clonePrimitive(o.TCPTimeout), - UDPTimeout: clonePrimitive(o.UDPTimeout), + UDPIdleTimeout: clonePrimitive(o.UDPIdleTimeout), } } @@ -216,7 +216,7 @@ func (origin *ConnOptions) Merge(overrides *ConnOptions) *ConnOptions { DefaultFakeTTL: lo.CoalesceOrEmpty(overrides.DefaultFakeTTL, origin.DefaultFakeTTL), DNSTimeout: lo.CoalesceOrEmpty(overrides.DNSTimeout, origin.DNSTimeout), TCPTimeout: lo.CoalesceOrEmpty(overrides.TCPTimeout, origin.TCPTimeout), - UDPTimeout: lo.CoalesceOrEmpty(overrides.UDPTimeout, origin.UDPTimeout), + UDPIdleTimeout: lo.CoalesceOrEmpty(overrides.UDPIdleTimeout, origin.UDPIdleTimeout), } } diff --git a/internal/config/types_test.go b/internal/config/types_test.go index b806114e..8de5df49 100644 --- a/internal/config/types_test.go +++ b/internal/config/types_test.go @@ -160,14 +160,14 @@ func TestConnOptions_UnmarshalTOML(t *testing.T) { "default-fake-ttl": int64(64), "dns-timeout": int64(1000), "tcp-timeout": int64(1000), - "udp-timeout": int64(1000), + "udp-idle-timeout": int64(1000), }, wantErr: false, assert: func(t *testing.T, o ConnOptions) { assert.Equal(t, uint8(64), *o.DefaultFakeTTL) assert.Equal(t, 1000*time.Millisecond, *o.DNSTimeout) assert.Equal(t, 1000*time.Millisecond, *o.TCPTimeout) - assert.Equal(t, 1000*time.Millisecond, *o.UDPTimeout) + assert.Equal(t, 1000*time.Millisecond, *o.UDPIdleTimeout) }, }, { @@ -212,14 +212,14 @@ func TestConnOptions_Clone(t *testing.T) { DefaultFakeTTL: lo.ToPtr(uint8(64)), DNSTimeout: lo.ToPtr(time.Duration(1000) * time.Millisecond), TCPTimeout: lo.ToPtr(time.Duration(1000) * time.Millisecond), - UDPTimeout: lo.ToPtr(time.Duration(1000) * time.Millisecond), + UDPIdleTimeout: lo.ToPtr(time.Duration(1000) * time.Millisecond), }, assert: func(t *testing.T, input *ConnOptions, output *ConnOptions) { assert.NotNil(t, output) assert.Equal(t, uint8(64), *output.DefaultFakeTTL) assert.Equal(t, 1000*time.Millisecond, *output.DNSTimeout) assert.Equal(t, 1000*time.Millisecond, *output.TCPTimeout) - assert.Equal(t, 1000*time.Millisecond, *output.UDPTimeout) + assert.Equal(t, 1000*time.Millisecond, *output.UDPIdleTimeout) assert.NotSame(t, input, output) }, }, @@ -262,7 +262,7 @@ func TestConnOptions_Merge(t *testing.T) { DefaultFakeTTL: lo.ToPtr(uint8(64)), DNSTimeout: lo.ToPtr(time.Duration(1000) * time.Millisecond), TCPTimeout: lo.ToPtr(time.Duration(1000) * time.Millisecond), - UDPTimeout: lo.ToPtr(time.Duration(1000) * time.Millisecond), + UDPIdleTimeout: lo.ToPtr(time.Duration(1000) * time.Millisecond), }, override: &ConnOptions{ DefaultFakeTTL: lo.ToPtr(uint8(128)), @@ -271,7 +271,7 @@ func TestConnOptions_Merge(t *testing.T) { assert.Equal(t, uint8(128), *output.DefaultFakeTTL) assert.Equal(t, 1000*time.Millisecond, *output.DNSTimeout) assert.Equal(t, 1000*time.Millisecond, *output.TCPTimeout) - assert.Equal(t, 1000*time.Millisecond, *output.UDPTimeout) + assert.Equal(t, 1000*time.Millisecond, *output.UDPIdleTimeout) }, }, } diff --git a/internal/server/tun/udp.go b/internal/server/tun/udp.go index e417b197..69038dc9 100644 --- a/internal/server/tun/udp.go +++ b/internal/server/tun/udp.go @@ -91,9 +91,9 @@ func (h *UDPHandler) Handle(ctx context.Context, lConn net.Conn, rule *config.Ru // Add to pool (pool handles LRU eviction and deadline) rConn := h.pool.Add(key, rawConn) - // Wrap lConn with TimeoutConn as well - timeout := *connOpts.UDPTimeout - lConnWrapped := &netutil.TimeoutConn{Conn: lConn, Timeout: timeout} + // Wrap lConn with IdleTimeoutConn as well + timeout := *connOpts.UDPIdleTimeout + lConnWrapped := &netutil.IdleTimeoutConn{Conn: lConn, Timeout: timeout} // Desync _, _ = h.desyncer.Desync(ctx, lConnWrapped, rConn, udpOpts) From fae0abf163f789f15ad985fb018d63bfd97dd911 Mon Sep 17 00:00:00 2001 From: xvzc Date: Wed, 18 Mar 2026 11:58:13 +0900 Subject: [PATCH 14/39] feat: add support for socks5, tun modes --- cmd/spoofdpi/main.go | 11 ++- cmd/spoofdpi/main_test.go | 13 ++- docs/user-guide/{general.md => app.md} | 70 ++++++++++++++-- docs/user-guide/connection.md | 106 ++++++++++++++++++++++++ docs/user-guide/https.md | 49 ++++++++++- docs/user-guide/overview.md | 19 +++-- docs/user-guide/policy.md | 33 +------- docs/user-guide/server.md | 79 ------------------ docs/user-guide/udp.md | 58 +++++++++++++ internal/config/cli.go | 13 --- internal/config/cli_test.go | 15 +--- internal/config/config.go | 5 +- internal/config/config_test.go | 4 +- internal/config/toml_test.go | 4 +- internal/config/types.go | 13 +-- internal/config/types_test.go | 21 ++--- internal/config/validate_test.go | 28 +++++++ internal/server/http/https.go | 4 - internal/server/http/server.go | 23 ----- internal/server/socks5/server.go | 33 -------- internal/server/socks5/udp_associate.go | 26 +++++- mkdocs.yml | 5 +- 22 files changed, 374 insertions(+), 258 deletions(-) rename docs/user-guide/{general.md => app.md} (62%) create mode 100644 docs/user-guide/connection.md delete mode 100644 docs/user-guide/server.md create mode 100644 docs/user-guide/udp.md diff --git a/cmd/spoofdpi/main.go b/cmd/spoofdpi/main.go index 52b52566..9e5f07ff 100644 --- a/cmd/spoofdpi/main.go +++ b/cmd/spoofdpi/main.go @@ -104,10 +104,6 @@ func runApp(ctx context.Context, configDir string, cfg *config.Config) { Uint8("count", uint8(*cfg.HTTPS.FakeCount)). Msg(" fake") - logger.Info(). - Bool("auto", *cfg.Policy.Auto). - Msgf("policy") - if *cfg.Conn.DNSTimeout > 0 { logger.Info(). Str("value", fmt.Sprintf("%dms", cfg.Conn.DNSTimeout.Milliseconds())). @@ -354,9 +350,16 @@ func createServer( cfg.Conn.Clone(), cfg.HTTPS.Clone(), ) + udpDesyncer := desync.NewUDPDesyncer( + logging.WithScope(logger, "dsn"), + udpWriter, + udpSniffer, + ) udpAssociateHandler := socks5.NewUdpAssociateHandler( logging.WithScope(logger, "hnd"), netutil.NewConnPool(4096, 60*time.Second), + udpDesyncer, + cfg.UDP.Clone(), ) bindHandler := socks5.NewBindHandler(logging.WithScope(logger, "hnd")) diff --git a/cmd/spoofdpi/main_test.go b/cmd/spoofdpi/main_test.go index 20fd04f2..135a11b2 100644 --- a/cmd/spoofdpi/main_test.go +++ b/cmd/spoofdpi/main_test.go @@ -23,8 +23,8 @@ func TestCreateResolver(t *testing.T) { Cache: lo.ToPtr(true), } cfg.Conn = &config.ConnOptions{ - DNSTimeout: lo.ToPtr(time.Duration(0)), - TCPTimeout: lo.ToPtr(time.Duration(0)), + DNSTimeout: lo.ToPtr(time.Duration(0)), + TCPTimeout: lo.ToPtr(time.Duration(0)), UDPIdleTimeout: lo.ToPtr(time.Duration(0)), } @@ -49,7 +49,7 @@ func TestCreateProxy_NoPcap(t *testing.T) { DefaultFakeTTL: lo.ToPtr(uint8(64)), DNSTimeout: lo.ToPtr(time.Duration(0)), TCPTimeout: lo.ToPtr(time.Duration(0)), - UDPIdleTimeout: lo.ToPtr(time.Duration(0)), + UDPIdleTimeout: lo.ToPtr(time.Duration(0)), } // HTTPS Config (Ensure FakeCount is 0 to disable PCAP) @@ -63,9 +63,7 @@ func TestCreateProxy_NoPcap(t *testing.T) { } // Policy Config - cfg.Policy = &config.PolicyOptions{ - Auto: lo.ToPtr(false), - } + cfg.Policy = &config.PolicyOptions{} // DNS Config cfg.DNS = &config.DNSOptions{ @@ -98,7 +96,7 @@ func TestCreateProxy_WithPolicy(t *testing.T) { DefaultFakeTTL: lo.ToPtr(uint8(64)), DNSTimeout: lo.ToPtr(time.Duration(0)), TCPTimeout: lo.ToPtr(time.Duration(0)), - UDPIdleTimeout: lo.ToPtr(time.Duration(0)), + UDPIdleTimeout: lo.ToPtr(time.Duration(0)), } // HTTPS Config @@ -108,7 +106,6 @@ func TestCreateProxy_WithPolicy(t *testing.T) { // Policy Config with one override cfg.Policy = &config.PolicyOptions{ - Auto: lo.ToPtr(false), Overrides: []config.Rule{ { Name: lo.ToPtr("test-rule"), diff --git a/docs/user-guide/general.md b/docs/user-guide/app.md similarity index 62% rename from docs/user-guide/general.md rename to docs/user-guide/app.md index bb2bbf74..835ff353 100644 --- a/docs/user-guide/general.md +++ b/docs/user-guide/app.md @@ -1,6 +1,60 @@ -# General Configuration +# App Configuration -General settings for the application, including logging and system integration. +Application-level settings including mode, logging, and system integration. + +## `app-mode` + +`type: string` + +### Description + +Specifies the proxy mode. `(default: "http")` + +### Allowed Values + +- `http`: HTTP proxy mode +- `socks5`: SOCKS5 proxy mode +- `tun`: TUN interface mode (transparent proxy) + +### Usage + +**Command-Line Flag** +```console +$ spoofdpi --app-mode socks5 +``` + +**TOML Config** +```toml +[app] +mode = "socks5" +``` + +--- + +## `listen-addr` + +`type: ` + +### Description + +Specifies the IP address and port to listen on. `(default: 127.0.0.1:8080 for http, 127.0.0.1:1080 for socks5)` + +If you want to run SpoofDPI remotely (e.g., on a physically separated machine), set the IP part to `0.0.0.0`. Otherwise, it is recommended to leave this option as default for security. + +### Usage + +**Command-Line Flag** +```console +$ spoofdpi --listen-addr "0.0.0.0:8080" +``` + +**TOML Config** +```toml +[app] +listen-addr = "0.0.0.0:8080" +``` + +--- ## `log-level` @@ -21,7 +75,7 @@ $ spoofdpi --log-level trace **TOML Config** ```toml -[general] +[app] log-level = "trace" ``` @@ -44,13 +98,13 @@ $ spoofdpi --silent **TOML Config** ```toml -[general] +[app] silent = true ``` --- -## `system-proxy` +## `network-config` `type: boolean` @@ -65,13 +119,13 @@ Specifies whether to automatically set up the system-wide proxy configuration. ` **Command-Line Flag** ```console -$ spoofdpi --system-proxy +$ spoofdpi --network-config ``` **TOML Config** ```toml -[general] -system-proxy = true +[app] +network-config = true ``` --- diff --git a/docs/user-guide/connection.md b/docs/user-guide/connection.md new file mode 100644 index 00000000..71f7db06 --- /dev/null +++ b/docs/user-guide/connection.md @@ -0,0 +1,106 @@ +# Connection Configuration + +Settings for network connection timeouts and packet configuration. + +## `dns-timeout` + +`type: uint16` + +### Description + +Specifies the timeout (in milliseconds) for DNS connections. `(default: 5000, max: 65535)` + +A value of `0` means no timeout. + +### Usage + +**Command-Line Flag** +```console +$ spoofdpi --dns-timeout 3000 +``` + +**TOML Config** +```toml +[connection] +dns-timeout = 3000 +``` + +--- + +## `tcp-timeout` + +`type: uint16` + +### Description + +Specifies the timeout (in milliseconds) for TCP connections. `(default: 10000, max: 65535)` + +A value of `0` means no timeout. + +### Usage + +**Command-Line Flag** +```console +$ spoofdpi --tcp-timeout 5000 +``` + +**TOML Config** +```toml +[connection] +tcp-timeout = 5000 +``` + +--- + +## `udp-idle-timeout` + +`type: uint16` + +### Description + +Specifies the idle timeout (in milliseconds) for UDP connections. `(default: 25000, max: 65535)` + +The connection will be closed if there is no read/write activity for this duration. Each read or write operation resets the timeout. + +A value of `0` means no timeout. + +### Usage + +**Command-Line Flag** +```console +$ spoofdpi --udp-idle-timeout 30000 +``` + +**TOML Config** +```toml +[connection] +udp-idle-timeout = 30000 +``` + +--- + +## `default-fake-ttl` + +`type: uint8` + +### Description + +Specifies the default [Time To Live (TTL)](https://en.wikipedia.org/wiki/Time_to_live) value for fake packets. `(default: 8)` + +This value is used for fake packets sent during disorder strategies. A lower value ensures fake packets expire before reaching the destination, while the real packets arrive successfully. + +!!! note + The fake TTL should be less than the number of hops to the destination. + +### Usage + +**Command-Line Flag** +```console +$ spoofdpi --default-fake-ttl 10 +``` + +**TOML Config** +```toml +[connection] +default-fake-ttl = 10 +``` diff --git a/docs/user-guide/https.md b/docs/user-guide/https.md index c56386fb..4bcf51c9 100644 --- a/docs/user-guide/https.md +++ b/docs/user-guide/https.md @@ -16,6 +16,7 @@ Specifies the default packet fragmentation strategy to use for the Client Hello - `random`: Splits the packet at a random position. - `chunk`: Splits the packet into fixed-size chunks (controlled by `https-chunk-size`). - `first-byte`: Splits only the first byte of the packet. +- `custom`: Uses custom segment plans (TOML only). - `none`: Disables fragmentation. ### Usage @@ -39,7 +40,7 @@ split-mode = "sni" ### Description -Specifies the chunk size in bytes for packet fragmentation. `(default: 0, max: 255)` +Specifies the chunk size in bytes for packet fragmentation. `(default: 35, max: 255)` This value is only applied when `https-split-mode` is set to `chunk`. Try lower values if the default fails to bypass the DPI. @@ -129,10 +130,10 @@ The value should be a sequence of bytes representing a valid (or semi-valid) TLS ### Usage **Command-Line Flag** -Provide a comma-separated string of hexadecimal bytes (e.g., `16,03,01,...`). +Provide a comma-separated string of hexadecimal bytes (e.g., `0x16, 0x03, 0x01, ...`). ```console -$ spoofdpi --https-fake-packet "16,03,01,00,a1,..." +$ spoofdpi --https-fake-packet "0x16, 0x03, 0x01, 0x00, 0xa1, ..." ``` **TOML Config** @@ -165,3 +166,45 @@ $ spoofdpi --https-skip [https] skip = true ``` + +--- + +## `custom-segments` + +`type: array of segment plans` + +### Description + +Defines custom segmentation plans for fine-grained control over how the Client Hello packet is split. This option is only used when `split-mode` is set to `"custom"`. + +Each segment plan specifies where to split the packet relative to a reference point (`from`) with an offset (`at`). + +!!! note + This option is only available via the TOML config file. + +!!! important + When using `custom` split-mode, the global `disorder` option is **ignored**. Use the `lazy` field in each segment plan to control packet ordering instead. + +### Segment Plan Fields + +| Field | Type | Required | Description | +| :------ | :------ | :------- | :---------- | +| `from` | String | Yes | Reference point: `"head"` (start of packet) or `"sni"` (start of SNI extension). | +| `at` | Int | Yes | Byte offset from the reference point. Negative values count backwards. | +| `lazy` | Boolean | No | If `true`, delays sending this segment (equivalent to disorder). `(default: false)` | +| `noise` | Int | No | Adds random noise (in bytes) to the split position. `(default: 0)` | + +### Usage + +**TOML Config** + +```toml +[https] +split-mode = "custom" +custom-segments = [ + { from = "head", at = 2 }, # Split 2 bytes from start + { from = "sni", at = 0 }, # Split at SNI start + { from = "sni", at = -5, lazy = true }, # Split 5 bytes before SNI, delayed + { from = "head", at = 100, noise = 10 }, # Split at ~100 bytes with ±10 noise +] +``` diff --git a/docs/user-guide/overview.md b/docs/user-guide/overview.md index b6cac093..2122d9db 100644 --- a/docs/user-guide/overview.md +++ b/docs/user-guide/overview.md @@ -25,14 +25,15 @@ If a specific path is not provided via a `--config` flag, SpoofDPI will search f ## Options -The configuration is organized into five main categories. Click on each category to view detailed options. +The configuration is organized into six main categories. Click on each category to view detailed options. | Category | Description | | :--- | :--- | -| **[General](general.md)** | General application options (logging, system proxy, etc.). | -| **[Server](server.md)** | Server connection options (address, timeout). | +| **[App](app.md)** | Application-level options (mode, address, logging, etc.). | +| **[Connection](connection.md)** | Connection timeout and packet TTL settings. | | **[DNS](dns.md)** | DNS resolution options. | | **[HTTPS](https.md)** | HTTPS/TLS packet manipulation options. | +| **[UDP](udp.md)** | UDP packet manipulation options. | | **[Policy](policy.md)** | Rule-based routing and automatic bypass policies. | ## Example @@ -44,7 +45,7 @@ The following two methods will achieve the exact same configuration. All settings are passed directly via flags. ```console -$ spoofdpi --dns-addr "1.1.1.1:53" --dns-https-url "https://dns.google/dns-query" --dns-mode "https" +$ spoofdpi --app-mode socks5 --dns-mode https --https-disorder ``` ### Method 2: Using a TOML Config File @@ -52,10 +53,14 @@ $ spoofdpi --dns-addr "1.1.1.1:53" --dns-https-url "https://dns.google/dns-query Place the settings in your `spoofdpi.toml` file: ```toml +[app] +mode = "socks5" + [dns] - addr = "1.1.1.1:53" - https-url = "https://dns.google/dns-query" - mode = "https" +mode = "https" + +[https] +disorder = true ``` Then, run spoofdpi without those flags (it will automatically load the file if placed in a standard path): diff --git a/docs/user-guide/policy.md b/docs/user-guide/policy.md index 09111529..1b6a52cf 100644 --- a/docs/user-guide/policy.md +++ b/docs/user-guide/policy.md @@ -2,51 +2,24 @@ By defining rules within the Policy section, you can granularly control how SpoofDPI handles connections to specific domains or IP addresses. You can define per-domain bypass strategies, DNS settings, or simply block connections. -## `auto` - -`type: boolean` - -### Description - -Automatically detect blocked sites and add them to the bypass list. `(default: false)` - -When enabled, SpoofDPI attempts to detect if a connection is being blocked and temporarily applies bypass rules for that destination. These generated rules utilize the configuration defined in `[policy.template]`. - -### Usage - -**Command-Line Flag** -```console -$ spoofdpi --policy-auto -``` - -**TOML Config** -```toml -[policy] -auto = true -``` - ---- - ## `template` -The `[policy.template]` section defines the default behavior for rules automatically generated when `auto = true`. If you enable automatic detection, you should configure this template to ensure the generated rules effectively bypass the DPI. +The `[policy.template]` section defines a base rule configuration. This template can be cloned and customized when programmatically adding rules. !!! note The template configuration is only available via the TOML config file. ### Structure -The template uses the same `Rule` structure as overrides, but typically only the `https` and `dns` sections are relevant, as the `match` criteria are determined dynamically. +The template uses the same `Rule` structure as overrides, but typically only the `https` and `dns` sections are relevant. ### Example ```toml [policy] - auto = true - - # This configuration is applied to automatically detected blocked sites [policy.template] https = { fake-count = 7, disorder = true } + dns = { mode = "https" } ``` --- diff --git a/docs/user-guide/server.md b/docs/user-guide/server.md deleted file mode 100644 index 8bdb6590..00000000 --- a/docs/user-guide/server.md +++ /dev/null @@ -1,79 +0,0 @@ -# Server Configuration - -Settings related to the proxy server connection and listener. - -## `listen-addr` - -`type: ` - -### Description - -Specifies the IP address and port to listen on. `(default: 127.0.0.1:8080)` - -If you want to run SpoofDPI remotely (e.g., on a physically separated machine), set the IP part to `0.0.0.0`. Otherwise, it is recommended to leave this option as default for security. - -### Usage - -**Command-Line Flag** -```console -$ spoofdpi --listen-addr "0.0.0.0:8080" -``` - -**TOML Config** -```toml -[server] -listen-addr = "0.0.0.0:8080" -``` - ---- - -## `timeout` - -`type: uint16` - -### Description - -Specifies the timeout (in milliseconds) for every TCP connection. `(default: 0, max: 65535)` - -A value of `0` means no timeout. You can set this option if you know what you are doing, but in most cases, leaving this option unset is recommended. - -### Usage - -**Command-Line Flag** -```console -$ spoofdpi --timeout 5000 -``` - -**TOML Config** -```toml -[server] -timeout = 5000 -``` - ---- - -## `default-ttl` - -`type: uint8` - -### Description - -Specifies the default [Time To Live (TTL)](https://en.wikipedia.org/wiki/Time_to_live) value for outgoing packets. `(default: 64)` - -This value is used to restore the TTL to its default state after applying disorder strategies. Changing this option is generally not required. - -!!! note - The default TTL value for macOS and Linux is usually `64`. - -### Usage - -**Command-Line Flag** -```console -$ spoofdpi --default-ttl 128 -``` - -**TOML Config** -```toml -[server] -default-ttl = 128 -``` diff --git a/docs/user-guide/udp.md b/docs/user-guide/udp.md new file mode 100644 index 00000000..c5f776c9 --- /dev/null +++ b/docs/user-guide/udp.md @@ -0,0 +1,58 @@ +# UDP Configuration + +Settings for UDP packet manipulation and bypass techniques. + +## `fake-count` + +`type: int` + +### Description + +Specifies the number of fake packets to be sent before actual UDP packets. `(default: 0)` + +Sending fake packets can trick DPI systems into inspecting invalid traffic, allowing real packets to pass through. + +!!! note + This feature requires root privileges and packet capture capabilities. + +### Usage + +**Command-Line Flag** +```console +$ spoofdpi --udp-fake-count 5 +``` + +**TOML Config** +```toml +[udp] +fake-count = 5 +``` + +--- + +## `fake-packet` + +`type: byte array` + +### Description + +Customizes the content of the fake packets used by `udp-fake-count`. `(default: 64 bytes of zeros)` + +The value should be a sequence of bytes representing the fake packet data. + +### Usage + +**Command-Line Flag** +Provide a comma-separated string of hexadecimal bytes (e.g., `0x00, 0x01, 0x02, ...`). + +```console +$ spoofdpi --udp-fake-packet "0x00, 0x01, 0x02, 0x03, 0x04" +``` + +**TOML Config** +Provide an array of integers (bytes). + +```toml +[udp] +fake-packet = [0x00, 0x01, 0x02, 0x03, 0x04] +``` diff --git a/internal/config/cli.go b/internal/config/cli.go index d84b7fc4..504bfc4a 100644 --- a/internal/config/cli.go +++ b/internal/config/cli.go @@ -342,19 +342,6 @@ func CreateCommand( }, }, - &cli.BoolFlag{ - Name: "policy-auto", - Usage: fmt.Sprintf(` - Automatically detect the blocked sites and add policies (default: %v)`, - *defaultCfg.Policy.Auto, - ), - OnlyOnce: true, - Action: func(ctx context.Context, cmd *cli.Command, v bool) error { - argsCfg.Policy.Auto = lo.ToPtr(v) - return nil - }, - }, - &cli.BoolFlag{ Name: "silent", Usage: fmt.Sprintf(` diff --git a/internal/config/cli_test.go b/internal/config/cli_test.go index 52a9cd16..88aeae6a 100644 --- a/internal/config/cli_test.go +++ b/internal/config/cli_test.go @@ -44,7 +44,6 @@ func TestCreateCommand_Flags(t *testing.T) { assert.False(t, *cfg.HTTPS.Skip) assert.Equal(t, 0, *cfg.UDP.FakeCount) assert.Equal(t, 64, len(cfg.UDP.FakePacket)) - assert.False(t, *cfg.Policy.Auto) }, }, { @@ -73,7 +72,6 @@ func TestCreateCommand_Flags(t *testing.T) { "--https-skip", "--udp-fake-count", "5", "--udp-fake-packet", "0x01, 0x02", - "--policy-auto", }, assert: func(t *testing.T, cfg *Config) { // General @@ -106,10 +104,6 @@ func TestCreateCommand_Flags(t *testing.T) { // UDP assert.Equal(t, 5, *cfg.UDP.FakeCount) assert.Equal(t, []byte{0x01, 0x02}, cfg.UDP.FakePacket) - assert.Equal(t, []byte{0x01, 0x02}, cfg.UDP.FakePacket) - - // Policy - assert.True(t, *cfg.Policy.Auto) }, }, { @@ -180,12 +174,12 @@ func TestCreateCommand_Flags(t *testing.T) { func TestCreateCommand_OverrideTOML(t *testing.T) { tomlContent := ` -[general] +[app] log-level = "debug" silent = true system-proxy = true -[server] +[connection] listen-addr = "127.0.0.1:8080" dns-timeout = 1000 tcp-timeout = 1000 @@ -208,7 +202,6 @@ func TestCreateCommand_OverrideTOML(t *testing.T) { skip = true [policy] - auto = true [[policy.overrides]] name = "test-rule" priority = 100 @@ -272,7 +265,6 @@ func TestCreateCommand_OverrideTOML(t *testing.T) { "--https-skip=false", "--udp-fake-count", "20", "--udp-fake-packet", "0xcc,0xdd", - "--policy-auto=false", } err = cmd.Run(context.Background(), args) @@ -312,9 +304,6 @@ func TestCreateCommand_OverrideTOML(t *testing.T) { assert.Equal(t, []byte{0xcc, 0xdd}, capturedCfg.UDP.FakePacket) assert.Equal(t, []byte{0xcc, 0xdd}, capturedCfg.UDP.FakePacket) - // Policy - assert.False(t, *capturedCfg.Policy.Auto) - // Verify TOML-only fields are preserved require.Len(t, capturedCfg.Policy.Overrides, 1) override := capturedCfg.Policy.Overrides[0] diff --git a/internal/config/config.go b/internal/config/config.go index b210f519..1c70d267 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -22,7 +22,7 @@ type cloner[T any] interface { var _ merger[*Config] = (*Config)(nil) type Config struct { - App *AppOptions `toml:"general"` + App *AppOptions `toml:"app"` Conn *ConnOptions `toml:"connection"` DNS *DNSOptions `toml:"dns"` HTTPS *HTTPSOptions `toml:"https"` @@ -36,7 +36,7 @@ func (c *Config) UnmarshalTOML(data any) (err error) { return fmt.Errorf("non-table type config file") } - c.App = findStructFrom[AppOptions](m, "general", &err) + c.App = findStructFrom[AppOptions](m, "app", &err) c.Conn = findStructFrom[ConnOptions](m, "connection", &err) c.DNS = findStructFrom[DNSOptions](m, "dns", &err) c.HTTPS = findStructFrom[HTTPSOptions](m, "https", &err) @@ -166,7 +166,6 @@ func getDefault() *Config { //exhaustruct:enforce FakePacket: make([]byte, 64), }, Policy: &PolicyOptions{ - Auto: lo.ToPtr(false), Template: &Rule{}, Overrides: []Rule{}, }, diff --git a/internal/config/config_test.go b/internal/config/config_test.go index f42c2720..10b55b35 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -18,7 +18,7 @@ func TestConfig_UnmarshalTOML(t *testing.T) { { name: "valid config", input: map[string]any{ - "general": map[string]any{ + "app": map[string]any{ "listen-addr": "127.0.0.1:9090", }, "dns": map[string]any{ @@ -55,7 +55,7 @@ func TestConfig_UnmarshalTOML(t *testing.T) { { name: "validation error", input: map[string]any{ - "general": map[string]any{ + "app": map[string]any{ "listen-addr": "invalid-addr", }, }, diff --git a/internal/config/toml_test.go b/internal/config/toml_test.go index a9eb565a..de3adf4b 100644 --- a/internal/config/toml_test.go +++ b/internal/config/toml_test.go @@ -414,7 +414,7 @@ func TestFindSliceFrom(t *testing.T) { func TestFromTomlFile(t *testing.T) { t.Run("full valid config", func(t *testing.T) { tomlContent := ` - [general] + [app] log-level = "debug" silent = true network-config = true @@ -442,7 +442,6 @@ func TestFromTomlFile(t *testing.T) { skip = true [policy] - auto = true [[policy.overrides]] name = "test-rule" priority = 100 @@ -498,7 +497,6 @@ func TestFromTomlFile(t *testing.T) { assert.Equal(t, "https://1.1.1.1/dns-query", *cfg.DNS.HTTPSURL) assert.Equal(t, DNSQueryIPv4, *cfg.DNS.QType) assert.Equal(t, uint8(100), *cfg.Conn.DefaultFakeTTL) - assert.True(t, *cfg.Policy.Auto) assert.True(t, *cfg.HTTPS.Disorder) assert.Equal(t, uint8(5), *cfg.HTTPS.FakeCount) assert.Equal(t, []byte{0x01, 0x02, 0x03}, cfg.HTTPS.FakePacket.Raw()) diff --git a/internal/config/types.go b/internal/config/types.go index 8286c601..e9b6214c 100644 --- a/internal/config/types.go +++ b/internal/config/types.go @@ -34,7 +34,14 @@ func clonePrimitive[T primitive](x *T) *T { // └─────────────────┘ var _ merger[*AppOptions] = (*AppOptions)(nil) -var availableLogLevelValues = []string{"info", "warn", "trace", "error", "debug"} +var availableLogLevelValues = []string{ + "info", + "warn", + "trace", + "error", + "debug", + "disabled", +} type AppOptions struct { LogLevel *zerolog.Level `toml:"log-level"` @@ -622,7 +629,6 @@ var ( ) type PolicyOptions struct { - Auto *bool `toml:"auto"` Template *Rule `toml:"template"` Overrides []Rule `toml:"overries"` } @@ -633,7 +639,6 @@ func (o *PolicyOptions) UnmarshalTOML(data any) (err error) { return fmt.Errorf("non-table type policy config") } - o.Auto = findFrom(m, "auto", parseBoolFn(), &err) o.Template = findStructFrom[Rule](m, "template", &err) o.Overrides = findStructSliceFrom[Rule](m, "overrides", &err) @@ -651,7 +656,6 @@ func (o *PolicyOptions) Clone() *PolicyOptions { } return &PolicyOptions{ - Auto: clonePrimitive(o.Auto), Template: o.Template.Clone(), Overrides: overrides, } @@ -667,7 +671,6 @@ func (origin *PolicyOptions) Merge(overrides *PolicyOptions) *PolicyOptions { } return &PolicyOptions{ - Auto: lo.CoalesceOrEmpty(overrides.Auto, origin.Auto), Template: lo.CoalesceOrEmpty(overrides.Template.Clone(), origin.Template.Clone()), Overrides: lo.CoalesceSliceOrEmpty(overrides.Overrides, origin.Overrides), } diff --git a/internal/config/types_test.go b/internal/config/types_test.go index 8de5df49..3c1f818f 100644 --- a/internal/config/types_test.go +++ b/internal/config/types_test.go @@ -160,7 +160,7 @@ func TestConnOptions_UnmarshalTOML(t *testing.T) { "default-fake-ttl": int64(64), "dns-timeout": int64(1000), "tcp-timeout": int64(1000), - "udp-idle-timeout": int64(1000), + "udp-idle-timeout": int64(1000), }, wantErr: false, assert: func(t *testing.T, o ConnOptions) { @@ -212,7 +212,7 @@ func TestConnOptions_Clone(t *testing.T) { DefaultFakeTTL: lo.ToPtr(uint8(64)), DNSTimeout: lo.ToPtr(time.Duration(1000) * time.Millisecond), TCPTimeout: lo.ToPtr(time.Duration(1000) * time.Millisecond), - UDPIdleTimeout: lo.ToPtr(time.Duration(1000) * time.Millisecond), + UDPIdleTimeout: lo.ToPtr(time.Duration(1000) * time.Millisecond), }, assert: func(t *testing.T, input *ConnOptions, output *ConnOptions) { assert.NotNil(t, output) @@ -262,7 +262,7 @@ func TestConnOptions_Merge(t *testing.T) { DefaultFakeTTL: lo.ToPtr(uint8(64)), DNSTimeout: lo.ToPtr(time.Duration(1000) * time.Millisecond), TCPTimeout: lo.ToPtr(time.Duration(1000) * time.Millisecond), - UDPIdleTimeout: lo.ToPtr(time.Duration(1000) * time.Millisecond), + UDPIdleTimeout: lo.ToPtr(time.Duration(1000) * time.Millisecond), }, override: &ConnOptions{ DefaultFakeTTL: lo.ToPtr(uint8(128)), @@ -577,7 +577,6 @@ func TestPolicyOptions_UnmarshalTOML(t *testing.T) { { name: "valid policy options", input: map[string]any{ - "auto": true, "overrides": []map[string]any{ { "name": "rule1", @@ -589,7 +588,6 @@ func TestPolicyOptions_UnmarshalTOML(t *testing.T) { }, wantErr: false, assert: func(t *testing.T, o PolicyOptions) { - assert.True(t, *o.Auto) assert.Len(t, o.Overrides, 1) assert.Equal(t, "rule1", *o.Overrides[0].Name) }, @@ -633,7 +631,6 @@ func TestPolicyOptions_Clone(t *testing.T) { { name: "non-nil receiver", input: &PolicyOptions{ - Auto: lo.ToPtr(true), Overrides: []Rule{ { Name: lo.ToPtr("rule1"), @@ -643,7 +640,6 @@ func TestPolicyOptions_Clone(t *testing.T) { }, assert: func(t *testing.T, input *PolicyOptions, output *PolicyOptions) { assert.NotNil(t, output) - assert.True(t, *output.Auto) assert.Len(t, output.Overrides, 1) assert.NotSame(t, input, output) // Deep copy check for slice @@ -670,31 +666,28 @@ func TestPolicyOptions_Merge(t *testing.T) { { name: "nil receiver", base: nil, - override: &PolicyOptions{Auto: lo.ToPtr(true)}, + override: &PolicyOptions{Overrides: []Rule{{Name: lo.ToPtr("rule1")}}}, assert: func(t *testing.T, output *PolicyOptions) { - assert.True(t, *output.Auto) + assert.Len(t, output.Overrides, 1) }, }, { name: "nil override", - base: &PolicyOptions{Auto: lo.ToPtr(false)}, + base: &PolicyOptions{Overrides: []Rule{{Name: lo.ToPtr("rule1")}}}, override: nil, assert: func(t *testing.T, output *PolicyOptions) { - assert.False(t, *output.Auto) + assert.Len(t, output.Overrides, 1) }, }, { name: "merge values", base: &PolicyOptions{ - Auto: lo.ToPtr(false), Overrides: []Rule{{Name: lo.ToPtr("rule1")}}, }, override: &PolicyOptions{ - Auto: lo.ToPtr(true), Overrides: []Rule{{Name: lo.ToPtr("rule2")}}, }, assert: func(t *testing.T, output *PolicyOptions) { - assert.True(t, *output.Auto) assert.Len(t, output.Overrides, 1) assert.Equal(t, "rule2", *output.Overrides[0].Name) }, diff --git a/internal/config/validate_test.go b/internal/config/validate_test.go index 302c82d9..fa9aa89d 100644 --- a/internal/config/validate_test.go +++ b/internal/config/validate_test.go @@ -305,3 +305,31 @@ func TestCheckRule(t *testing.T) { }) } } + +func TestCheckLogLevel(t *testing.T) { + tcs := []struct { + name string + input string + wantErr bool + }{ + {"valid info", "info", false}, + {"valid debug", "debug", false}, + {"valid warn", "warn", false}, + {"valid error", "error", false}, + {"valid trace", "trace", false}, + {"valid disabled", "disabled", false}, + {"invalid unknown", "unknown", true}, + {"invalid empty", "", true}, + } + + for _, tc := range tcs { + t.Run(tc.name, func(t *testing.T) { + err := checkLogLevel(tc.input) + if tc.wantErr { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + }) + } +} diff --git a/internal/server/http/https.go b/internal/server/http/https.go index 7bc35a6f..502307be 100644 --- a/internal/server/http/https.go +++ b/internal/server/http/https.go @@ -161,9 +161,5 @@ func (h *HTTPSHandler) sendClientHello( msg *proto.TLSMessage, opts *config.HTTPSOptions, ) (int, error) { - if lo.FromPtr(opts.Skip) { - return rConn.Write(msg.Raw()) - } - return h.desyncer.Desync(ctx, h.logger, rConn, msg, opts) } diff --git a/internal/server/http/server.go b/internal/server/http/server.go index 12f8837e..c4f351cb 100644 --- a/internal/server/http/server.go +++ b/internal/server/http/server.go @@ -217,27 +217,4 @@ func (p *HTTPProxy) handleNewConnection(ctx context.Context, conn net.Conn) { } logger.Warn().Err(handleErr).Msg("error handling request") - if !errors.Is(handleErr, netutil.ErrBlocked) { // Early exit if not blocked - return - } - - // ┌─────────────┐ - // │ AUTO config │ - // └─────────────┘ - if bestMatch != nil && logger.GetLevel() == zerolog.TraceLevel { - logger.Info().Msg("skipping auto-config (duplicate policy)") - return - } - - // Perform auto config if enabled and RuleTemplate is not nil - if *p.policyOpts.Auto && p.policyOpts.Template != nil { - newRule := p.policyOpts.Template.Clone() - newRule.Match = &config.MatchAttrs{Domains: []string{host}} - - if err := p.ruleMatcher.Add(newRule); err != nil { - logger.Info().Err(err).Msg("failed to add config automatically") - } else { - logger.Info().Msg("automatically added to config") - } - } } diff --git a/internal/server/socks5/server.go b/internal/server/socks5/server.go index 7bc1748f..a8c9e35a 100644 --- a/internal/server/socks5/server.go +++ b/internal/server/socks5/server.go @@ -223,39 +223,6 @@ func (p *SOCKS5Proxy) handleConnection(ctx context.Context, conn net.Conn) { } logger.Error().Err(err).Msg("failed to handle") - if errors.Is(err, netutil.ErrBlocked) { - p.handleAutoConfig(ctx, req, addrs, bestMatch) - } -} - -func (p *SOCKS5Proxy) handleAutoConfig( - ctx context.Context, - req *proto.SOCKS5Request, - addrs []net.IP, - matchedRule *config.Rule, -) { - logger := zerolog.Ctx(ctx) - - if matchedRule != nil { - logger.Trace().Msg("skipping auto-policy for this request (duplicate policy)") - return - } - - if *p.policyOpts.Auto && p.policyOpts.Template != nil { - newRule := p.policyOpts.Template.Clone() - targetDomain := req.FQDN // req.Domain -> req.FQDN - if targetDomain == "" && len(addrs) > 0 { - targetDomain = addrs[0].String() - } - - newRule.Match = &config.MatchAttrs{Domains: []string{targetDomain}} - - if err := p.ruleMatcher.Add(newRule); err != nil { - logger.Info().Err(err).Msg("failed to add config automatically") - } else { - logger.Info().Msg("automatically added to config") - } - } } func (p *SOCKS5Proxy) negotiate(logger zerolog.Logger, conn net.Conn) error { diff --git a/internal/server/socks5/udp_associate.go b/internal/server/socks5/udp_associate.go index c16d23a7..71d36983 100644 --- a/internal/server/socks5/udp_associate.go +++ b/internal/server/socks5/udp_associate.go @@ -9,23 +9,30 @@ import ( "github.com/rs/zerolog" "github.com/xvzc/SpoofDPI/internal/config" + "github.com/xvzc/SpoofDPI/internal/desync" "github.com/xvzc/SpoofDPI/internal/logging" "github.com/xvzc/SpoofDPI/internal/netutil" "github.com/xvzc/SpoofDPI/internal/proto" ) type UdpAssociateHandler struct { - logger zerolog.Logger - pool *netutil.ConnPool + logger zerolog.Logger + pool *netutil.ConnPool + desyncer *desync.UDPDesyncer + defaultUDPOpts *config.UDPOptions } func NewUdpAssociateHandler( logger zerolog.Logger, pool *netutil.ConnPool, + desyncer *desync.UDPDesyncer, + defaultUDPOpts *config.UDPOptions, ) *UdpAssociateHandler { return &UdpAssociateHandler{ - logger: logger, - pool: pool, + logger: logger, + pool: pool, + desyncer: desyncer, + defaultUDPOpts: defaultUDPOpts, } } @@ -139,6 +146,17 @@ func (h *UdpAssociateHandler) Handle( // Add to pool (pool handles LRU eviction and deadline) conn := h.pool.Add(key, rawConn) + // Apply UDP options from rule if matched + udpOpts := h.defaultUDPOpts.Clone() + if rule != nil && rule.UDP != nil { + udpOpts = udpOpts.Merge(rule.UDP) + } + + // Send fake packets before real payload (UDP desync) + if h.desyncer != nil { + _, _ = h.desyncer.Desync(ctx, lNewConn, conn.Conn, udpOpts) + } + // Start a goroutine to read from the target and forward to the client go func(targetConn *netutil.PooledConn, clientAddr *net.UDPAddr) { respBuf := make([]byte, 65535) diff --git a/mkdocs.yml b/mkdocs.yml index f0c5002c..60d19877 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -56,10 +56,11 @@ nav: - 'User Guide': - 'Overview': user-guide/overview.md - 'Options': - - 'General': user-guide/general.md - - 'Server': user-guide/server.md + - 'App': user-guide/app.md + - 'Connection': user-guide/connection.md - 'DNS': user-guide/dns.md - 'HTTPS': user-guide/https.md + - 'UDP': user-guide/udp.md - 'Policy': user-guide/policy.md - 'Configuration Examples': user-guide/examples.md # - 'Geographical Recipes': From b9d6b930c47e3d90379390895f0bf22699a155e4 Mon Sep 17 00:00:00 2001 From: xvzc Date: Wed, 18 Mar 2026 12:02:50 +0900 Subject: [PATCH 15/39] fix: unsafe casting for trace id --- internal/session/session.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/internal/session/session.go b/internal/session/session.go index 280fcb28..45408056 100644 --- a/internal/session/session.go +++ b/internal/session/session.go @@ -3,7 +3,6 @@ package session import ( "context" "math/rand/v2" - "unsafe" ) // We define unexported key types to prevent key collisions with other packages. @@ -88,5 +87,6 @@ func generateTraceID() string { b[i] = r + 0x30 } - return unsafe.String(unsafe.SliceData(b), 16) + return string(b) + // return unsafe.String(unsafe.SliceData(b), 16) } From 4793a14fb6d43c654c04a1c83858613cb1ab8f34 Mon Sep 17 00:00:00 2001 From: xvzc Date: Wed, 18 Mar 2026 12:08:10 +0900 Subject: [PATCH 16/39] style: reformat --- internal/packet/udp_sniffer.go | 3 ++- internal/server/socks5/bind.go | 3 ++- internal/server/socks5/connect.go | 3 ++- internal/server/tun/server.go | 7 ++++++- 4 files changed, 12 insertions(+), 4 deletions(-) diff --git a/internal/packet/udp_sniffer.go b/internal/packet/udp_sniffer.go index 8f754d16..7cc09d0d 100644 --- a/internal/packet/udp_sniffer.go +++ b/internal/packet/udp_sniffer.go @@ -167,7 +167,8 @@ func generateUdpFilter(linkType layers.LinkType) []BPFInstruction { } else { // Check IP Version == 4 at the base offset // Load byte at baseOffset, mask 0xF0, check if 0x40 - instructions = append(instructions, + instructions = append( + instructions, BPFInstruction{Op: 0x30, Jt: 0, Jf: 0, K: baseOffset}, // Ldb [baseOffset] BPFInstruction{Op: 0x54, Jt: 0, Jf: 0, K: 0xf0}, // And 0xf0 BPFInstruction{Op: 0x15, Jt: 0, Jf: 3, K: 0x40}, // Jeq 0x40, True, False(Skip to End) diff --git a/internal/server/socks5/bind.go b/internal/server/socks5/bind.go index 2e90f884..1db9c011 100644 --- a/internal/server/socks5/bind.go +++ b/internal/server/socks5/bind.go @@ -71,7 +71,8 @@ func (h *BindHandler) Handle( Msg("accepted incoming connection") // 4. Second Reply: Send the address/port of the connecting host - if err := proto.SOCKS5SuccessResponse().Bind(rAddr.IP).Port(rAddr.Port).Write(conn); err != nil { + err = proto.SOCKS5SuccessResponse().Bind(rAddr.IP).Port(rAddr.Port).Write(conn) + if err != nil { logger.Error().Err(err).Msg("failed to write second bind reply") return err } diff --git a/internal/server/socks5/connect.go b/internal/server/socks5/connect.go index fbe0f61d..fc22cb9d 100644 --- a/internal/server/socks5/connect.go +++ b/internal/server/socks5/connect.go @@ -78,7 +78,8 @@ func (h *ConnectHandler) Handle( } // 3. Send Success Response - if err := proto.SOCKS5SuccessResponse().Bind(net.IPv4zero).Port(0).Write(lConn); err != nil { + err = proto.SOCKS5SuccessResponse().Bind(net.IPv4zero).Port(0).Write(lConn) + if err != nil { logger.Error().Err(err).Msg("failed to write socks5 success reply") return err } diff --git a/internal/server/tun/server.go b/internal/server/tun/server.go index 31a5b198..e173c335 100644 --- a/internal/server/tun/server.go +++ b/internal/server/tun/server.go @@ -117,7 +117,12 @@ func (s *TunServer) SetNetworkConfig() error { // This is crucial for the TUN interface to receive packets destined for its own subnet // Calculate the network address for /30 subnet (e.g., 10.0.0.1 -> 10.0.0.0/30) localIP := net.ParseIP(local) - networkAddr := net.IPv4(localIP[12], localIP[13], localIP[14], localIP[15]&0xFC) // Mask with /30 + networkAddr := net.IPv4( + localIP[12], + localIP[13], + localIP[14], + localIP[15]&0xFC, + ) // Mask with /30 if err := SetRoute(s.iface.Name(), []string{networkAddr.String() + "/30"}); err != nil { return fmt.Errorf("failed to set local route: %w", err) } From c803a16b480bbb415ed105b9420bab51be047387 Mon Sep 17 00:00:00 2001 From: xvzc Date: Wed, 18 Mar 2026 12:14:30 +0900 Subject: [PATCH 17/39] style: reformat --- internal/netutil/netutil_linux.go | 6 +++++- internal/packet/udp_sniffer.go | 9 ++++++--- internal/server/socks5/bind.go | 3 ++- 3 files changed, 13 insertions(+), 5 deletions(-) diff --git a/internal/netutil/netutil_linux.go b/internal/netutil/netutil_linux.go index 1e370b34..b903b541 100644 --- a/internal/netutil/netutil_linux.go +++ b/internal/netutil/netutil_linux.go @@ -36,7 +36,11 @@ func bindToInterface(dialer *net.Dialer, iface *net.Interface, targetIP net.IP) } } - return fmt.Errorf("no suitable IP address found on interface %s for target %s", iface.Name, targetIP) + return fmt.Errorf( + "no suitable IP address found on interface %s for target %s", + iface.Name, + targetIP, + ) } // getDefaultGateway parses the system route table to find the default gateway on Linux diff --git a/internal/packet/udp_sniffer.go b/internal/packet/udp_sniffer.go index 7cc09d0d..42975786 100644 --- a/internal/packet/udp_sniffer.go +++ b/internal/packet/udp_sniffer.go @@ -167,11 +167,14 @@ func generateUdpFilter(linkType layers.LinkType) []BPFInstruction { } else { // Check IP Version == 4 at the base offset // Load byte at baseOffset, mask 0xF0, check if 0x40 + // Ldb [baseOffset] + // And 0xf0 + // Jeq 0x40, True, False(Skip to End) instructions = append( instructions, - BPFInstruction{Op: 0x30, Jt: 0, Jf: 0, K: baseOffset}, // Ldb [baseOffset] - BPFInstruction{Op: 0x54, Jt: 0, Jf: 0, K: 0xf0}, // And 0xf0 - BPFInstruction{Op: 0x15, Jt: 0, Jf: 3, K: 0x40}, // Jeq 0x40, True, False(Skip to End) + BPFInstruction{Op: 0x30, Jt: 0, Jf: 0, K: baseOffset}, + BPFInstruction{Op: 0x54, Jt: 0, Jf: 0, K: 0xf0}, + BPFInstruction{Op: 0x15, Jt: 0, Jf: 3, K: 0x40}, ) } diff --git a/internal/server/socks5/bind.go b/internal/server/socks5/bind.go index 1db9c011..b1dce109 100644 --- a/internal/server/socks5/bind.go +++ b/internal/server/socks5/bind.go @@ -45,7 +45,8 @@ func (h *BindHandler) Handle( lAddr := listener.Addr().(*net.TCPAddr) // 2. First Reply: Send the address/port we are listening on - if err := proto.SOCKS5SuccessResponse().Bind(lAddr.IP).Port(lAddr.Port).Write(conn); err != nil { + err = proto.SOCKS5SuccessResponse().Bind(lAddr.IP).Port(lAddr.Port).Write(conn) + if err != nil { logger.Error().Err(err).Msg("failed to write first bind reply") return err } From a657f7aebd08131ca4788f43235be8ef804e4d4f Mon Sep 17 00:00:00 2001 From: xvzc Date: Wed, 18 Mar 2026 12:16:51 +0900 Subject: [PATCH 18/39] style: reformat --- internal/server/socks5/udp_associate.go | 5 +++-- internal/server/tun/network_linux.go | 13 ++++++++++++- internal/server/tun/server.go | 4 +++- 3 files changed, 18 insertions(+), 4 deletions(-) diff --git a/internal/server/socks5/udp_associate.go b/internal/server/socks5/udp_associate.go index 71d36983..cfff6d2d 100644 --- a/internal/server/socks5/udp_associate.go +++ b/internal/server/socks5/udp_associate.go @@ -66,8 +66,9 @@ func (h *UdpAssociateHandler) Handle( Str("bind_addr", lAddr.String()). Msg("socks5 udp associate established") - // 2. Reply with the bound address - if err := proto.SOCKS5SuccessResponse().Bind(lAddr.IP).Port(lAddr.Port).Write(lConn); err != nil { + // 2. Reply with the bound address + err = proto.SOCKS5SuccessResponse().Bind(lAddr.IP).Port(lAddr.Port).Write(lConn) + if err != nil { logger.Error().Err(err).Msg("failed to write socks5 success reply") return err } diff --git a/internal/server/tun/network_linux.go b/internal/server/tun/network_linux.go index 50ce5a44..ba8f9190 100644 --- a/internal/server/tun/network_linux.go +++ b/internal/server/tun/network_linux.go @@ -209,7 +209,18 @@ func SetGatewayRoute(gateway, iface string) error { } // Add default route to the allocated table via the gateway - cmd = exec.Command("ip", "route", "add", "default", "via", gateway, "dev", iface, "table", tableIDStr) + cmd = exec.Command( + "ip", + "route", + "add", + "default", + "via", + gateway, + "dev", + iface, + "table", + tableIDStr, + ) out, err = cmd.CombinedOutput() if err != nil { if !strings.Contains(string(out), "File exists") { diff --git a/internal/server/tun/server.go b/internal/server/tun/server.go index e173c335..1f0411c4 100644 --- a/internal/server/tun/server.go +++ b/internal/server/tun/server.go @@ -123,7 +123,9 @@ func (s *TunServer) SetNetworkConfig() error { localIP[14], localIP[15]&0xFC, ) // Mask with /30 - if err := SetRoute(s.iface.Name(), []string{networkAddr.String() + "/30"}); err != nil { + + err = SetRoute(s.iface.Name(), []string{networkAddr.String() + "/30"}) + if err != nil { return fmt.Errorf("failed to set local route: %w", err) } From 7af0e95c1bc7aca478f7d84b40c136a4877b5df4 Mon Sep 17 00:00:00 2001 From: xvzc Date: Thu, 19 Mar 2026 06:27:22 +0900 Subject: [PATCH 19/39] refactor: optimize cache with generics --- cmd/spoofdpi/main.go | 7 +- go.mod | 2 +- internal/cache/cache.go | 16 +- internal/cache/lru_cache.go | 97 +++++---- internal/cache/ttl_cache.go | 97 ++++++--- internal/cache/ttl_cache_benchmark_test.go | 54 +++++ internal/desync/tls.go | 4 +- internal/desync/udp.go | 5 +- internal/dns/cache.go | 8 +- internal/netutil/conn.go | 74 ++++++- internal/netutil/conn_pool.go | 206 ------------------ internal/netutil/dial.go | 2 +- internal/netutil/dst.go | 52 +++++ internal/netutil/key.go | 80 +++++++ internal/netutil/key_benchmark_test.go | 30 +++ internal/netutil/netutil.go | 26 +-- internal/netutil/{addr.go => route.go} | 56 +---- .../netutil/{netutil_bsd.go => route_bsd.go} | 7 +- .../{netutil_linux.go => route_linux.go} | 23 +- internal/netutil/route_unsupported.go | 23 ++ internal/netutil/session_cache.go | 137 ++++++++++++ internal/packet/sniffer.go | 5 +- internal/packet/tcp_sniffer.go | 31 +-- internal/packet/udp_sniffer.go | 31 +-- internal/packet/udp_writer.go | 2 +- internal/server/socks5/udp_associate.go | 23 +- internal/server/tun/udp.go | 26 +-- 27 files changed, 693 insertions(+), 431 deletions(-) create mode 100644 internal/cache/ttl_cache_benchmark_test.go delete mode 100644 internal/netutil/conn_pool.go create mode 100644 internal/netutil/dst.go create mode 100644 internal/netutil/key.go create mode 100644 internal/netutil/key_benchmark_test.go rename internal/netutil/{addr.go => route.go} (66%) rename internal/netutil/{netutil_bsd.go => route_bsd.go} (91%) rename internal/netutil/{netutil_linux.go => route_linux.go} (71%) create mode 100644 internal/netutil/route_unsupported.go create mode 100644 internal/netutil/session_cache.go diff --git a/cmd/spoofdpi/main.go b/cmd/spoofdpi/main.go index 9e5f07ff..365511d1 100644 --- a/cmd/spoofdpi/main.go +++ b/cmd/spoofdpi/main.go @@ -167,7 +167,7 @@ func createResolver(logger zerolog.Logger, cfg *config.Config) dns.Resolver { cacheResolver := dns.NewCacheResolver( logging.WithScope(logger, "dns"), - cache.NewTTLCache( + cache.NewTTLCache[string]( cache.TTLCacheAttrs{ NumOfShards: 64, CleanupInterval: time.Duration(3 * time.Minute), @@ -245,7 +245,7 @@ func createPacketObjects( Str("mac", gatewayMACStr). Msg(" gateway (passive detection)") - hopCache := cache.NewLRUCache(4096) + hopCache := cache.NewLRUCache[netutil.IPKey](4096, nil) // TCP Objects tcpSniffer := packet.NewTCPSniffer( @@ -357,7 +357,7 @@ func createServer( ) udpAssociateHandler := socks5.NewUdpAssociateHandler( logging.WithScope(logger, "hnd"), - netutil.NewConnPool(4096, 60*time.Second), + netutil.NewSessionCache[netutil.NATKey](4096, 60*time.Second), udpDesyncer, cfg.UDP.Clone(), ) @@ -397,7 +397,6 @@ func createServer( udpDesyncer, cfg.UDP.Clone(), cfg.Conn.Clone(), - netutil.NewConnPool(4096, 60*time.Second), ) return tun.NewTunServer( diff --git a/go.mod b/go.mod index f6cabd6d..33d82d80 100644 --- a/go.mod +++ b/go.mod @@ -9,6 +9,7 @@ require ( github.com/google/gopacket v1.1.19 github.com/miekg/dns v1.1.61 github.com/rs/zerolog v1.33.0 + github.com/samber/lo v1.52.0 github.com/songgao/water v0.0.0-20200317203138-2b4b6d7c09d8 github.com/stretchr/testify v1.11.1 github.com/urfave/cli/v3 v3.6.1 @@ -24,7 +25,6 @@ require ( github.com/mattn/go-colorable v0.1.13 // indirect github.com/mattn/go-isatty v0.0.19 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect - github.com/samber/lo v1.52.0 // indirect golang.org/x/mod v0.28.0 // indirect golang.org/x/sync v0.17.0 // indirect golang.org/x/text v0.29.0 // indirect diff --git a/internal/cache/cache.go b/internal/cache/cache.go index 9ae875c3..1d816b89 100644 --- a/internal/cache/cache.go +++ b/internal/cache/cache.go @@ -33,13 +33,15 @@ func (o *options) WithSkipExisting(skipExisting bool) *options { // Cache is the unified interface for all cache implementations. // The Set method accepts a variadic list of options. -type Cache interface { - // Get retrieves a value from the cache. - Get(key string) (any, bool) - // Set adds a value to the cache, applying any provided options. - Set(key string, value any, opts *options) bool - Delete(key string) - Range(f func(key string, value any) bool) +type Cache[K comparable] interface { + // Fetch retrieves a value from the cache. + Fetch(key K) (any, bool) + // Store adds a value to the cache, applying any provided options. + Store(key K, value any, opts *options) bool + Evict(key K) + Has(key K) bool + // ForEach iterates over the cache items. + ForEach(f func(key K, value any) error) error // Size returns the number of items in the cache. Size() int } diff --git a/internal/cache/lru_cache.go b/internal/cache/lru_cache.go index 5db51742..f2d83aa8 100644 --- a/internal/cache/lru_cache.go +++ b/internal/cache/lru_cache.go @@ -5,66 +5,73 @@ import ( "sync" ) -var _ Cache = (*LRUCache)(nil) +var _ Cache[string] = (*LRUCache[string])(nil) // lruEntry represents the value stored in the cache and the linked list node. -type lruEntry struct { - key string +type lruEntry[K comparable] struct { + key K value any // expiry time.Time field removed } // LRUCache is a concurrent, fixed-size cache with an LRU eviction policy. -type LRUCache struct { +type LRUCache[K comparable] struct { capacity int - mu sync.RWMutex + mu sync.Mutex // list is a doubly linked list used for tracking access order. // Front is Most Recently Used (MRU), Back is Least Recently Used (LRU). list *list.List // cache maps the key to the list element (*list.Element) which holds the lruEntry. - cache map[string]*list.Element + cache map[K]*list.Element + + onInvalidate func(key K, value any) } -// NewLRUCache creates a new LRU Cache instance with the given capacity. +// NewLRUCache creates a new LRU Cache instance. // Capacity must be greater than zero. -func NewLRUCache(capacity int) Cache { +func NewLRUCache[K comparable]( + capacity int, + onInvalidate func(key K, value any), +) Cache[K] { if capacity <= 0 { // Default to a sensible minimum capacity if input is invalid capacity = 100 } - return &LRUCache{ - capacity: capacity, - list: list.New(), - cache: make(map[string]*list.Element, capacity), + return &LRUCache[K]{ + capacity: capacity, + list: list.New(), + cache: make(map[K]*list.Element, capacity), + onInvalidate: onInvalidate, } } -// isExpired function removed - // evictOldest removes the least recently used item from the cache. -func (c *LRUCache) evictOldest() { +func (c *LRUCache[K]) evictOldest() { // Element is at the back of the list (LRU) tail := c.list.Back() if tail != nil { - c.removeElement(tail) + c.removeByElement(tail) } } -// removeElement removes a specific list element from both the list and the map. -func (c *LRUCache) removeElement(e *list.Element) { +// removeByElement removes a specific list element from both the list and the map. +func (c *LRUCache[K]) removeByElement(e *list.Element) { c.list.Remove(e) - entry := e.Value.(*lruEntry) + entry := e.Value.(*lruEntry[K]) delete(c.cache, entry.key) + if c.onInvalidate != nil { + c.onInvalidate(entry.key, entry.value) + } } -// Get retrieves a value from the cache. +// Fetch retrieves a value from the cache. // If found, the item is promoted to Most Recently Used (MRU). -func (c *LRUCache) Get(key string) (any, bool) { - // Use RLock for concurrent safe read operations - c.mu.RLock() - defer c.mu.RUnlock() +func (c *LRUCache[K]) Fetch(key K) (any, bool) { + // Use Lock since MoveToFront modifies the linked list + c.mu.Lock() + defer c.mu.Unlock() // Check if key exists in the map if element, ok := c.cache[key]; ok { @@ -73,15 +80,15 @@ func (c *LRUCache) Get(key string) (any, bool) { // Promote to Most Recently Used (MRU) c.list.MoveToFront(element) - entry := element.Value.(*lruEntry) + entry := element.Value.(*lruEntry[K]) return entry.value, true } return nil, false } -// Set adds a value to the cache, applying any provided options. -func (c *LRUCache) Set(key string, value any, opts *options) bool { +// Store adds a value to the cache, applying any provided options. +func (c *LRUCache[K]) Store(key K, value any, opts *options) bool { // Use Write Lock for modification operations c.mu.Lock() defer c.mu.Unlock() @@ -100,7 +107,7 @@ func (c *LRUCache) Set(key string, value any, opts *options) bool { } if ok { - entry := element.Value.(*lruEntry) + entry := element.Value.(*lruEntry[K]) entry.value = value c.list.MoveToFront(element) @@ -108,7 +115,7 @@ func (c *LRUCache) Set(key string, value any, opts *options) bool { } // Key is new: Create a new entry - entry := &lruEntry{ + entry := &lruEntry[K]{ key: key, value: value, } @@ -125,35 +132,43 @@ func (c *LRUCache) Set(key string, value any, opts *options) bool { return true } -// Range iterates over the cache items. -// If f returns false, the item is removed from the cache. -func (c *LRUCache) Range(f func(key string, value any) bool) { +// ForEach iterates over the cache items. +func (c *LRUCache[K]) ForEach(f func(key K, value any) error) error { c.mu.Lock() defer c.mu.Unlock() var next *list.Element for e := c.list.Front(); e != nil; e = next { next = e.Next() - entry := e.Value.(*lruEntry) - if !f(entry.key, entry.value) { - c.removeElement(e) + entry := e.Value.(*lruEntry[K]) + if err := f(entry.key, entry.value); err != nil { + return err } } + return nil } -// Delete removes an item from the cache. -func (c *LRUCache) Delete(key string) { +// Evict removes an item from the cache. +func (c *LRUCache[K]) Evict(key K) { c.mu.Lock() defer c.mu.Unlock() if element, ok := c.cache[key]; ok { - c.removeElement(element) + c.removeByElement(element) } } +// Has checks if an item exists in the cache without moving its MRU status. +func (c *LRUCache[K]) Has(key K) bool { + c.mu.Lock() + defer c.mu.Unlock() + _, ok := c.cache[key] + return ok +} + // Size returns the number of items in the cache. -func (c *LRUCache) Size() int { - c.mu.RLock() - defer c.mu.RUnlock() +func (c *LRUCache[K]) Size() int { + c.mu.Lock() + defer c.mu.Unlock() return c.list.Len() } diff --git a/internal/cache/ttl_cache.go b/internal/cache/ttl_cache.go index 87eaa6e1..b4a5678d 100644 --- a/internal/cache/ttl_cache.go +++ b/internal/cache/ttl_cache.go @@ -7,16 +7,16 @@ import ( "time" ) -var _ Cache = (*TTLCache)(nil) +var _ Cache[string] = (*TTLCache[string])(nil) // ttlCacheItem represents a single cached item using generics. -type ttlCacheItem struct { +type ttlCacheItem[K comparable] struct { value any // The cached data of type T. expiresAt time.Time // The time when the item expires. } // isExpired checks if the item has expired. -func (i ttlCacheItem) isExpired() bool { +func (i ttlCacheItem[K]) isExpired() bool { if i.expiresAt.IsZero() { // zero time means no expiration. return false @@ -25,40 +25,43 @@ func (i ttlCacheItem) isExpired() bool { } // ttlCacheShard represents a single, thread-safe shard of the cache. -type ttlCacheShard struct { - items map[string]ttlCacheItem // items holds the cache data for this shard. +type ttlCacheShard[K comparable] struct { + items map[K]ttlCacheItem[K] // items holds the cache data for this shard. mu sync.RWMutex } type TTLCacheAttrs struct { NumOfShards uint8 CleanupInterval time.Duration + HashFunc func(key any) uint64 } // TTLCache is a high-performance, sharded, generic TTL cache. -type TTLCache struct { - shards []*ttlCacheShard // A slice of cache shards. +type TTLCache[K comparable] struct { + shards []*ttlCacheShard[K] // A slice of cache shards. + hashFunc func(key any) uint64 } // NewTTLCache creates a new sharded TTL cache with a background janitor goroutine. // numShards specifies the number of shards to create and must be greater than 0. // cleanupInterval specifies how often the janitor should run. -func NewTTLCache( +func NewTTLCache[K comparable]( attrs TTLCacheAttrs, -) *TTLCache { +) *TTLCache[K] { if attrs.NumOfShards == 0 { panic( fmt.Errorf("number of shards must be greater than 0, got %d", attrs.NumOfShards), ) } - c := &TTLCache{ - shards: make([]*ttlCacheShard, attrs.NumOfShards), + c := &TTLCache[K]{ + shards: make([]*ttlCacheShard[K], attrs.NumOfShards), + hashFunc: attrs.HashFunc, } for i := range attrs.NumOfShards { - c.shards[i] = &ttlCacheShard{ - items: make(map[string]ttlCacheItem), + c.shards[i] = &ttlCacheShard[K]{ + items: make(map[K]ttlCacheItem[K]), } } @@ -70,7 +73,7 @@ func NewTTLCache( } // janitor runs the cleanup goroutine, calling ForceCleanup at the specified interval. -func (c *TTLCache) janitor(interval time.Duration) { +func (c *TTLCache[K]) janitor(interval time.Duration) { ticker := time.NewTicker(interval) defer ticker.Stop() for range ticker.C { @@ -79,9 +82,22 @@ func (c *TTLCache) janitor(interval time.Duration) { } // getShard maps a key to its corresponding cache shard using a hash function. -func (c *TTLCache) getShard(key string) *ttlCacheShard { +func (c *TTLCache[K]) getShard(key K) *ttlCacheShard[K] { + if c.hashFunc != nil { + hash := c.hashFunc(key) + return c.shards[hash%uint64(len(c.shards))] + } + hasher := fnv.New64a() - hasher.Write([]byte(key)) + // Optimally hash the memory without string allocation + switch v := any(key).(type) { + case string: + hasher.Write([]byte(v)) + case []byte: + hasher.Write(v) + default: + _, _ = fmt.Fprint(hasher, key) + } hash := hasher.Sum64() return c.shards[hash%uint64(len(c.shards))] } @@ -89,9 +105,9 @@ func (c *TTLCache) getShard(key string) *ttlCacheShard { // ┌─────────────┐ // │ PUBLIC APIs │ // └─────────────┘ -// Set adds an item to the cache, replacing any existing item. +// Store adds an item to the cache, replacing any existing item. // If ttl is 0 or negative, the item will never expire (passive-only). -func (c *TTLCache) Set(key string, value any, opts *options) bool { +func (c *TTLCache[K]) Store(key K, value any, opts *options) bool { shard := c.getShard(key) shard.mu.Lock() @@ -115,7 +131,7 @@ func (c *TTLCache) Set(key string, value any, opts *options) bool { } expiresAt := time.Now().Add(opts.ttl) - newItem := ttlCacheItem{ + newItem := ttlCacheItem[K]{ value: value, expiresAt: expiresAt, } @@ -124,10 +140,10 @@ func (c *TTLCache) Set(key string, value any, opts *options) bool { return true } -// Get retrieves an item from the cache. +// Fetch retrieves an item from the cache. // It returns the item (of type T) and true if found and not expired. // Otherwise, it returns the zero value of T and false. -func (c *TTLCache) Get(key string) (any, bool) { +func (c *TTLCache[K]) Fetch(key K) (any, bool) { shard := c.getShard(key) shard.mu.RLock() i, ok := shard.items[key] @@ -157,17 +173,31 @@ func (c *TTLCache) Get(key string) (any, bool) { return i.value, true } -// Delete removes an item from the cache. -func (c *TTLCache) Delete(key string) { +// Evict removes an item from the cache. +func (c *TTLCache[K]) Evict(key K) { shard := c.getShard(key) shard.mu.Lock() delete(shard.items, key) shard.mu.Unlock() } +// Has checks if an item exists in the cache and is not expired. +func (c *TTLCache[K]) Has(key K) bool { + shard := c.getShard(key) + shard.mu.RLock() + i, ok := shard.items[key] + shard.mu.RUnlock() + + if !ok { + return false + } + + return !i.isExpired() +} + // ForceCleanup actively scans all shards and deletes expired items. // This is called periodically by the janitor but can also be called manually. -func (c *TTLCache) ForceCleanup() { +func (c *TTLCache[K]) ForceCleanup() { now := time.Now() for _, shard := range c.shards { shard.mu.Lock() @@ -180,22 +210,23 @@ func (c *TTLCache) ForceCleanup() { } } -// Range iterates over the cache items. -// If f returns false, the item is removed from the cache. -func (c *TTLCache) Range(f func(key string, value any) bool) { +// ForEach iterates over the cache items. +func (c *TTLCache[K]) ForEach(f func(key K, value any) error) error { for _, shard := range c.shards { - shard.mu.Lock() - for key, i := range shard.items { - if !f(key, i.value) { - delete(shard.items, key) + shard.mu.RLock() + for key, i := range shard.items { // Pre-allocate values to avoid holding RLock unnecessarily? For now, keep simple. + if err := f(key, i.value); err != nil { + shard.mu.RUnlock() + return err } } - shard.mu.Unlock() + shard.mu.RUnlock() } + return nil } // Size returns the total number of items across all shards. -func (c *TTLCache) Size() int { +func (c *TTLCache[K]) Size() int { total := 0 for _, shard := range c.shards { shard.mu.RLock() diff --git a/internal/cache/ttl_cache_benchmark_test.go b/internal/cache/ttl_cache_benchmark_test.go new file mode 100644 index 00000000..acedc7f2 --- /dev/null +++ b/internal/cache/ttl_cache_benchmark_test.go @@ -0,0 +1,54 @@ +package cache + +import ( + "fmt" + "testing" + "time" +) + +type dummyIPKey [16]byte + +func generateDummyIPKey(i int) dummyIPKey { + var k dummyIPKey + k[0] = byte(i) + k[1] = byte(i >> 8) + return k +} + +func BenchmarkCacheKeys(b *testing.B) { + strCache := NewTTLCache[string](TTLCacheAttrs{ + NumOfShards: 1, + CleanupInterval: time.Minute, + }) + + ipCache := NewTTLCache[dummyIPKey](TTLCacheAttrs{ + NumOfShards: 1, + CleanupInterval: time.Minute, + }) + + b.Run("TTLCache_StringKey", func(b *testing.B) { + var keys []string + for i := 0; i < b.N; i++ { + keys = append(keys, "192.168.0."+fmt.Sprint(i)) + } + b.ReportAllocs() + for i := 0; i < b.N; i++ { + key := keys[i] + strCache.Store(key, 1, nil) + strCache.Fetch(key) + } + }) + + b.Run("TTLCache_GenericStructKey", func(b *testing.B) { + var keys []dummyIPKey + for i := 0; i < b.N; i++ { + keys = append(keys, generateDummyIPKey(i)) + } + b.ReportAllocs() + for i := 0; i < b.N; i++ { + key := keys[i] + ipCache.Store(key, 1, nil) + ipCache.Fetch(key) + } + }) +} diff --git a/internal/desync/tls.go b/internal/desync/tls.go index 04299530..cd037ac5 100644 --- a/internal/desync/tls.go +++ b/internal/desync/tls.go @@ -53,7 +53,9 @@ func (d *TLSDesyncer) Desync( } if d.sniffer != nil && d.writer != nil && lo.FromPtr(httpsOpts.FakeCount) > 0 { - oTTL := d.sniffer.GetOptimalTTL(conn.RemoteAddr().(*net.TCPAddr).IP.String()) + oTTL := d.sniffer.GetOptimalTTL( + netutil.NewIPKey(conn.RemoteAddr().(*net.TCPAddr).IP), + ) n, err := d.sendFakePackets(ctx, logger, conn, oTTL, httpsOpts) if err != nil { logger.Warn().Err(err).Msg("failed to send fake packets") diff --git a/internal/desync/udp.go b/internal/desync/udp.go index 228f68fe..93891369 100644 --- a/internal/desync/udp.go +++ b/internal/desync/udp.go @@ -7,6 +7,7 @@ import ( "github.com/rs/zerolog" "github.com/xvzc/SpoofDPI/internal/config" "github.com/xvzc/SpoofDPI/internal/logging" + "github.com/xvzc/SpoofDPI/internal/netutil" "github.com/xvzc/SpoofDPI/internal/packet" ) @@ -41,8 +42,8 @@ func (d *UDPDesyncer) Desync( return 0, nil } - dstIP := rConn.RemoteAddr().(*net.UDPAddr).IP.String() - oTTL := d.sniffer.GetOptimalTTL(dstIP) + dstIP := rConn.RemoteAddr().(*net.UDPAddr).IP + oTTL := d.sniffer.GetOptimalTTL(netutil.NewIPKey(dstIP)) var totalSent int for range *opts.FakeCount { diff --git a/internal/dns/cache.go b/internal/dns/cache.go index 11b10e24..8122ab3b 100644 --- a/internal/dns/cache.go +++ b/internal/dns/cache.go @@ -15,13 +15,13 @@ import ( type CacheResolver struct { logger zerolog.Logger - ttlCache cache.Cache // Owns the cache + ttlCache cache.Cache[string] // Owns the cache } // NewCacheResolver wraps a "worker" resolver with a cache. func NewCacheResolver( logger zerolog.Logger, - cache cache.Cache, + cache cache.Cache[string], ) *CacheResolver { return &CacheResolver{ logger: logger, @@ -53,7 +53,7 @@ func (cr *CacheResolver) Resolve( // the cache might return the wrong one. // For now, assuming simplistic cache key = domain, but awareness of potential issue. // Ideally: key = domain + qtypes + spec-related-things - if item, ok := cr.ttlCache.Get(domain); ok { + if item, ok := cr.ttlCache.Fetch(domain); ok { logger.Debug().Str("domain", domain).Msgf("hit") return item.(*RecordSet).Clone(), nil } @@ -81,7 +81,7 @@ func (cr *CacheResolver) Resolve( Uint32("ttl", rSet.TTL). Msg("set") - _ = cr.ttlCache.Set( + _ = cr.ttlCache.Store( domain, rSet, cache.Options().WithTTL(time.Duration(rSet.TTL)*time.Second), diff --git a/internal/netutil/conn.go b/internal/netutil/conn.go index 1fd377c8..465679bd 100644 --- a/internal/netutil/conn.go +++ b/internal/netutil/conn.go @@ -9,6 +9,7 @@ import ( "net" "os" "sync" + "sync/atomic" "syscall" "time" @@ -227,9 +228,33 @@ func (b *BufferedConn) Peek(n int) ([]byte, error) { // This is useful for sessions which should stay alive as long as there is activity. type IdleTimeoutConn struct { net.Conn - Timeout time.Duration - LastActive time.Time - ExpiredAt time.Time // Calculated expiration time for cleanup + Key any + Timeout time.Duration + + lastActivity int64 // UnixNano atomic + expiredAt int64 // UnixNano atomic + + onActivity func() + onClose func() +} + +// NewIdleTimeoutConn wraps a net.Conn and securely initializes its internal atomic deadlines. +func NewIdleTimeoutConn(conn net.Conn, timeout time.Duration) *IdleTimeoutConn { + c := &IdleTimeoutConn{ + Conn: conn, + Timeout: timeout, + } + + now := time.Now() + atomic.StoreInt64(&c.lastActivity, now.UnixNano()) + if timeout > 0 { + expTime := now.Add(timeout) + atomic.StoreInt64(&c.expiredAt, expTime.UnixNano()) + _ = c.SetReadDeadline(expTime) + _ = c.SetWriteDeadline(expTime) + } + + return c } func (c *IdleTimeoutConn) Read(b []byte) (int, error) { @@ -246,17 +271,48 @@ func (c *IdleTimeoutConn) Write(b []byte) (int, error) { // Returns false if the connection was already expired. func (c *IdleTimeoutConn) ExtendDeadline() bool { now := time.Now() + nowUnix := now.UnixNano() - // Check if already expired - if !c.ExpiredAt.IsZero() && now.After(c.ExpiredAt) { + // 1. Check if already expired (Thread-safe atomic read) + expUnix := atomic.LoadInt64(&c.expiredAt) + if expUnix != 0 && nowUnix > expUnix { return false } - c.LastActive = now + // 2. Throttle OnActivity to drastically reduce LRU Cache Lock Contention + lastActUnix := atomic.LoadInt64(&c.lastActivity) + if nowUnix-lastActUnix > int64(time.Second) { + atomic.StoreInt64(&c.lastActivity, nowUnix) + if c.onActivity != nil { + c.onActivity() + } + } + + // 3. Throttle SetDeadline overhead (System Call) + // Extends only if remaining time is under 70% of timeout if c.Timeout > 0 { - c.ExpiredAt = now.Add(c.Timeout) - _ = c.SetReadDeadline(c.ExpiredAt) - _ = c.SetWriteDeadline(c.ExpiredAt) + if expUnix == 0 || (expUnix-nowUnix) < (c.Timeout.Nanoseconds()*7/10) { + newExpUnix := now.Add(c.Timeout).UnixNano() + atomic.StoreInt64(&c.expiredAt, newExpUnix) + + newExpTime := time.Unix(0, newExpUnix) + _ = c.SetReadDeadline(newExpTime) + _ = c.SetWriteDeadline(newExpTime) + } } + return true } + +// IsExpired safely checks if the connection has surpassed its calculated expiration time. +func (c *IdleTimeoutConn) IsExpired(now time.Time) bool { + expUnix := atomic.LoadInt64(&c.expiredAt) + return expUnix != 0 && now.UnixNano() > expUnix +} + +func (c *IdleTimeoutConn) Close() error { + if c.onClose != nil { + c.onClose() + } + return c.Conn.Close() +} diff --git a/internal/netutil/conn_pool.go b/internal/netutil/conn_pool.go deleted file mode 100644 index f730a85f..00000000 --- a/internal/netutil/conn_pool.go +++ /dev/null @@ -1,206 +0,0 @@ -package netutil - -import ( - "container/list" - "net" - "sync" - "time" -) - -// ConnPool manages UDP connections with LRU eviction policy and idle timeout. -type ConnPool struct { - capacity int - timeout time.Duration - cache map[string]*list.Element - ll *list.List - mu sync.Mutex - stopCh chan struct{} - stopOnce sync.Once -} - -// PooledConn wraps net.Conn with LRU tracking and deadline management. -type PooledConn struct { - net.Conn - pool *ConnPool - key string - timeout time.Duration - expiredAt time.Time -} - -type connEntry struct { - key string - conn *PooledConn -} - -// NewConnPool creates a new pool with the specified capacity and timeout. -// Starts a background goroutine for expired connection cleanup. -func NewConnPool(capacity int, timeout time.Duration) *ConnPool { - p := &ConnPool{ - capacity: capacity, - timeout: timeout, - cache: make(map[string]*list.Element), - ll: list.New(), - stopCh: make(chan struct{}), - } - - // Cleanup interval: half of timeout, min 10s, max 60s - cleanupInterval := timeout / 2 - if cleanupInterval < 10*time.Second { - cleanupInterval = 10 * time.Second - } - if cleanupInterval > 60*time.Second { - cleanupInterval = 60 * time.Second - } - - go p.cleanupLoop(cleanupInterval) - return p -} - -// Add adds a connection to the pool and returns the wrapped connection. -// If capacity is full, evicts the least recently used connection. -func (p *ConnPool) Add(key string, rawConn net.Conn) *PooledConn { - p.mu.Lock() - defer p.mu.Unlock() - - // Evict if capacity is reached - if p.ll.Len() >= p.capacity { - p.evictOldest() - } - - now := time.Now() - expiredAt := now.Add(p.timeout) - - wrapper := &PooledConn{ - Conn: rawConn, - pool: p, - key: key, - timeout: p.timeout, - expiredAt: expiredAt, - } - - _ = rawConn.SetDeadline(expiredAt) - - elem := p.ll.PushFront(&connEntry{key: key, conn: wrapper}) - p.cache[key] = elem - - return wrapper -} - -// Remove closes and removes the connection from the pool. -func (p *ConnPool) Remove(key string) { - p.mu.Lock() - defer p.mu.Unlock() - - if elem, ok := p.cache[key]; ok { - p.removeElement(elem) - } -} - -// Size returns the number of connections in the pool. -func (p *ConnPool) Size() int { - p.mu.Lock() - defer p.mu.Unlock() - return p.ll.Len() -} - -// Stop stops the background cleanup goroutine. -func (p *ConnPool) Stop() { - p.stopOnce.Do(func() { - close(p.stopCh) - }) -} - -// CloseAll closes all connections in the pool. -func (p *ConnPool) CloseAll() { - p.mu.Lock() - defer p.mu.Unlock() - - elem := p.ll.Front() - for elem != nil { - next := elem.Next() - p.removeElement(elem) - elem = next - } -} - -func (p *ConnPool) cleanupLoop(interval time.Duration) { - ticker := time.NewTicker(interval) - defer ticker.Stop() - - for { - select { - case <-p.stopCh: - return - case <-ticker.C: - p.evictExpired() - } - } -} - -func (p *ConnPool) evictExpired() { - p.mu.Lock() - defer p.mu.Unlock() - - now := time.Now() - elem := p.ll.Back() - for elem != nil { - // Save next before potential removal - next := elem.Prev() - e := elem.Value.(*connEntry) - if now.After(e.conn.expiredAt) { - p.removeElement(elem) - } - elem = next - } -} - -func (p *ConnPool) evictOldest() { - if elem := p.ll.Back(); elem != nil { - p.removeElement(elem) - } -} - -func (p *ConnPool) removeElement(elem *list.Element) { - e := elem.Value.(*connEntry) - _ = e.conn.Conn.Close() - p.ll.Remove(elem) - delete(p.cache, e.key) -} - -func (p *ConnPool) touch(key string) { - p.mu.Lock() - defer p.mu.Unlock() - if elem, ok := p.cache[key]; ok { - p.ll.MoveToFront(elem) - } -} - -func (c *PooledConn) refreshDeadline() { - c.expiredAt = time.Now().Add(c.timeout) - _ = c.SetDeadline(c.expiredAt) - c.pool.touch(c.key) -} - -// Read reads data and refreshes the deadline on success. -func (c *PooledConn) Read(b []byte) (n int, err error) { - n, err = c.Conn.Read(b) - if n > 0 { - c.refreshDeadline() - } - return -} - -// Write writes data and refreshes the deadline on success. -func (c *PooledConn) Write(b []byte) (n int, err error) { - n, err = c.Conn.Write(b) - if n > 0 { - c.refreshDeadline() - } - return -} - -// Close removes the connection from the pool (underlying close handled by pool). -func (c *PooledConn) Close() error { - c.pool.Remove(c.key) - return nil -} diff --git a/internal/netutil/dial.go b/internal/netutil/dial.go index 6467913f..81baa9a3 100644 --- a/internal/netutil/dial.go +++ b/internal/netutil/dial.go @@ -53,7 +53,7 @@ func DialFastest( // If Iface is specified, bind to the interface if dst.Iface != nil { - if err := bindToInterface(dialer, dst.Iface, ip); err != nil { + if err := bindToInterface(network, dialer, dst.Iface, ip); err != nil { select { case results <- dialResult{conn: nil, err: err}: case <-ctx.Done(): diff --git a/internal/netutil/dst.go b/internal/netutil/dst.go new file mode 100644 index 00000000..89b42807 --- /dev/null +++ b/internal/netutil/dst.go @@ -0,0 +1,52 @@ +package netutil + +import ( + "fmt" + "net" + "strconv" + "time" +) + +type Destination struct { + Domain string + Addrs []net.IP + Port int + Timeout time.Duration + Iface *net.Interface + Gateway string +} + +func (d *Destination) String() string { + return net.JoinHostPort(d.Domain, strconv.Itoa(d.Port)) +} + +func ValidateDestination( + dstAddrs []net.IP, + dstPort int, + listenAddr *net.TCPAddr, +) (bool, error) { + if dstPort != int(listenAddr.Port) { + return true, nil + } + + var err error + var ifAddrs []net.Addr + ifAddrs, err = net.InterfaceAddrs() + + for _, dstAddr := range dstAddrs { + ip := dstAddr + if ip.IsLoopback() { + return false, fmt.Errorf("loopback addr detected %v", ip.String()) + } + + for _, addr := range ifAddrs { + if ipnet, ok := addr.(*net.IPNet); ok { + if ipnet.IP.Equal(ip) { + return false, fmt.Errorf("interface addr detected %v", ipnet.String()) + } + } + } + } + + return true, err +} diff --git a/internal/netutil/key.go b/internal/netutil/key.go new file mode 100644 index 00000000..bba313bd --- /dev/null +++ b/internal/netutil/key.go @@ -0,0 +1,80 @@ +package netutil + +import ( + "fmt" + "net" +) + +// NATKey represents a 4-tuple (SrcIP, SrcPort, DstIP, DstPort) for zero-allocation NAT session mapping +type NATKey struct { + SrcIP [16]byte + SrcPort uint16 + DstIP [16]byte + DstPort uint16 +} + +// String returns the string representation of the session key. +// Only used for debugging / logging. +func (k NATKey) String() string { + var srcIP, dstIP net.IP + + // Check if IPv4-mapped IPv6 + if isIPv4Mapped(k.SrcIP) { + srcIP = net.IP(k.SrcIP[12:16]) + } else { + srcIP = net.IP(k.SrcIP[:]) + } + + if isIPv4Mapped(k.DstIP) { + dstIP = net.IP(k.DstIP[12:16]) + } else { + dstIP = net.IP(k.DstIP[:]) + } + + return fmt.Sprintf("%v:%d>%v:%d", srcIP, k.SrcPort, dstIP, k.DstPort) +} + +// IPKey represents an IP address for zero-allocation cache mapping +type IPKey [16]byte + +// String returns the string representation of the IPKey. +func (k IPKey) String() string { + var srcIP net.IP + if isIPv4Mapped(k) { + srcIP = net.IP(k[12:16]) + } else { + srcIP = net.IP(k[:]) + } + return srcIP.String() +} + +// NewIPKey zero-alloc constructs an IPKey from net.IP +func NewIPKey(ip net.IP) IPKey { + var k IPKey + ip16 := ip.To16() + if ip16 != nil { + copy(k[:], ip16) + } + return k +} + +// NewNATKey zero-alloc constructs a NATKey from two UDPAddr +func NewNATKey(src *net.UDPAddr, dst *net.UDPAddr) NATKey { + var k NATKey + + // net.IP is a slice. Let's force it to 16 bytes for comparable struct key + srcIP16 := src.IP.To16() + if srcIP16 != nil { + copy(k.SrcIP[:], srcIP16) + } + + dstIP16 := dst.IP.To16() + if dstIP16 != nil { + copy(k.DstIP[:], dstIP16) + } + + k.SrcPort = uint16(src.Port) + k.DstPort = uint16(dst.Port) + + return k +} diff --git a/internal/netutil/key_benchmark_test.go b/internal/netutil/key_benchmark_test.go new file mode 100644 index 00000000..9b1fca6a --- /dev/null +++ b/internal/netutil/key_benchmark_test.go @@ -0,0 +1,30 @@ +package netutil + +import ( + "net" + "testing" +) + +func BenchmarkKeyAllocation(b *testing.B) { + // Dummy test data + clientAddr := &net.UDPAddr{IP: net.ParseIP("192.168.1.100"), Port: 12345} + targetAddrStr := "142.250.190.46:443" + + uAddr, _ := net.ResolveUDPAddr("udp", targetAddrStr) + + b.Run("StringKey_Legacy", func(b *testing.B) { + b.ReportAllocs() + for i := 0; i < b.N; i++ { + // This was the old way + _ = clientAddr.String() + ">" + targetAddrStr + } + }) + + b.Run("StructKey_NATKey", func(b *testing.B) { + b.ReportAllocs() + for i := 0; i < b.N; i++ { + // This is the new way (zero allocation) + _ = NewNATKey(clientAddr, uAddr) + } + }) +} diff --git a/internal/netutil/netutil.go b/internal/netutil/netutil.go index f19e56ba..a79eacda 100644 --- a/internal/netutil/netutil.go +++ b/internal/netutil/netutil.go @@ -1,18 +1,14 @@ -//go:build !linux && !darwin && !freebsd - package netutil -import ( - "fmt" - "net" -) - -// bindToInterface is a no-op on unsupported platforms. -func bindToInterface(dialer *net.Dialer, iface *net.Interface, targetIP net.IP) error { - return nil -} - -// getDefaultGateway is not supported on this platform. -func getDefaultGateway() (string, error) { - return "", fmt.Errorf("getDefaultGateway not supported on this platform") +func isIPv4Mapped(ip [16]byte) bool { + // IPv4-mapped IPv6 address has the prefix 0:0:0:0:0:FFFF + for i := 0; i < 10; i++ { + if ip[i] != 0 { + return false + } + } + if ip[10] != 0xff || ip[11] != 0xff { + return false + } + return true } diff --git a/internal/netutil/addr.go b/internal/netutil/route.go similarity index 66% rename from internal/netutil/addr.go rename to internal/netutil/route.go index 0b51194a..756cd163 100644 --- a/internal/netutil/addr.go +++ b/internal/netutil/route.go @@ -3,59 +3,12 @@ package netutil import ( "fmt" "net" - "strconv" - "time" ) -type Destination struct { - Domain string - Addrs []net.IP - Port int - Timeout time.Duration - Iface *net.Interface - Gateway string -} - -func (d *Destination) String() string { - return net.JoinHostPort(d.Domain, strconv.Itoa(d.Port)) -} - -func ValidateDestination( - dstAddrs []net.IP, - dstPort int, - listenAddr *net.TCPAddr, -) (bool, error) { - if dstPort != int(listenAddr.Port) { - return true, nil - } - - var err error - var ifAddrs []net.Addr - ifAddrs, err = net.InterfaceAddrs() - - for _, dstAddr := range dstAddrs { - ip := dstAddr - if ip.IsLoopback() { - return false, fmt.Errorf("loopback addr detected %v", ip.String()) - } - - for _, addr := range ifAddrs { - if ipnet, ok := addr.(*net.IPNet); ok { - if ipnet.IP.Equal(ip) { - return false, fmt.Errorf("interface addr detected %v", ipnet.String()) - } - } - } - } - - return true, err -} - // FindSafeSubnet scans the 10.0.0.0/8 range to find an unused /30 subnet func FindSafeSubnet() (string, string, error) { - /* Retrieve all active interface addresses to prevent CIDR overlapping. - Checking against existing networks is faster than sending probe packets. - */ + // Retrieve all active interface addresses to prevent CIDR overlapping. + // Checking against existing networks is faster than sending probe packets. addrs, err := net.InterfaceAddrs() if err != nil { return "", "", err @@ -68,9 +21,8 @@ func FindSafeSubnet() (string, string, error) { } } - /* Iterate through the 10.0.0.0/8 private range with a /30 step. - A /30 subnet provides exactly two usable end-point IP addresses. - */ + // Iterate through the 10.0.0.0/8 private range with a /30 step. + // A /30 subnet provides exactly two usable end-point IP addresses. for i := 0; i < 256; i++ { for j := 0; j < 256; j++ { // Construct candidate IP pair: 10.i.j.1 and 10.i.j.2 diff --git a/internal/netutil/netutil_bsd.go b/internal/netutil/route_bsd.go similarity index 91% rename from internal/netutil/netutil_bsd.go rename to internal/netutil/route_bsd.go index ab5330b6..c84b6329 100644 --- a/internal/netutil/netutil_bsd.go +++ b/internal/netutil/route_bsd.go @@ -14,7 +14,12 @@ import ( // bindToInterface sets the dialer's Control function to bind the socket // to a specific network interface using IP_BOUND_IF on BSD systems. -func bindToInterface(dialer *net.Dialer, iface *net.Interface, targetIP net.IP) error { +func bindToInterface( + network string, + dialer *net.Dialer, + iface *net.Interface, + targetIP net.IP, +) error { if iface == nil { return nil } diff --git a/internal/netutil/netutil_linux.go b/internal/netutil/route_linux.go similarity index 71% rename from internal/netutil/netutil_linux.go rename to internal/netutil/route_linux.go index b903b541..3273e124 100644 --- a/internal/netutil/netutil_linux.go +++ b/internal/netutil/route_linux.go @@ -12,7 +12,12 @@ import ( // bindToInterface sets the dialer's LocalAddr to use the interface's IP as the source address. // On Linux, we only set LocalAddr because SO_BINDTODEVICE can cause issues with // socket lookup for incoming packets. -func bindToInterface(dialer *net.Dialer, iface *net.Interface, targetIP net.IP) error { +func bindToInterface( + network string, + dialer *net.Dialer, + iface *net.Interface, + targetIP net.IP, +) error { if iface == nil { return nil } @@ -27,10 +32,22 @@ func bindToInterface(dialer *net.Dialer, iface *net.Interface, targetIP net.IP) if ipnet, ok := addr.(*net.IPNet); ok { // Match IP version: use IPv4 source for IPv4 target, IPv6 for IPv6 if targetIP.To4() != nil && ipnet.IP.To4() != nil && !ipnet.IP.IsLoopback() { - dialer.LocalAddr = &net.TCPAddr{IP: ipnet.IP} + if strings.HasPrefix(network, "tcp") { + dialer.LocalAddr = &net.TCPAddr{IP: ipnet.IP} + } else if strings.HasPrefix(network, "udp") { + dialer.LocalAddr = &net.UDPAddr{IP: ipnet.IP} + } else { + dialer.LocalAddr = &net.IPAddr{IP: ipnet.IP} + } return nil } else if targetIP.To4() == nil && ipnet.IP.To4() == nil && !ipnet.IP.IsLoopback() { - dialer.LocalAddr = &net.TCPAddr{IP: ipnet.IP} + if strings.HasPrefix(network, "tcp") { + dialer.LocalAddr = &net.TCPAddr{IP: ipnet.IP} + } else if strings.HasPrefix(network, "udp") { + dialer.LocalAddr = &net.UDPAddr{IP: ipnet.IP} + } else { + dialer.LocalAddr = &net.IPAddr{IP: ipnet.IP} + } return nil } } diff --git a/internal/netutil/route_unsupported.go b/internal/netutil/route_unsupported.go new file mode 100644 index 00000000..7e13f7f7 --- /dev/null +++ b/internal/netutil/route_unsupported.go @@ -0,0 +1,23 @@ +//go:build !linux && !darwin && !freebsd + +package netutil + +import ( + "fmt" + "net" +) + +// bindToInterface is a no-op on unsupported platforms. +func bindToInterface( + network string, + dialer *net.Dialer, + iface *net.Interface, + targetIP net.IP, +) error { + return nil +} + +// getDefaultGateway is not supported on this platform. +func getDefaultGateway() (string, error) { + return "", fmt.Errorf("getDefaultGateway not supported on this platform") +} diff --git a/internal/netutil/session_cache.go b/internal/netutil/session_cache.go new file mode 100644 index 00000000..acd4f1f7 --- /dev/null +++ b/internal/netutil/session_cache.go @@ -0,0 +1,137 @@ +package netutil + +import ( + "net" + "sync" + "time" + + "github.com/xvzc/SpoofDPI/internal/cache" +) + +// SessionCache manages UDP connections with LRU eviction policy and idle timeout. +type SessionCache[K comparable] struct { + storage cache.Cache[K] + timeout time.Duration + stopCh chan struct{} + stopOnce sync.Once +} + +// NewSessionCache creates a new pool with the specified capacity and timeout. +// Starts a background goroutine for expired connection cleanup. +func NewSessionCache[K comparable]( + capacity int, + timeout time.Duration, +) *SessionCache[K] { + p := &SessionCache[K]{ + timeout: timeout, + stopCh: make(chan struct{}), + } + + onInvalidate := func(k K, v any) { + if conn, ok := v.(*IdleTimeoutConn); ok { + _ = conn.Conn.Close() + } + } + + p.storage = cache.NewLRUCache(capacity, onInvalidate) + + // Cleanup interval: half of timeout, min 10s, max 60s + cleanupInterval := timeout / 2 + cleanupInterval = max(cleanupInterval, 10*time.Second) + cleanupInterval = min(cleanupInterval, 60*time.Second) + + go p.cleanupLoop(cleanupInterval) + return p +} + +// Store adds a connection to the cache and returns the wrapped connection. +// If the key already exists, the old connection is closed and evicted first. +// If capacity is full, evicts the least recently used connection. +func (p *SessionCache[K]) Store(key K, rawConn net.Conn) *IdleTimeoutConn { + wrapper := NewIdleTimeoutConn(rawConn, p.timeout) + wrapper.Key = key + + wrapper.onActivity = func() { + p.storage.Fetch(key) + } + + wrapper.onClose = func() { + p.Evict(key) + } + + p.storage.Store(key, wrapper, nil) + + return wrapper +} + +// Fetch retrieves a connection from the pool, refreshing its LRU status. +func (p *SessionCache[K]) Fetch(key K) (*IdleTimeoutConn, bool) { + if val, ok := p.storage.Fetch(key); ok { + return val.(*IdleTimeoutConn), true + } + return nil, false +} + +// Evict closes and removes the connection from the pool. +func (p *SessionCache[K]) Evict(key K) { + p.storage.Evict(key) +} + +// Has checks if the connection exists in the cache. +func (p *SessionCache[K]) Has(key K) bool { + return p.storage.Has(key) +} + +// Size returns the number of connections in the pool. +func (p *SessionCache[K]) Size() int { + return p.storage.Size() +} + +// Stop stops the background cleanup goroutine. +func (p *SessionCache[K]) Stop() { + p.stopOnce.Do(func() { + close(p.stopCh) + }) +} + +// CloseAll closes all connections in the pool. +func (p *SessionCache[K]) CloseAll() { + var toRemove []K + _ = p.storage.ForEach(func(key K, value any) error { + toRemove = append(toRemove, key) + return nil + }) + for _, k := range toRemove { + p.Evict(k) // safely removes without deadlocking + } +} + +func (p *SessionCache[K]) cleanupLoop(interval time.Duration) { + ticker := time.NewTicker(interval) + defer ticker.Stop() + + for { + select { + case <-p.stopCh: + return + case <-ticker.C: + p.evictExpired() + } + } +} + +func (p *SessionCache[K]) evictExpired() { + now := time.Now() + var toRemove []K + _ = p.storage.ForEach(func(key K, value any) error { + if conn, ok := value.(*IdleTimeoutConn); ok { + if conn.IsExpired(now) { + toRemove = append(toRemove, key) + } + } + return nil + }) + for _, k := range toRemove { + p.Evict(k) + } +} diff --git a/internal/packet/sniffer.go b/internal/packet/sniffer.go index 5d36d236..07fe4c08 100644 --- a/internal/packet/sniffer.go +++ b/internal/packet/sniffer.go @@ -4,13 +4,14 @@ import ( "net" "github.com/xvzc/SpoofDPI/internal/cache" + "github.com/xvzc/SpoofDPI/internal/netutil" ) type Sniffer interface { StartCapturing() RegisterUntracked(addrs []net.IP) - GetOptimalTTL(key string) uint8 - Cache() cache.Cache + GetOptimalTTL(key netutil.IPKey) uint8 + Cache() cache.Cache[netutil.IPKey] } // estimateHops estimates the number of hops based on TTL. diff --git a/internal/packet/tcp_sniffer.go b/internal/packet/tcp_sniffer.go index 52f5b474..1e4bddb3 100644 --- a/internal/packet/tcp_sniffer.go +++ b/internal/packet/tcp_sniffer.go @@ -9,6 +9,7 @@ import ( "github.com/rs/zerolog" "github.com/xvzc/SpoofDPI/internal/cache" "github.com/xvzc/SpoofDPI/internal/logging" + "github.com/xvzc/SpoofDPI/internal/netutil" ) var _ Sniffer = (*TCPSniffer)(nil) @@ -16,7 +17,7 @@ var _ Sniffer = (*TCPSniffer)(nil) type TCPSniffer struct { logger zerolog.Logger - nhopCache cache.Cache + nhopCache cache.Cache[netutil.IPKey] defaultTTL uint8 handle Handle @@ -24,7 +25,7 @@ type TCPSniffer struct { func NewTCPSniffer( logger zerolog.Logger, - cache cache.Cache, + cache cache.Cache[netutil.IPKey], handle Handle, defaultTTL uint8, ) *TCPSniffer { @@ -38,7 +39,7 @@ func NewTCPSniffer( // --- HopTracker Methods --- -func (ts *TCPSniffer) Cache() cache.Cache { +func (ts *TCPSniffer) Cache() cache.Cache[netutil.IPKey] { return ts.nhopCache } @@ -64,15 +65,19 @@ func (ts *TCPSniffer) StartCapturing() { // Addresses that are already being tracked are ignored. func (ts *TCPSniffer) RegisterUntracked(addrs []net.IP) { for _, v := range addrs { - ts.nhopCache.Set(v.String(), ts.defaultTTL, cache.Options().WithSkipExisting(true)) + ts.nhopCache.Store( + netutil.NewIPKey(v), + ts.defaultTTL, + cache.Options().WithSkipExisting(true), + ) } } // GetOptimalTTL retrieves the estimated hop count for a given key from the cache. // It returns the hop count and true if found, or 0 and false if not found. -func (ts *TCPSniffer) GetOptimalTTL(key string) uint8 { +func (ts *TCPSniffer) GetOptimalTTL(key netutil.IPKey) uint8 { hopCount := uint8(255) - if oTTL, ok := ts.nhopCache.Get(key); ok { + if oTTL, ok := ts.nhopCache.Fetch(key); ok { hopCount = oTTL.(uint8) } @@ -96,7 +101,7 @@ func (ts *TCPSniffer) processPacket(ctx context.Context, p gopacket.Packet) { } // Check for a TCP layer - var srcIP string + var srcIP []byte var ttlLeft uint8 // Handle IPv4 @@ -107,12 +112,12 @@ func (ts *TCPSniffer) processPacket(ctx context.Context, p gopacket.Packet) { return } - srcIP = ip.SrcIP.String() + srcIP = ip.SrcIP ttlLeft = ip.TTL } else if ipLayer := p.Layer(layers.LayerTypeIPv6); ipLayer != nil { // Handle IPv6 ip, _ := ipLayer.(*layers.IPv6) - srcIP = ip.SrcIP.String() + srcIP = ip.SrcIP ttlLeft = ip.HopLimit } else { return // No IP layer found @@ -120,16 +125,16 @@ func (ts *TCPSniffer) processPacket(ctx context.Context, p gopacket.Packet) { // Create the cache key: ServerIP:ServerPort // (The source of the SYN/ACK is the server) - key := srcIP + key := netutil.NewIPKey(srcIP) // Calculate hop count from the TTL nhops := estimateHops(ttlLeft) - stored, exists := ts.nhopCache.Get(key) + stored, exists := ts.nhopCache.Fetch(key) - if ts.nhopCache.Set(key, nhops, cache.Options().WithUpdateExistingOnly(true)) { + if ts.nhopCache.Store(key, nhops, cache.Options().WithUpdateExistingOnly(true)) { if !exists || stored != nhops { logger.Trace(). - Str("from", key). + Str("from", key.String()). Uint8("nhops", nhops). Uint8("ttlLeft", ttlLeft). Msgf("ttl(tcp) update") diff --git a/internal/packet/udp_sniffer.go b/internal/packet/udp_sniffer.go index 42975786..acd46d28 100644 --- a/internal/packet/udp_sniffer.go +++ b/internal/packet/udp_sniffer.go @@ -9,6 +9,7 @@ import ( "github.com/rs/zerolog" "github.com/xvzc/SpoofDPI/internal/cache" "github.com/xvzc/SpoofDPI/internal/logging" + "github.com/xvzc/SpoofDPI/internal/netutil" ) var _ Sniffer = (*UDPSniffer)(nil) @@ -16,7 +17,7 @@ var _ Sniffer = (*UDPSniffer)(nil) type UDPSniffer struct { logger zerolog.Logger - nhopCache cache.Cache + nhopCache cache.Cache[netutil.IPKey] defaultTTL uint8 handle Handle @@ -24,7 +25,7 @@ type UDPSniffer struct { func NewUDPSniffer( logger zerolog.Logger, - cache cache.Cache, + cache cache.Cache[netutil.IPKey], handle Handle, defaultTTL uint8, ) *UDPSniffer { @@ -38,7 +39,7 @@ func NewUDPSniffer( // --- HopTracker Methods --- -func (us *UDPSniffer) Cache() cache.Cache { +func (us *UDPSniffer) Cache() cache.Cache[netutil.IPKey] { return us.nhopCache } @@ -63,15 +64,19 @@ func (us *UDPSniffer) StartCapturing() { // Addresses that are already being tracked are ignored. func (us *UDPSniffer) RegisterUntracked(addrs []net.IP) { for _, v := range addrs { - us.nhopCache.Set(v.String(), us.defaultTTL, cache.Options().WithSkipExisting(true)) + us.nhopCache.Store( + netutil.NewIPKey(v), + us.defaultTTL, + cache.Options().WithSkipExisting(true), + ) } } // GetOptimalTTL retrieves the estimated hop count for a given key from the cache. // It returns the hop count and true if found, or 0 and false if not found. -func (us *UDPSniffer) GetOptimalTTL(key string) uint8 { +func (us *UDPSniffer) GetOptimalTTL(key netutil.IPKey) uint8 { hopCount := uint8(255) - if oTTL, ok := us.nhopCache.Get(key); ok { + if oTTL, ok := us.nhopCache.Fetch(key); ok { hopCount = oTTL.(uint8) } @@ -87,7 +92,7 @@ func (us *UDPSniffer) processPacket(ctx context.Context, p gopacket.Packet) { return } - var srcIP string + var srcIP []byte var ttlLeft uint8 // Handle IPv4 @@ -103,27 +108,27 @@ func (us *UDPSniffer) processPacket(ctx context.Context, p gopacket.Packet) { return } - srcIP = ip.SrcIP.String() + srcIP = ip.SrcIP ttlLeft = ip.TTL } else if ipLayer := p.Layer(layers.LayerTypeIPv6); ipLayer != nil { // Handle IPv6 ip, _ := ipLayer.(*layers.IPv6) - srcIP = ip.SrcIP.String() + srcIP = ip.SrcIP ttlLeft = ip.HopLimit } else { return // No IP layer found } - key := srcIP + key := netutil.NewIPKey(srcIP) // Calculate hop count from the TTL nhops := estimateHops(ttlLeft) - stored, exists := us.nhopCache.Get(key) + stored, exists := us.nhopCache.Fetch(key) - if us.nhopCache.Set(key, nhops, nil) { + if us.nhopCache.Store(key, nhops, nil) { if !exists || stored != nhops { logger.Trace(). - Str("from", key). + Str("from", key.String()). Uint8("nhops", nhops). Uint8("ttlLeft", ttlLeft). Msgf("ttl(udp) update") diff --git a/internal/packet/udp_writer.go b/internal/packet/udp_writer.go index ff92bf97..19de4fb1 100644 --- a/internal/packet/udp_writer.go +++ b/internal/packet/udp_writer.go @@ -79,7 +79,7 @@ func (uw *UDPWriter) WriteCraftedPacket( srcMAC, dstMAC, srcUDP.IP, - srcUDP.IP, + dstUDP.IP, srcPort, dstPort, ttl, diff --git a/internal/server/socks5/udp_associate.go b/internal/server/socks5/udp_associate.go index cfff6d2d..90fa1c3e 100644 --- a/internal/server/socks5/udp_associate.go +++ b/internal/server/socks5/udp_associate.go @@ -17,14 +17,14 @@ import ( type UdpAssociateHandler struct { logger zerolog.Logger - pool *netutil.ConnPool + pool *netutil.SessionCache[netutil.NATKey] desyncer *desync.UDPDesyncer defaultUDPOpts *config.UDPOptions } func NewUdpAssociateHandler( logger zerolog.Logger, - pool *netutil.ConnPool, + pool *netutil.SessionCache[netutil.NATKey], desyncer *desync.UDPDesyncer, defaultUDPOpts *config.UDPOptions, ) *UdpAssociateHandler { @@ -120,9 +120,6 @@ func (h *UdpAssociateHandler) Handle( continue } - // Key: Client Addr -> Target Addr - key := clientAddr.String() + ">" + targetAddrStr - // Resolve address to construct Destination uAddr, err := net.ResolveUDPAddr("udp", targetAddrStr) if err != nil { @@ -133,11 +130,23 @@ func (h *UdpAssociateHandler) Handle( continue } + // Key: Client Addr -> Target Addr (Zero Allocation Struct) + key := netutil.NewNATKey(clientAddr, uAddr) + dst := &netutil.Destination{ Addrs: []net.IP{uAddr.IP}, Port: uAddr.Port, } + // Check if connection already exists in the pool + if conn, ok := h.pool.Fetch(key); ok { + // Write payload to target + if _, err := conn.Write(payload); err != nil { + logger.Warn().Err(err).Msg("failed to write udp to target") + } + continue + } + rawConn, err := netutil.DialFastest(ctx, "udp", dst) if err != nil { logger.Warn().Err(err).Str("addr", targetAddrStr).Msg("failed to dial udp target") @@ -145,7 +154,7 @@ func (h *UdpAssociateHandler) Handle( } // Add to pool (pool handles LRU eviction and deadline) - conn := h.pool.Add(key, rawConn) + conn := h.pool.Store(key, rawConn) // Apply UDP options from rule if matched udpOpts := h.defaultUDPOpts.Clone() @@ -159,7 +168,7 @@ func (h *UdpAssociateHandler) Handle( } // Start a goroutine to read from the target and forward to the client - go func(targetConn *netutil.PooledConn, clientAddr *net.UDPAddr) { + go func(targetConn *netutil.IdleTimeoutConn, clientAddr *net.UDPAddr) { respBuf := make([]byte, 65535) for { n, _, err := targetConn.Conn.(*net.UDPConn).ReadFromUDP(respBuf) diff --git a/internal/server/tun/udp.go b/internal/server/tun/udp.go index 69038dc9..99a09a8c 100644 --- a/internal/server/tun/udp.go +++ b/internal/server/tun/udp.go @@ -18,7 +18,6 @@ type UDPHandler struct { defaultUDPOpts *config.UDPOptions defaultConnOpts *config.ConnOptions desyncer *desync.UDPDesyncer - pool *netutil.ConnPool iface string gateway string } @@ -28,14 +27,12 @@ func NewUDPHandler( desyncer *desync.UDPDesyncer, defaultUDPOpts *config.UDPOptions, defaultConnOpts *config.ConnOptions, - pool *netutil.ConnPool, ) *UDPHandler { return &UDPHandler{ logger: logger, desyncer: desyncer, defaultUDPOpts: defaultUDPOpts, defaultConnOpts: defaultConnOpts, - pool: pool, } } @@ -79,27 +76,26 @@ func (h *UDPHandler) Handle(ctx context.Context, lConn net.Conn, rule *config.Ru connOpts = connOpts.Merge(rule.Conn) } - // Key for connection pool - key := lConn.RemoteAddr().String() + ">" + lConn.LocalAddr().String() - // Dial remote connection rawConn, err := netutil.DialFastest(ctx, "udp", dst) if err != nil { + logger.Error().Msgf("error dialing to %s", dst.String()) return } - // Add to pool (pool handles LRU eviction and deadline) - rConn := h.pool.Add(key, rawConn) + timeout := *connOpts.UDPIdleTimeout + + // Wrap rConn with IdleTimeoutConn + rConnWrapped := netutil.NewIdleTimeoutConn(rawConn, timeout) // Wrap lConn with IdleTimeoutConn as well - timeout := *connOpts.UDPIdleTimeout - lConnWrapped := &netutil.IdleTimeoutConn{Conn: lConn, Timeout: timeout} + lConnWrapped := netutil.NewIdleTimeoutConn(lConn, timeout) // Desync - _, _ = h.desyncer.Desync(ctx, lConnWrapped, rConn, udpOpts) + _, _ = h.desyncer.Desync(ctx, lConnWrapped, rConnWrapped, udpOpts) logger.Debug(). - Msgf("new remote conn (%s -> %s)", lConn.RemoteAddr(), rConn.RemoteAddr()) + Msgf("new remote conn (%s -> %s)", lConn.RemoteAddr(), rConnWrapped.RemoteAddr()) resCh := make(chan netutil.TransferResult, 2) @@ -107,15 +103,15 @@ func (h *UDPHandler) Handle(ctx context.Context, lConn net.Conn, rule *config.Ru defer cancel() startedAt := time.Now() - go netutil.TunnelConns(ctx, resCh, lConnWrapped, rConn, netutil.TunnelDirOut) - go netutil.TunnelConns(ctx, resCh, rConn, lConnWrapped, netutil.TunnelDirIn) + go netutil.TunnelConns(ctx, resCh, lConnWrapped, rConnWrapped, netutil.TunnelDirOut) + go netutil.TunnelConns(ctx, resCh, rConnWrapped, lConnWrapped, netutil.TunnelDirIn) err = netutil.WaitAndLogTunnel( ctx, logger, resCh, startedAt, - netutil.DescribeRoute(lConnWrapped, rConn), + netutil.DescribeRoute(lConnWrapped, rConnWrapped), nil, ) if err != nil { From 5fc7fae3ca15c93f97a0ec3bcad5df6e4f3cbf25 Mon Sep 17 00:00:00 2001 From: xvzc Date: Thu, 19 Mar 2026 06:58:01 +0900 Subject: [PATCH 20/39] docs: update docs --- docs/user-guide/app.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/user-guide/app.md b/docs/user-guide/app.md index 835ff353..5c6a2932 100644 --- a/docs/user-guide/app.md +++ b/docs/user-guide/app.md @@ -104,7 +104,7 @@ silent = true --- -## `network-config` +## `auto-configure-network` `type: boolean` @@ -119,13 +119,13 @@ Specifies whether to automatically set up the system-wide proxy configuration. ` **Command-Line Flag** ```console -$ spoofdpi --network-config +$ spoofdpi --auto-configure-network ``` **TOML Config** ```toml [app] -network-config = true +auto-configure-network = true ``` --- From 934c5193ef675903901fb024029a072c55bdc50f Mon Sep 17 00:00:00 2001 From: xvzc Date: Thu, 19 Mar 2026 06:58:15 +0900 Subject: [PATCH 21/39] refactor: rename network-config flag to auto-configure-network --- cmd/spoofdpi/main.go | 2 +- internal/config/cli.go | 6 +++--- internal/config/cli_test.go | 10 +++++----- internal/config/config.go | 10 +++++----- internal/config/toml_test.go | 4 ++-- internal/config/types.go | 32 ++++++++++++++++---------------- internal/config/types_test.go | 10 +++++----- 7 files changed, 37 insertions(+), 37 deletions(-) diff --git a/cmd/spoofdpi/main.go b/cmd/spoofdpi/main.go index 365511d1..8f5660d7 100644 --- a/cmd/spoofdpi/main.go +++ b/cmd/spoofdpi/main.go @@ -74,7 +74,7 @@ func runApp(ctx context.Context, configDir string, cfg *config.Config) { <-ready // System Proxy Config - if *cfg.App.SetNetworkConfig { + if *cfg.App.AutoConfigureNetwork { if err := srv.SetNetworkConfig(); err != nil { logger.Fatal().Err(err).Msg("failed to set system network config") } diff --git a/internal/config/cli.go b/internal/config/cli.go index 504bfc4a..182830f5 100644 --- a/internal/config/cli.go +++ b/internal/config/cli.go @@ -356,14 +356,14 @@ func CreateCommand( }, &cli.BoolFlag{ - Name: "network-config", + Name: "auto-configure-network", Usage: fmt.Sprintf(` Automatically set system-wide proxy configuration (default: %v)`, - *defaultCfg.App.SetNetworkConfig, + *defaultCfg.App.AutoConfigureNetwork, ), OnlyOnce: true, Action: func(ctx context.Context, cmd *cli.Command, v bool) error { - argsCfg.App.SetNetworkConfig = lo.ToPtr(v) + argsCfg.App.AutoConfigureNetwork = lo.ToPtr(v) return nil }, }, diff --git a/internal/config/cli_test.go b/internal/config/cli_test.go index 88aeae6a..d1e8ff4b 100644 --- a/internal/config/cli_test.go +++ b/internal/config/cli_test.go @@ -26,7 +26,7 @@ func TestCreateCommand_Flags(t *testing.T) { // Verify defaults are preserved assert.Equal(t, zerolog.InfoLevel, *cfg.App.LogLevel) assert.False(t, *cfg.App.Silent) - assert.False(t, *cfg.App.SetNetworkConfig) + assert.False(t, *cfg.App.AutoConfigureNetwork) assert.Equal(t, "127.0.0.1:8080", cfg.App.ListenAddr.String()) assert.Equal(t, uint8(8), *cfg.Conn.DefaultFakeTTL) assert.Equal(t, int64(5000), cfg.Conn.DNSTimeout.Milliseconds()) @@ -53,7 +53,7 @@ func TestCreateCommand_Flags(t *testing.T) { "--clean", // Ensure no config file interferes "--log-level", "debug", "--silent", - "--network-config", + "--auto-configure-network", "--listen-addr", "127.0.0.1:9090", "--default-fake-ttl", "128", "--dns-timeout", "5000", @@ -77,7 +77,7 @@ func TestCreateCommand_Flags(t *testing.T) { // General assert.Equal(t, zerolog.DebugLevel, *cfg.App.LogLevel) assert.True(t, *cfg.App.Silent) - assert.True(t, *cfg.App.SetNetworkConfig) + assert.True(t, *cfg.App.AutoConfigureNetwork) // Server assert.Equal(t, "127.0.0.1:9090", cfg.App.ListenAddr.String()) @@ -246,7 +246,7 @@ func TestCreateCommand_OverrideTOML(t *testing.T) { "--config", configPath, "--log-level", "error", "--silent=false", - "--network-config=false", + "--auto-configure-network=false", "--listen-addr", "127.0.0.1:9090", "--dns-timeout", "2000", "--tcp-timeout", "2000", @@ -275,7 +275,7 @@ func TestCreateCommand_OverrideTOML(t *testing.T) { // General assert.Equal(t, zerolog.ErrorLevel, *capturedCfg.App.LogLevel) assert.False(t, *capturedCfg.App.Silent) - assert.False(t, *capturedCfg.App.SetNetworkConfig) + assert.False(t, *capturedCfg.App.AutoConfigureNetwork) // Server assert.Equal(t, "127.0.0.1:9090", capturedCfg.App.ListenAddr.String()) diff --git a/internal/config/config.go b/internal/config/config.go index 1c70d267..081108e7 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -133,11 +133,11 @@ func (c *Config) ShouldEnablePcap() bool { func getDefault() *Config { //exhaustruct:enforce return &Config{ App: &AppOptions{ - LogLevel: lo.ToPtr(zerolog.InfoLevel), - Silent: lo.ToPtr(false), - SetNetworkConfig: lo.ToPtr(false), - Mode: lo.ToPtr(AppModeHTTP), - ListenAddr: nil, + LogLevel: lo.ToPtr(zerolog.InfoLevel), + Silent: lo.ToPtr(false), + AutoConfigureNetwork: lo.ToPtr(false), + Mode: lo.ToPtr(AppModeHTTP), + ListenAddr: nil, }, Conn: &ConnOptions{ DefaultFakeTTL: lo.ToPtr(uint8(8)), diff --git a/internal/config/toml_test.go b/internal/config/toml_test.go index de3adf4b..fb05cfbb 100644 --- a/internal/config/toml_test.go +++ b/internal/config/toml_test.go @@ -417,7 +417,7 @@ func TestFromTomlFile(t *testing.T) { [app] log-level = "debug" silent = true - network-config = true + auto-configure-network = true mode = "socks5" listen-addr = "127.0.0.1:8080" [connection] @@ -489,7 +489,7 @@ func TestFromTomlFile(t *testing.T) { assert.Equal(t, time.Duration(1000*time.Millisecond), *cfg.Conn.UDPIdleTimeout) assert.Equal(t, zerolog.DebugLevel, *cfg.App.LogLevel) assert.True(t, *cfg.App.Silent) - assert.True(t, *cfg.App.SetNetworkConfig) + assert.True(t, *cfg.App.AutoConfigureNetwork) assert.Equal(t, AppModeSOCKS5, *cfg.App.Mode) assert.Equal(t, "8.8.8.8:53", cfg.DNS.Addr.String()) assert.True(t, *cfg.DNS.Cache) diff --git a/internal/config/types.go b/internal/config/types.go index e9b6214c..3d3759b3 100644 --- a/internal/config/types.go +++ b/internal/config/types.go @@ -44,11 +44,11 @@ var availableLogLevelValues = []string{ } type AppOptions struct { - LogLevel *zerolog.Level `toml:"log-level"` - Silent *bool `toml:"silent"` - SetNetworkConfig *bool `toml:"network-config"` - Mode *AppModeType `toml:"mode"` - ListenAddr *net.TCPAddr `toml:"listen-addr"` + LogLevel *zerolog.Level `toml:"log-level"` + Silent *bool `toml:"silent"` + AutoConfigureNetwork *bool `toml:"auto-configure-network"` + Mode *AppModeType `toml:"mode"` + ListenAddr *net.TCPAddr `toml:"listen-addr"` } func (o *AppOptions) UnmarshalTOML(data any) (err error) { @@ -58,7 +58,7 @@ func (o *AppOptions) UnmarshalTOML(data any) (err error) { } o.Silent = findFrom(m, "silent", parseBoolFn(), &err) - o.SetNetworkConfig = findFrom(m, "network-config", parseBoolFn(), &err) + o.AutoConfigureNetwork = findFrom(m, "auto-configure-network", parseBoolFn(), &err) if p := findFrom(m, "log-level", parseStringFn(checkLogLevel), &err); isOk(p, err) { o.LogLevel = lo.ToPtr(MustParseLogLevel(*p)) } @@ -92,11 +92,11 @@ func (o *AppOptions) Clone() *AppOptions { } return &AppOptions{ - LogLevel: newLevel, - Silent: clonePrimitive(o.Silent), - SetNetworkConfig: clonePrimitive(o.SetNetworkConfig), - Mode: clonePrimitive(o.Mode), - ListenAddr: newAddr, + LogLevel: newLevel, + Silent: clonePrimitive(o.Silent), + AutoConfigureNetwork: clonePrimitive(o.AutoConfigureNetwork), + Mode: clonePrimitive(o.Mode), + ListenAddr: newAddr, } } @@ -112,9 +112,9 @@ func (origin *AppOptions) Merge(overrides *AppOptions) *AppOptions { return &AppOptions{ LogLevel: lo.CoalesceOrEmpty(overrides.LogLevel, origin.LogLevel), Silent: lo.CoalesceOrEmpty(overrides.Silent, origin.Silent), - SetNetworkConfig: lo.CoalesceOrEmpty( - overrides.SetNetworkConfig, - origin.SetNetworkConfig, + AutoConfigureNetwork: lo.CoalesceOrEmpty( + overrides.AutoConfigureNetwork, + origin.AutoConfigureNetwork, ), Mode: lo.CoalesceOrEmpty(overrides.Mode, origin.Mode), ListenAddr: lo.CoalesceOrEmpty(overrides.ListenAddr, origin.ListenAddr), @@ -387,8 +387,8 @@ var availableHTTPSModeValues = []string{ "random", "chunk", "first-byte", - "none", "custom", + "none", } const ( @@ -396,8 +396,8 @@ const ( HTTPSSplitModeRandom HTTPSSplitModeChunk HTTPSSplitModeFirstByte - HTTPSSplitModeNone HTTPSSplitModeCustom + HTTPSSplitModeNone ) func (k HTTPSSplitModeType) String() string { diff --git a/internal/config/types_test.go b/internal/config/types_test.go index 3c1f818f..0a021a0f 100644 --- a/internal/config/types_test.go +++ b/internal/config/types_test.go @@ -25,16 +25,16 @@ func TestAppOptions_UnmarshalTOML(t *testing.T) { { name: "valid general options", input: map[string]any{ - "log-level": "debug", - "silent": true, - "network-config": true, - "mode": "socks5", + "log-level": "debug", + "silent": true, + "auto-configure-network": true, + "mode": "socks5", }, wantErr: false, assert: func(t *testing.T, o AppOptions) { assert.Equal(t, zerolog.DebugLevel, *o.LogLevel) assert.True(t, *o.Silent) - assert.True(t, *o.SetNetworkConfig) + assert.True(t, *o.AutoConfigureNetwork) assert.Equal(t, AppModeSOCKS5, *o.Mode) }, }, From ff8efc04bba21d7be83195c873e7a22275b96d04 Mon Sep 17 00:00:00 2001 From: xvzc Date: Thu, 19 Mar 2026 07:32:23 +0900 Subject: [PATCH 22/39] fix(socks5): close associated UDP connection on TCP disconnection --- internal/server/socks5/udp_associate.go | 1 + 1 file changed, 1 insertion(+) diff --git a/internal/server/socks5/udp_associate.go b/internal/server/socks5/udp_associate.go index 90fa1c3e..ab66d0f4 100644 --- a/internal/server/socks5/udp_associate.go +++ b/internal/server/socks5/udp_associate.go @@ -82,6 +82,7 @@ func (h *UdpAssociateHandler) Handle( go func() { _, _ = io.Copy(io.Discard, lConn) // Block until TCP closes close(done) + lNewConn.Close() // Force ReadFromUDP to unblock and avoid goroutine leak }() buf := make([]byte, 65535) From 6dd4393ec78f38d7eaedc1c8ae261155d829d7ec Mon Sep 17 00:00:00 2001 From: xvzc Date: Thu, 19 Mar 2026 07:53:21 +0900 Subject: [PATCH 23/39] fix: validate source address for UDP associate --- internal/dns/https.go | 3 +-- internal/server/socks5/udp_associate.go | 21 ++++++++++----------- 2 files changed, 11 insertions(+), 13 deletions(-) diff --git a/internal/dns/https.go b/internal/dns/https.go index ab97bf08..1442e1fb 100644 --- a/internal/dns/https.go +++ b/internal/dns/https.go @@ -48,8 +48,7 @@ func NewHTTPSResolver( // Configure HTTP/2 transport explicitly if err := http2.ConfigureTransport(tr); err != nil { - // Log error instead of panic if strict http2 is not required, otherwise panic - panic(fmt.Sprintf("failed to configure http2: %v", err)) + logger.Warn().Err(err).Msg("failed to configure http2 expressly, falling back to default / http/1.1") } return &HTTPSResolver{ diff --git a/internal/server/socks5/udp_associate.go b/internal/server/socks5/udp_associate.go index ab66d0f4..4bf11635 100644 --- a/internal/server/socks5/udp_associate.go +++ b/internal/server/socks5/udp_associate.go @@ -81,12 +81,12 @@ func (h *UdpAssociateHandler) Handle( done := make(chan struct{}) go func() { _, _ = io.Copy(io.Discard, lConn) // Block until TCP closes - close(done) - lNewConn.Close() // Force ReadFromUDP to unblock and avoid goroutine leak + close(done) // Close the channel to signal UDP handler to exit + lNewConn.Close() // Force ReadFromUDP to unblock and avoid goroutine leak }() buf := make([]byte, 65535) - var clientAddr *net.UDPAddr + tcpRemoteIP := lConn.RemoteAddr().(*net.TCPAddr).IP for { // Wait for data @@ -104,13 +104,12 @@ func (h *UdpAssociateHandler) Handle( } } - // Initial Client Identification - if clientAddr == nil { - clientAddr = addr - } - - // Only accept packets from the client that established the association - if !addr.IP.Equal(clientAddr.IP) || addr.Port != clientAddr.Port { + // Security: Only accept UDP packets from the same IP that established the TCP connection + if !addr.IP.Equal(tcpRemoteIP) { + logger.Debug(). + Str("expected", tcpRemoteIP.String()). + Str("actual", addr.IP.String()). + Msg("dropped udp packet from unexpected ip") continue } @@ -132,7 +131,7 @@ func (h *UdpAssociateHandler) Handle( } // Key: Client Addr -> Target Addr (Zero Allocation Struct) - key := netutil.NewNATKey(clientAddr, uAddr) + key := netutil.NewNATKey(addr, uAddr) dst := &netutil.Destination{ Addrs: []net.IP{uAddr.IP}, From 22785a5673bc1f5b38c59391160f99a48431b8fc Mon Sep 17 00:00:00 2001 From: xvzc Date: Thu, 19 Mar 2026 09:34:50 +0900 Subject: [PATCH 24/39] fix(packet): skip ethernet layer if MAC addresses are missing --- internal/packet/tcp_writer.go | 12 +++++++----- internal/packet/udp_writer.go | 14 ++++++++------ 2 files changed, 15 insertions(+), 11 deletions(-) diff --git a/internal/packet/tcp_writer.go b/internal/packet/tcp_writer.go index b94fa025..4ed28756 100644 --- a/internal/packet/tcp_writer.go +++ b/internal/packet/tcp_writer.go @@ -173,12 +173,14 @@ func (tw *TCPWriter) createIPv6Layers( ) ([]gopacket.SerializableLayer, error) { var packetLayers []gopacket.SerializableLayer - eth := &layers.Ethernet{ - SrcMAC: srcMAC, - DstMAC: dstMAC, - EthernetType: layers.EthernetTypeIPv6, + if srcMAC != nil { + eth := &layers.Ethernet{ + SrcMAC: srcMAC, + DstMAC: dstMAC, + EthernetType: layers.EthernetTypeIPv6, + } + packetLayers = append(packetLayers, eth) } - packetLayers = append(packetLayers, eth) ipLayer := &layers.IPv6{ Version: 6, diff --git a/internal/packet/udp_writer.go b/internal/packet/udp_writer.go index 19de4fb1..873562f3 100644 --- a/internal/packet/udp_writer.go +++ b/internal/packet/udp_writer.go @@ -123,7 +123,7 @@ func (uw *UDPWriter) createIPv4Layers( ) ([]gopacket.SerializableLayer, error) { var packetLayers []gopacket.SerializableLayer - if srcMAC != nil { + if srcMAC != nil && dstMAC != nil { eth := &layers.Ethernet{ SrcMAC: srcMAC, DstMAC: dstMAC, @@ -167,12 +167,14 @@ func (uw *UDPWriter) createIPv6Layers( ) ([]gopacket.SerializableLayer, error) { var packetLayers []gopacket.SerializableLayer - eth := &layers.Ethernet{ - SrcMAC: srcMAC, - DstMAC: dstMAC, - EthernetType: layers.EthernetTypeIPv6, + if srcMAC != nil && dstMAC != nil { + eth := &layers.Ethernet{ + SrcMAC: srcMAC, + DstMAC: dstMAC, + EthernetType: layers.EthernetTypeIPv6, + } + packetLayers = append(packetLayers, eth) } - packetLayers = append(packetLayers, eth) ipLayer := &layers.IPv6{ Version: 6, From e0c1bdb00324670f80dba3531b4ea92ff9f9e4eb Mon Sep 17 00:00:00 2001 From: xvzc Date: Thu, 19 Mar 2026 09:36:29 +0900 Subject: [PATCH 25/39] fix: ensure graceful cleanup using global context --- cmd/spoofdpi/main.go | 43 ++--- cmd/spoofdpi/main_test.go | 5 +- internal/netutil/session_cache.go | 55 +++---- internal/server/http/server.go | 14 +- internal/server/server.go | 9 +- internal/server/socks5/server.go | 19 +-- internal/server/socks5/udp_associate.go | 39 ++--- internal/server/tun/server.go | 32 ++-- temp.txt | 208 ++++++++++++++++++++++++ 9 files changed, 300 insertions(+), 124 deletions(-) create mode 100644 temp.txt diff --git a/cmd/spoofdpi/main.go b/cmd/spoofdpi/main.go index 8f5660d7..6365f4ae 100644 --- a/cmd/spoofdpi/main.go +++ b/cmd/spoofdpi/main.go @@ -34,21 +34,25 @@ var ( func main() { cmd := config.CreateCommand(runApp, version, commit, build) - ctx := session.WithNewTraceID(context.Background()) - if err := cmd.Run(ctx, os.Args); err != nil { + appctx, cancel := signal.NotifyContext( + session.WithNewTraceID(context.Background()), + syscall.SIGINT, syscall.SIGTERM, syscall.SIGQUIT, syscall.SIGHUP, + ) + defer cancel() + if err := cmd.Run(appctx, os.Args); err != nil { fmt.Fprintln(os.Stderr, err) os.Exit(1) } } -func runApp(ctx context.Context, configDir string, cfg *config.Config) { +func runApp(appctx context.Context, configDir string, cfg *config.Config) { if !*cfg.App.Silent { printBanner() } - logging.SetGlobalLogger(ctx, *cfg.App.LogLevel) + logging.SetGlobalLogger(appctx, *cfg.App.LogLevel) - logger := log.Logger.With().Ctx(ctx).Logger() + logger := log.Logger.With().Ctx(appctx).Logger() logger.Info().Str("version", version).Msg("started spoofdpi") if configDir != "" { logger.Info(). @@ -58,7 +62,7 @@ func runApp(ctx context.Context, configDir string, cfg *config.Config) { resolver := createResolver(logger, cfg) - srv, err := createServer(logger, cfg, resolver) + srv, err := createServer(appctx, logger, cfg, resolver) if err != nil { logger.Fatal().Err(err).Msg("failed to create server") } @@ -66,7 +70,7 @@ func runApp(ctx context.Context, configDir string, cfg *config.Config) { // Start server ready := make(chan struct{}) go func() { - if err := srv.Start(ctx, ready); err != nil { + if err := srv.ListenAndServe(appctx, ready); err != nil { logger.Fatal().Err(err).Msgf("failed to start server: %T", srv) } }() @@ -124,25 +128,7 @@ func runApp(ctx context.Context, configDir string, cfg *config.Config) { logger.Info().Msgf("server started on %s", srv.Addr()) - sigs := make(chan os.Signal, 1) - done := make(chan bool, 1) - - signal.Notify( - sigs, - syscall.SIGINT, - syscall.SIGTERM, - syscall.SIGQUIT, - syscall.SIGHUP) - - go func() { - <-sigs - done <- true - }() - - <-done - - // Graceful shutdown - _ = srv.Stop() + <-appctx.Done() } func createResolver(logger zerolog.Logger, cfg *config.Config) dns.Resolver { @@ -283,6 +269,7 @@ func createPacketObjects( } func createServer( + appctx context.Context, logger zerolog.Logger, cfg *config.Config, resolver dns.Resolver, @@ -355,9 +342,11 @@ func createServer( udpWriter, udpSniffer, ) + udpPool := netutil.NewSessionCache[netutil.NATKey](4096, 60*time.Second) + udpPool.RunCleanupLoop(appctx) udpAssociateHandler := socks5.NewUdpAssociateHandler( logging.WithScope(logger, "hnd"), - netutil.NewSessionCache[netutil.NATKey](4096, 60*time.Second), + udpPool, udpDesyncer, cfg.UDP.Clone(), ) diff --git a/cmd/spoofdpi/main_test.go b/cmd/spoofdpi/main_test.go index 135a11b2..5eb77e47 100644 --- a/cmd/spoofdpi/main_test.go +++ b/cmd/spoofdpi/main_test.go @@ -1,6 +1,7 @@ package main import ( + "context" "net" "testing" "time" @@ -77,7 +78,7 @@ func TestCreateProxy_NoPcap(t *testing.T) { logger := zerolog.Nop() resolver := createResolver(logger, cfg) - p, err := createServer(logger, cfg, resolver) + p, err := createServer(context.Background(), logger, cfg, resolver) require.NoError(t, err) assert.NotNil(t, p) } @@ -134,7 +135,7 @@ func TestCreateProxy_WithPolicy(t *testing.T) { logger := zerolog.Nop() resolver := createResolver(logger, cfg) - p, err := createServer(logger, cfg, resolver) + p, err := createServer(context.Background(), logger, cfg, resolver) require.NoError(t, err) assert.NotNil(t, p) } diff --git a/internal/netutil/session_cache.go b/internal/netutil/session_cache.go index acd4f1f7..709a0e7d 100644 --- a/internal/netutil/session_cache.go +++ b/internal/netutil/session_cache.go @@ -1,8 +1,8 @@ package netutil import ( + "context" "net" - "sync" "time" "github.com/xvzc/SpoofDPI/internal/cache" @@ -10,21 +10,17 @@ import ( // SessionCache manages UDP connections with LRU eviction policy and idle timeout. type SessionCache[K comparable] struct { - storage cache.Cache[K] - timeout time.Duration - stopCh chan struct{} - stopOnce sync.Once + storage cache.Cache[K] + timeout time.Duration } // NewSessionCache creates a new pool with the specified capacity and timeout. -// Starts a background goroutine for expired connection cleanup. func NewSessionCache[K comparable]( capacity int, timeout time.Duration, ) *SessionCache[K] { p := &SessionCache[K]{ timeout: timeout, - stopCh: make(chan struct{}), } onInvalidate := func(k K, v any) { @@ -35,13 +31,31 @@ func NewSessionCache[K comparable]( p.storage = cache.NewLRUCache(capacity, onInvalidate) + return p +} + +// RunCleanupLoop runs the background cleanup goroutine. +// It exits when appctx is cancelled, closing all remaining cached connections. +func (p *SessionCache[K]) RunCleanupLoop(appctx context.Context) { // Cleanup interval: half of timeout, min 10s, max 60s - cleanupInterval := timeout / 2 + cleanupInterval := p.timeout / 2 cleanupInterval = max(cleanupInterval, 10*time.Second) cleanupInterval = min(cleanupInterval, 60*time.Second) - go p.cleanupLoop(cleanupInterval) - return p + go func() { + ticker := time.NewTicker(cleanupInterval) + defer ticker.Stop() + + for { + select { + case <-appctx.Done(): + p.CloseAll() + return + case <-ticker.C: + p.evictExpired() + } + } + }() } // Store adds a connection to the cache and returns the wrapped connection. @@ -87,13 +101,6 @@ func (p *SessionCache[K]) Size() int { return p.storage.Size() } -// Stop stops the background cleanup goroutine. -func (p *SessionCache[K]) Stop() { - p.stopOnce.Do(func() { - close(p.stopCh) - }) -} - // CloseAll closes all connections in the pool. func (p *SessionCache[K]) CloseAll() { var toRemove []K @@ -106,20 +113,6 @@ func (p *SessionCache[K]) CloseAll() { } } -func (p *SessionCache[K]) cleanupLoop(interval time.Duration) { - ticker := time.NewTicker(interval) - defer ticker.Stop() - - for { - select { - case <-p.stopCh: - return - case <-ticker.C: - p.evictExpired() - } - } -} - func (p *SessionCache[K]) evictExpired() { now := time.Now() var toRemove []K diff --git a/internal/server/http/server.go b/internal/server/http/server.go index c4f351cb..6204f493 100644 --- a/internal/server/http/server.go +++ b/internal/server/http/server.go @@ -55,7 +55,7 @@ func NewHTTPProxy( } } -func (p *HTTPProxy) Start(ctx context.Context, ready chan<- struct{}) error { +func (p *HTTPProxy) ListenAndServe(appctx context.Context, ready chan<- struct{}) error { listener, err := net.ListenTCP("tcp", p.appOpts.ListenAddr) if err != nil { return fmt.Errorf( @@ -66,6 +66,11 @@ func (p *HTTPProxy) Start(ctx context.Context, ready chan<- struct{}) error { } p.listener = listener + go func() { + <-appctx.Done() + listener.Close() + }() + if ready != nil { close(ready) } @@ -87,13 +92,6 @@ func (p *HTTPProxy) Start(ctx context.Context, ready chan<- struct{}) error { } } -func (p *HTTPProxy) Stop() error { - if p.listener != nil { - return p.listener.Close() - } - return nil -} - func (p *HTTPProxy) SetNetworkConfig() error { return SetSystemProxy(p.logger, uint16(p.appOpts.ListenAddr.Port)) } diff --git a/internal/server/server.go b/internal/server/server.go index 94d84b3f..cc26d5f5 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -2,16 +2,13 @@ package server import "context" -// Server represents a core component that processes network traffic +// Server represents a core component that processes network traffic. +// ListenAndServe blocks until ctx is cancelled, then releases all resources. type Server interface { - // Start begins the execution of the server module - Start(ctx context.Context, ready chan<- struct{}) error + ListenAndServe(ctx context.Context, ready chan<- struct{}) error SetNetworkConfig() error UnsetNetworkConfig() error - // Stop gracefully terminates the server and releases resources - Stop() error - // Addr returns the network address or interface name the server is bound to Addr() string } diff --git a/internal/server/socks5/server.go b/internal/server/socks5/server.go index a8c9e35a..8741d148 100644 --- a/internal/server/socks5/server.go +++ b/internal/server/socks5/server.go @@ -34,8 +34,6 @@ type SOCKS5Proxy struct { appOpts *config.AppOptions connOpts *config.ConnOptions policyOpts *config.PolicyOptions - - listener net.Listener } func NewSOCKS5Proxy( @@ -62,7 +60,7 @@ func NewSOCKS5Proxy( } } -func (p *SOCKS5Proxy) Start(ctx context.Context, ready chan<- struct{}) error { +func (p *SOCKS5Proxy) ListenAndServe(appctx context.Context, ready chan<- struct{}) error { listener, err := net.ListenTCP("tcp", p.appOpts.ListenAddr) if err != nil { return fmt.Errorf( @@ -71,7 +69,11 @@ func (p *SOCKS5Proxy) Start(ctx context.Context, ready chan<- struct{}) error { err, ) } - p.listener = listener + + go func() { + <-appctx.Done() + listener.Close() + }() if ready != nil { close(ready) @@ -89,15 +91,8 @@ func (p *SOCKS5Proxy) Start(ctx context.Context, ready chan<- struct{}) error { continue } - go p.handleConnection(session.WithNewTraceID(ctx), conn) - } -} - -func (p *SOCKS5Proxy) Stop() error { - if p.listener != nil { - return p.listener.Close() + go p.handleConnection(session.WithNewTraceID(appctx), conn) } - return nil } func (p *SOCKS5Proxy) SetNetworkConfig() error { diff --git a/internal/server/socks5/udp_associate.go b/internal/server/socks5/udp_associate.go index 4bf11635..c6995bce 100644 --- a/internal/server/socks5/udp_associate.go +++ b/internal/server/socks5/udp_associate.go @@ -90,7 +90,7 @@ func (h *UdpAssociateHandler) Handle( for { // Wait for data - n, addr, err := lNewConn.ReadFromUDP(buf) + n, lAddr, err := lNewConn.ReadFromUDP(buf) if err != nil { // Normal closure check select { @@ -105,37 +105,37 @@ func (h *UdpAssociateHandler) Handle( } // Security: Only accept UDP packets from the same IP that established the TCP connection - if !addr.IP.Equal(tcpRemoteIP) { + if !lAddr.IP.Equal(tcpRemoteIP) { logger.Debug(). Str("expected", tcpRemoteIP.String()). - Str("actual", addr.IP.String()). + Str("actual", lAddr.IP.String()). Msg("dropped udp packet from unexpected ip") continue } // Outbound: Client -> Proxy -> Target - targetAddrStr, payload, err := parseUDPHeader(buf[:n]) + dstUDPAddrStr, payload, err := parseUDPHeader(buf[:n]) if err != nil { logger.Warn().Err(err).Msg("failed to parse socks5 udp header") continue } // Resolve address to construct Destination - uAddr, err := net.ResolveUDPAddr("udp", targetAddrStr) + dstUDPAddr, err := net.ResolveUDPAddr("udp", dstUDPAddrStr) if err != nil { logger.Warn(). Err(err). - Str("addr", targetAddrStr). + Str("addr", dstUDPAddrStr). Msg("failed to resolve udp target") continue } // Key: Client Addr -> Target Addr (Zero Allocation Struct) - key := netutil.NewNATKey(addr, uAddr) + key := netutil.NewNATKey(lAddr, dstUDPAddr) dst := &netutil.Destination{ - Addrs: []net.IP{uAddr.IP}, - Port: uAddr.Port, + Addrs: []net.IP{dstUDPAddr.IP}, + Port: dstUDPAddr.Port, } // Check if connection already exists in the pool @@ -147,14 +147,15 @@ func (h *UdpAssociateHandler) Handle( continue } - rawConn, err := netutil.DialFastest(ctx, "udp", dst) + rConnRaw, err := netutil.DialFastest(ctx, "udp", dst) if err != nil { - logger.Warn().Err(err).Str("addr", targetAddrStr).Msg("failed to dial udp target") + logger.Warn().Err(err).Str("addr", dstUDPAddrStr).Msg("failed to dial udp target") continue } // Add to pool (pool handles LRU eviction and deadline) - conn := h.pool.Store(key, rawConn) + // returns IdleTimeoutConn with the actual net.Conn inside + rConn := h.pool.Store(key, rConnRaw) // Apply UDP options from rule if matched udpOpts := h.defaultUDPOpts.Clone() @@ -164,14 +165,14 @@ func (h *UdpAssociateHandler) Handle( // Send fake packets before real payload (UDP desync) if h.desyncer != nil { - _, _ = h.desyncer.Desync(ctx, lNewConn, conn.Conn, udpOpts) + _, _ = h.desyncer.Desync(ctx, lNewConn, rConn.Conn, udpOpts) } // Start a goroutine to read from the target and forward to the client - go func(targetConn *netutil.IdleTimeoutConn, clientAddr *net.UDPAddr) { + go func(rConn *netutil.IdleTimeoutConn, lAddr *net.UDPAddr) { respBuf := make([]byte, 65535) for { - n, _, err := targetConn.Conn.(*net.UDPConn).ReadFromUDP(respBuf) + n, _, err := rConn.Conn.(*net.UDPConn).ReadFromUDP(respBuf) if err != nil { // Connection closed or network issues return @@ -179,21 +180,21 @@ func (h *UdpAssociateHandler) Handle( // Inbound: Target -> Proxy -> Client // Wrap with SOCKS5 Header - remoteAddr := targetConn.Conn.(*net.UDPConn).RemoteAddr().(*net.UDPAddr) + remoteAddr := rConn.Conn.(*net.UDPConn).RemoteAddr().(*net.UDPAddr) header := createUDPHeaderFromAddr(remoteAddr) response := append(header, respBuf[:n]...) - if _, err := lNewConn.WriteToUDP(response, clientAddr); err != nil { + if _, err := lNewConn.WriteToUDP(response, lAddr); err != nil { // If we can't write back to the client, it might be gone or network issue. // Exit this goroutine to avoid busy looping. logger.Warn().Err(err).Msg("failed to write udp to client") return } } - }(conn, clientAddr) + }(rConn, lAddr) // Write payload to target - if _, err := conn.Write(payload); err != nil { + if _, err := rConn.Write(payload); err != nil { logger.Warn().Err(err).Msg("failed to write udp to target") } } diff --git a/internal/server/tun/server.go b/internal/server/tun/server.go index 1f0411c4..486311bc 100644 --- a/internal/server/tun/server.go +++ b/internal/server/tun/server.go @@ -62,25 +62,19 @@ func NewTunServer( } } -func (s *TunServer) Start(ctx context.Context, ready chan<- struct{}) error { +func (s *TunServer) ListenAndServe(appctx context.Context, ready chan<- struct{}) error { iface, err := NewTunDevice() if err != nil { return fmt.Errorf("failed to create tun device: %w", err) } s.iface = iface + defer iface.Close() if ready != nil { close(ready) } - return s.handle(ctx, iface) -} - -func (s *TunServer) Stop() error { - if s.iface != nil { - return s.iface.Close() - } - return nil + return s.handle(appctx, iface) } func (s *TunServer) SetNetworkConfig() error { @@ -187,8 +181,8 @@ func (s *TunServer) matchRuleByAddr(addr net.Addr) *config.Rule { return s.matcher.Search(selector) } -func (s *TunServer) handle(ctx context.Context, iface *water.Interface) error { - logger := logging.WithLocalScope(ctx, s.logger, "tun") +func (s *TunServer) handle(appctx context.Context, iface *water.Interface) error { + logger := logging.WithLocalScope(appctx, s.logger, "tun") // 1. Create gVisor stack stk := stack.New(stack.Options{ @@ -263,15 +257,15 @@ func (s *TunServer) handle(ctx context.Context, iface *water.Interface) error { stk.SetTransportProtocolHandler(udp.ProtocolNumber, udpFwd.HandlePacket) // 6. Start packet pump - go s.tunToStack(ctx, logger, iface, ep) - go s.stackToTun(ctx, logger, iface, ep) + go s.tunToStack(appctx, logger, iface, ep) + go s.stackToTun(appctx, logger, iface, ep) - <-ctx.Done() + <-appctx.Done() return nil } func (s *TunServer) tunToStack( - ctx context.Context, + appctx context.Context, logger zerolog.Logger, iface *water.Interface, ep *channel.Endpoint, @@ -285,7 +279,7 @@ func (s *TunServer) tunToStack( } select { - case <-ctx.Done(): + case <-appctx.Done(): return default: if err != io.EOF { @@ -340,7 +334,7 @@ func (n *notifier) WriteNotify() { } func (s *TunServer) stackToTun( - ctx context.Context, + appctx context.Context, logger zerolog.Logger, iface *water.Interface, ep *channel.Endpoint, @@ -351,7 +345,7 @@ func (s *TunServer) stackToTun( for { select { - case <-ctx.Done(): + case <-appctx.Done(): return default: } @@ -361,7 +355,7 @@ func (s *TunServer) stackToTun( select { case <-ch: continue - case <-ctx.Done(): + case <-appctx.Done(): return } } diff --git a/temp.txt b/temp.txt new file mode 100644 index 00000000..b94fa025 --- /dev/null +++ b/temp.txt @@ -0,0 +1,208 @@ +package packet + +import ( + "context" + "errors" + "math/rand/v2" + "net" + + "github.com/google/gopacket" + "github.com/google/gopacket/layers" + "github.com/rs/zerolog" +) + +var _ Writer = (*TCPWriter)(nil) + +type TCPWriter struct { + logger zerolog.Logger + + handle Handle + iface *net.Interface + gatewayMAC net.HardwareAddr +} + +func NewTCPWriter( + logger zerolog.Logger, + handle Handle, + iface *net.Interface, + gatewayMAC net.HardwareAddr, +) *TCPWriter { + return &TCPWriter{ + logger: logger, + handle: handle, + iface: iface, + gatewayMAC: gatewayMAC, + } +} + +// --- Injector Methods --- + +// WriteCraftedPacket crafts and injects a full TCP packet from a payload. +// It uses the pre-configured gateway MAC address. +func (tw *TCPWriter) WriteCraftedPacket( + ctx context.Context, + src net.Addr, + dst net.Addr, + ttl uint8, + payload []byte, +) (int, error) { + // set variables for src/dst + srcMAC := tw.iface.HardwareAddr + dstMAC := tw.gatewayMAC + + srcTCP, ok := src.(*net.TCPAddr) + if !ok { + return 0, errors.New("src is not *net.TCPAddr") + } + + dstTCP, ok := dst.(*net.TCPAddr) + if !ok { + return 0, errors.New("dst is not *net.TCPAddr") + } + + srcPort := srcTCP.Port + dstPort := dstTCP.Port + + var err error + var packetLayers []gopacket.SerializableLayer + if dstTCP.IP.To4() != nil { + packetLayers, err = tw.createIPv4Layers( + srcMAC, + dstMAC, + srcTCP.IP, + dstTCP.IP, + srcPort, + dstPort, + ttl, + ) + } else { + packetLayers, err = tw.createIPv6Layers( + srcMAC, + dstMAC, + srcTCP.IP, + srcTCP.IP, + srcPort, + dstPort, + ttl, + ) + } + + if err != nil { + return 0, err + } + + // serialize the packet L2(optional) + L3 + L4 + payload + buf := gopacket.NewSerializeBuffer() + opts := gopacket.SerializeOptions{ + ComputeChecksums: true, // Recalculate checksums + FixLengths: true, // Fix lengths + } + + packetLayers = append(packetLayers, gopacket.Payload(payload)) + + err = gopacket.SerializeLayers(buf, opts, packetLayers...) + if err != nil { + return 0, err + } + + // inject the raw L2 packet + if err := tw.handle.WritePacketData(buf.Bytes()); err != nil { + return 0, err + } + + return len(payload), nil +} + +func (tw *TCPWriter) createIPv4Layers( + srcMAC net.HardwareAddr, + dstMAC net.HardwareAddr, + srcIP net.IP, + dstIP net.IP, + srcPort int, + dstPort int, + ttl uint8, +) ([]gopacket.SerializableLayer, error) { + var packetLayers []gopacket.SerializableLayer + + if srcMAC != nil { + eth := &layers.Ethernet{ + SrcMAC: srcMAC, + DstMAC: dstMAC, + EthernetType: layers.EthernetTypeIPv4, + } + packetLayers = append(packetLayers, eth) + } + + // define ip layer + ipLayer := &layers.IPv4{ + Version: 4, + TTL: ttl, + Protocol: layers.IPProtocolTCP, + SrcIP: srcIP, + DstIP: dstIP, + } + packetLayers = append(packetLayers, ipLayer) + + // define tcp layer + tcpLayer := &layers.TCP{ + SrcPort: layers.TCPPort(srcPort), // Use a random high port + DstPort: layers.TCPPort(dstPort), + Seq: rand.Uint32(), // A random sequence number + PSH: true, // Push the payload + ACK: true, // Assuming this is part of an established flow + Ack: rand.Uint32(), + Window: 12345, + } + packetLayers = append(packetLayers, tcpLayer) + + if err := tcpLayer.SetNetworkLayerForChecksum(ipLayer); err != nil { + return nil, err + } + + return packetLayers, nil +} + +func (tw *TCPWriter) createIPv6Layers( + srcMAC net.HardwareAddr, + dstMAC net.HardwareAddr, + srcIP net.IP, + dstIP net.IP, + srcPort int, + dstPort int, + ttl uint8, +) ([]gopacket.SerializableLayer, error) { + var packetLayers []gopacket.SerializableLayer + + eth := &layers.Ethernet{ + SrcMAC: srcMAC, + DstMAC: dstMAC, + EthernetType: layers.EthernetTypeIPv6, + } + packetLayers = append(packetLayers, eth) + + ipLayer := &layers.IPv6{ + Version: 6, + HopLimit: ttl, + NextHeader: layers.IPProtocolTCP, + SrcIP: srcIP, + DstIP: dstIP, + } + packetLayers = append(packetLayers, ipLayer) + + tcpLayer := &layers.TCP{ + SrcPort: layers.TCPPort(srcPort), + DstPort: layers.TCPPort(dstPort), + Seq: rand.Uint32(), + PSH: true, + ACK: true, + Ack: rand.Uint32(), + Window: 12345, + } + packetLayers = append(packetLayers, tcpLayer) + + if err := tcpLayer.SetNetworkLayerForChecksum(ipLayer); err != nil { + return nil, err + } + + return packetLayers, nil +} From 4533f68665f9074d1216ccd037117edbccec0161 Mon Sep 17 00:00:00 2001 From: xvzc Date: Thu, 19 Mar 2026 09:39:26 +0900 Subject: [PATCH 26/39] chore: remove unnecessary file --- temp.txt | 208 ------------------------------------------------------- 1 file changed, 208 deletions(-) delete mode 100644 temp.txt diff --git a/temp.txt b/temp.txt deleted file mode 100644 index b94fa025..00000000 --- a/temp.txt +++ /dev/null @@ -1,208 +0,0 @@ -package packet - -import ( - "context" - "errors" - "math/rand/v2" - "net" - - "github.com/google/gopacket" - "github.com/google/gopacket/layers" - "github.com/rs/zerolog" -) - -var _ Writer = (*TCPWriter)(nil) - -type TCPWriter struct { - logger zerolog.Logger - - handle Handle - iface *net.Interface - gatewayMAC net.HardwareAddr -} - -func NewTCPWriter( - logger zerolog.Logger, - handle Handle, - iface *net.Interface, - gatewayMAC net.HardwareAddr, -) *TCPWriter { - return &TCPWriter{ - logger: logger, - handle: handle, - iface: iface, - gatewayMAC: gatewayMAC, - } -} - -// --- Injector Methods --- - -// WriteCraftedPacket crafts and injects a full TCP packet from a payload. -// It uses the pre-configured gateway MAC address. -func (tw *TCPWriter) WriteCraftedPacket( - ctx context.Context, - src net.Addr, - dst net.Addr, - ttl uint8, - payload []byte, -) (int, error) { - // set variables for src/dst - srcMAC := tw.iface.HardwareAddr - dstMAC := tw.gatewayMAC - - srcTCP, ok := src.(*net.TCPAddr) - if !ok { - return 0, errors.New("src is not *net.TCPAddr") - } - - dstTCP, ok := dst.(*net.TCPAddr) - if !ok { - return 0, errors.New("dst is not *net.TCPAddr") - } - - srcPort := srcTCP.Port - dstPort := dstTCP.Port - - var err error - var packetLayers []gopacket.SerializableLayer - if dstTCP.IP.To4() != nil { - packetLayers, err = tw.createIPv4Layers( - srcMAC, - dstMAC, - srcTCP.IP, - dstTCP.IP, - srcPort, - dstPort, - ttl, - ) - } else { - packetLayers, err = tw.createIPv6Layers( - srcMAC, - dstMAC, - srcTCP.IP, - srcTCP.IP, - srcPort, - dstPort, - ttl, - ) - } - - if err != nil { - return 0, err - } - - // serialize the packet L2(optional) + L3 + L4 + payload - buf := gopacket.NewSerializeBuffer() - opts := gopacket.SerializeOptions{ - ComputeChecksums: true, // Recalculate checksums - FixLengths: true, // Fix lengths - } - - packetLayers = append(packetLayers, gopacket.Payload(payload)) - - err = gopacket.SerializeLayers(buf, opts, packetLayers...) - if err != nil { - return 0, err - } - - // inject the raw L2 packet - if err := tw.handle.WritePacketData(buf.Bytes()); err != nil { - return 0, err - } - - return len(payload), nil -} - -func (tw *TCPWriter) createIPv4Layers( - srcMAC net.HardwareAddr, - dstMAC net.HardwareAddr, - srcIP net.IP, - dstIP net.IP, - srcPort int, - dstPort int, - ttl uint8, -) ([]gopacket.SerializableLayer, error) { - var packetLayers []gopacket.SerializableLayer - - if srcMAC != nil { - eth := &layers.Ethernet{ - SrcMAC: srcMAC, - DstMAC: dstMAC, - EthernetType: layers.EthernetTypeIPv4, - } - packetLayers = append(packetLayers, eth) - } - - // define ip layer - ipLayer := &layers.IPv4{ - Version: 4, - TTL: ttl, - Protocol: layers.IPProtocolTCP, - SrcIP: srcIP, - DstIP: dstIP, - } - packetLayers = append(packetLayers, ipLayer) - - // define tcp layer - tcpLayer := &layers.TCP{ - SrcPort: layers.TCPPort(srcPort), // Use a random high port - DstPort: layers.TCPPort(dstPort), - Seq: rand.Uint32(), // A random sequence number - PSH: true, // Push the payload - ACK: true, // Assuming this is part of an established flow - Ack: rand.Uint32(), - Window: 12345, - } - packetLayers = append(packetLayers, tcpLayer) - - if err := tcpLayer.SetNetworkLayerForChecksum(ipLayer); err != nil { - return nil, err - } - - return packetLayers, nil -} - -func (tw *TCPWriter) createIPv6Layers( - srcMAC net.HardwareAddr, - dstMAC net.HardwareAddr, - srcIP net.IP, - dstIP net.IP, - srcPort int, - dstPort int, - ttl uint8, -) ([]gopacket.SerializableLayer, error) { - var packetLayers []gopacket.SerializableLayer - - eth := &layers.Ethernet{ - SrcMAC: srcMAC, - DstMAC: dstMAC, - EthernetType: layers.EthernetTypeIPv6, - } - packetLayers = append(packetLayers, eth) - - ipLayer := &layers.IPv6{ - Version: 6, - HopLimit: ttl, - NextHeader: layers.IPProtocolTCP, - SrcIP: srcIP, - DstIP: dstIP, - } - packetLayers = append(packetLayers, ipLayer) - - tcpLayer := &layers.TCP{ - SrcPort: layers.TCPPort(srcPort), - DstPort: layers.TCPPort(dstPort), - Seq: rand.Uint32(), - PSH: true, - ACK: true, - Ack: rand.Uint32(), - Window: 12345, - } - packetLayers = append(packetLayers, tcpLayer) - - if err := tcpLayer.SetNetworkLayerForChecksum(ipLayer); err != nil { - return nil, err - } - - return packetLayers, nil -} From 60affa6d44dcaa1ea2a9da501b087fbfbb3aa961 Mon Sep 17 00:00:00 2001 From: xvzc Date: Thu, 19 Mar 2026 11:09:29 +0900 Subject: [PATCH 27/39] refactor(tun): resolve gateway and interface prior to server startup --- cmd/spoofdpi/main.go | 52 +++++++++++----- internal/dns/https.go | 4 +- internal/netutil/route.go | 6 -- internal/server/http/server.go | 7 ++- internal/server/socks5/server.go | 7 ++- internal/server/socks5/udp_associate.go | 2 +- internal/server/tun/server.go | 80 +++++++++++-------------- internal/server/tun/udp.go | 4 ++ 8 files changed, 90 insertions(+), 72 deletions(-) diff --git a/cmd/spoofdpi/main.go b/cmd/spoofdpi/main.go index 6365f4ae..cbf75d01 100644 --- a/cmd/spoofdpi/main.go +++ b/cmd/spoofdpi/main.go @@ -75,20 +75,6 @@ func runApp(appctx context.Context, configDir string, cfg *config.Config) { } }() - <-ready - - // System Proxy Config - if *cfg.App.AutoConfigureNetwork { - if err := srv.SetNetworkConfig(); err != nil { - logger.Fatal().Err(err).Msg("failed to set system network config") - } - defer func() { - if err := srv.UnsetNetworkConfig(); err != nil { - logger.Error().Err(err).Msg("failed to unset system network config") - } - }() - } - logger.Info().Msg("dns info") logger.Info().Msgf(" query type '%s'", cfg.DNS.QType.String()) logger.Info().Msgf(" resolvers") @@ -128,6 +114,20 @@ func runApp(appctx context.Context, configDir string, cfg *config.Config) { logger.Info().Msgf("server started on %s", srv.Addr()) + <-ready + + // System Proxy Config + if *cfg.App.AutoConfigureNetwork { + if err := srv.SetNetworkConfig(); err != nil { + logger.Fatal().Err(err).Msg("failed to set system network config") + } + defer func() { + if err := srv.UnsetNetworkConfig(); err != nil { + logger.Error().Err(err).Msg("failed to unset system network config") + } + }() + } + <-appctx.Done() } @@ -364,6 +364,22 @@ func createServer( cfg.Policy.Clone(), ), nil case config.AppModeTUN: + // Find default interface and gateway before modifying routes + defaultIface, defaultGateway, err := netutil.GetDefaultInterfaceAndGateway() + if err != nil { + return nil, fmt.Errorf("failed to get default interface: %w", err) + } + logger.Info(). + Str("interface", defaultIface). + Str("gateway", defaultGateway). + Msg("determined default interface and gateway") + // s.defaultIface = defaultIface + // s.defaultGateway = defaultGateway + + // Update handlers with network info + // s.tcpHandler.SetNetworkInfo(defaultIface, defaultGateway) + // s.udpHandler.SetNetworkInfo(defaultIface, defaultGateway) + // tcpHandler := tun.NewTCPHandler( logging.WithScope(logger, "hnd"), ruleMatcher, // For domain-based TLS matching @@ -371,8 +387,8 @@ func createServer( cfg.Conn.Clone(), desyncer, tcpSniffer, // For TTL tracking - "", // iface and gateway will be set later - "", + defaultIface, + defaultGateway, ) udpDesyncer := desync.NewUDPDesyncer( @@ -386,6 +402,8 @@ func createServer( udpDesyncer, cfg.UDP.Clone(), cfg.Conn.Clone(), + defaultIface, + defaultGateway, ) return tun.NewTunServer( @@ -394,6 +412,8 @@ func createServer( ruleMatcher, // For IP-based matching in server.go tcpHandler, udpHandler, + defaultIface, + defaultGateway, ), nil default: return nil, fmt.Errorf("unknown server mode: %s", *cfg.App.Mode) diff --git a/internal/dns/https.go b/internal/dns/https.go index 1442e1fb..b3e3937f 100644 --- a/internal/dns/https.go +++ b/internal/dns/https.go @@ -48,7 +48,9 @@ func NewHTTPSResolver( // Configure HTTP/2 transport explicitly if err := http2.ConfigureTransport(tr); err != nil { - logger.Warn().Err(err).Msg("failed to configure http2 expressly, falling back to default / http/1.1") + logger.Warn(). + Err(err). + Msg("failed to configure http2 expressly, falling back to default / http/1.1") } return &HTTPSResolver{ diff --git a/internal/netutil/route.go b/internal/netutil/route.go index 756cd163..88ba1af0 100644 --- a/internal/netutil/route.go +++ b/internal/netutil/route.go @@ -94,9 +94,3 @@ func GetDefaultInterfaceAndGateway() (string, string, error) { return ifaceName, gateway, nil } - -// GetDefaultInterface returns the name of the default network interface -func GetDefaultInterface() (string, error) { - ifaceName, _, err := GetDefaultInterfaceAndGateway() - return ifaceName, err -} diff --git a/internal/server/http/server.go b/internal/server/http/server.go index 6204f493..ba31e7f4 100644 --- a/internal/server/http/server.go +++ b/internal/server/http/server.go @@ -55,7 +55,10 @@ func NewHTTPProxy( } } -func (p *HTTPProxy) ListenAndServe(appctx context.Context, ready chan<- struct{}) error { +func (p *HTTPProxy) ListenAndServe( + appctx context.Context, + ready chan<- struct{}, +) error { listener, err := net.ListenTCP("tcp", p.appOpts.ListenAddr) if err != nil { return fmt.Errorf( @@ -68,7 +71,7 @@ func (p *HTTPProxy) ListenAndServe(appctx context.Context, ready chan<- struct{} go func() { <-appctx.Done() - listener.Close() + _ = listener.Close() }() if ready != nil { diff --git a/internal/server/socks5/server.go b/internal/server/socks5/server.go index 8741d148..3014dc8c 100644 --- a/internal/server/socks5/server.go +++ b/internal/server/socks5/server.go @@ -60,7 +60,10 @@ func NewSOCKS5Proxy( } } -func (p *SOCKS5Proxy) ListenAndServe(appctx context.Context, ready chan<- struct{}) error { +func (p *SOCKS5Proxy) ListenAndServe( + appctx context.Context, + ready chan<- struct{}, +) error { listener, err := net.ListenTCP("tcp", p.appOpts.ListenAddr) if err != nil { return fmt.Errorf( @@ -72,7 +75,7 @@ func (p *SOCKS5Proxy) ListenAndServe(appctx context.Context, ready chan<- struct go func() { <-appctx.Done() - listener.Close() + _ = listener.Close() }() if ready != nil { diff --git a/internal/server/socks5/udp_associate.go b/internal/server/socks5/udp_associate.go index c6995bce..8a0495fe 100644 --- a/internal/server/socks5/udp_associate.go +++ b/internal/server/socks5/udp_associate.go @@ -82,7 +82,7 @@ func (h *UdpAssociateHandler) Handle( go func() { _, _ = io.Copy(io.Discard, lConn) // Block until TCP closes close(done) // Close the channel to signal UDP handler to exit - lNewConn.Close() // Force ReadFromUDP to unblock and avoid goroutine leak + _ = lNewConn.Close() // Force ReadFromUDP to unblock and avoid goroutine leak }() buf := make([]byte, 65535) diff --git a/internal/server/tun/server.go b/internal/server/tun/server.go index 486311bc..91cb9640 100644 --- a/internal/server/tun/server.go +++ b/internal/server/tun/server.go @@ -41,9 +41,9 @@ type TunServer struct { tcpHandler *TCPHandler udpHandler *UDPHandler - iface *water.Interface - defaultIface string - defaultGateway string + tunDevice *water.Interface + iface string + gateway string } func NewTunServer( @@ -52,6 +52,8 @@ func NewTunServer( matcher matcher.RuleMatcher, tcpHandler *TCPHandler, udpHandler *UDPHandler, + iface string, + gateway string, ) server.Server { return &TunServer{ logger: logger, @@ -59,51 +61,44 @@ func NewTunServer( matcher: matcher, tcpHandler: tcpHandler, udpHandler: udpHandler, + iface: iface, + gateway: gateway, } } -func (s *TunServer) ListenAndServe(appctx context.Context, ready chan<- struct{}) error { - iface, err := NewTunDevice() +func (s *TunServer) ListenAndServe( + appctx context.Context, + ready chan<- struct{}, +) error { + var err error + s.tunDevice, err = NewTunDevice() if err != nil { return fmt.Errorf("failed to create tun device: %w", err) } - s.iface = iface - defer iface.Close() + + go func() { + <-appctx.Done() + _ = s.tunDevice.Close() + }() if ready != nil { close(ready) } - return s.handle(appctx, iface) + return s.handle(appctx) } func (s *TunServer) SetNetworkConfig() error { - if s.iface == nil { + if s.tunDevice == nil { return fmt.Errorf("tun device not initialized") } - // Find default interface and gateway before modifying routes - defaultIface, defaultGateway, err := netutil.GetDefaultInterfaceAndGateway() - if err != nil { - return fmt.Errorf("failed to get default interface: %w", err) - } - s.logger.Info(). - Str("interface", defaultIface). - Str("gateway", defaultGateway). - Msg("determined default interface and gateway") - s.defaultIface = defaultIface - s.defaultGateway = defaultGateway - - // Update handlers with network info - s.tcpHandler.SetNetworkInfo(defaultIface, defaultGateway) - s.udpHandler.SetNetworkInfo(defaultIface, defaultGateway) - local, remote, err := netutil.FindSafeSubnet() if err != nil { return fmt.Errorf("failed to find safe subnet: %w", err) } - if err := SetInterfaceAddress(s.iface.Name(), local, remote); err != nil { + if err := SetInterfaceAddress(s.tunDevice.Name(), local, remote); err != nil { return fmt.Errorf("failed to set interface address: %w", err) } @@ -118,38 +113,38 @@ func (s *TunServer) SetNetworkConfig() error { localIP[15]&0xFC, ) // Mask with /30 - err = SetRoute(s.iface.Name(), []string{networkAddr.String() + "/30"}) + err = SetRoute(s.tunDevice.Name(), []string{networkAddr.String() + "/30"}) if err != nil { return fmt.Errorf("failed to set local route: %w", err) } // Add a host route to the gateway via the physical interface // This ensures SpoofDPI's outbound traffic goes through en0, not utun8 - if err := SetGatewayRoute(defaultGateway, defaultIface); err != nil { - s.logger.Warn().Err(err).Msg("failed to set gateway route") + if err := SetGatewayRoute(s.gateway, s.iface); err != nil { + s.logger.Error().Err(err).Msg("failed to set gateway route") } - return SetRoute(s.iface.Name(), []string{"0.0.0.0/0"}) // Default Route + return SetRoute(s.tunDevice.Name(), []string{"0.0.0.0/0"}) // Default Route } func (s *TunServer) UnsetNetworkConfig() error { - if s.iface == nil { + if s.tunDevice == nil { return nil } // Remove the gateway route - if s.defaultGateway != "" && s.defaultIface != "" { - if err := UnsetGatewayRoute(s.defaultGateway, s.defaultIface); err != nil { + if s.gateway != "" && s.iface != "" { + if err := UnsetGatewayRoute(s.gateway, s.iface); err != nil { s.logger.Warn().Err(err).Msg("failed to unset gateway route") } } - return UnsetRoute(s.iface.Name(), []string{"0.0.0.0/0"}) // Default Route + return UnsetRoute(s.tunDevice.Name(), []string{"0.0.0.0/0"}) // Default Route } func (s *TunServer) Addr() string { - if s.iface != nil { - return s.iface.Name() + if s.tunDevice != nil { + return s.tunDevice.Name() } return "tun" } @@ -181,7 +176,7 @@ func (s *TunServer) matchRuleByAddr(addr net.Addr) *config.Rule { return s.matcher.Search(selector) } -func (s *TunServer) handle(appctx context.Context, iface *water.Interface) error { +func (s *TunServer) handle(appctx context.Context) error { logger := logging.WithLocalScope(appctx, s.logger, "tun") // 1. Create gVisor stack @@ -257,22 +252,20 @@ func (s *TunServer) handle(appctx context.Context, iface *water.Interface) error stk.SetTransportProtocolHandler(udp.ProtocolNumber, udpFwd.HandlePacket) // 6. Start packet pump - go s.tunToStack(appctx, logger, iface, ep) - go s.stackToTun(appctx, logger, iface, ep) + go s.tunToStack(appctx, logger, ep) + s.stackToTun(appctx, logger, ep) - <-appctx.Done() return nil } func (s *TunServer) tunToStack( appctx context.Context, logger zerolog.Logger, - iface *water.Interface, ep *channel.Endpoint, ) { buf := make([]byte, 2000) for { - n, err := iface.Read(buf) + n, err := s.tunDevice.Read(buf) if err != nil { if errors.Is(err, fs.ErrClosed) || errors.Is(err, os.ErrClosed) { return @@ -336,7 +329,6 @@ func (n *notifier) WriteNotify() { func (s *TunServer) stackToTun( appctx context.Context, logger zerolog.Logger, - iface *water.Interface, ep *channel.Endpoint, ) { ch := make(chan struct{}, 1) @@ -362,7 +354,7 @@ func (s *TunServer) stackToTun( views := pkt.ToView().AsSlice() if len(views) > 0 { - _, _ = iface.Write(views) + _, _ = s.tunDevice.Write(views) } pkt.DecRef() } diff --git a/internal/server/tun/udp.go b/internal/server/tun/udp.go index 99a09a8c..39b0bee8 100644 --- a/internal/server/tun/udp.go +++ b/internal/server/tun/udp.go @@ -27,12 +27,16 @@ func NewUDPHandler( desyncer *desync.UDPDesyncer, defaultUDPOpts *config.UDPOptions, defaultConnOpts *config.ConnOptions, + iface string, + gateway string, ) *UDPHandler { return &UDPHandler{ logger: logger, desyncer: desyncer, defaultUDPOpts: defaultUDPOpts, defaultConnOpts: defaultConnOpts, + iface: iface, + gateway: gateway, } } From bcb42ab330e1f912610b0d3156c7fe45470a08b7 Mon Sep 17 00:00:00 2001 From: xvzc Date: Thu, 19 Mar 2026 14:47:44 +0900 Subject: [PATCH 28/39] fix(cache): trigger onInvalidate on duplicate key updates LRUCache previously overwrote existing values without notification when storing a duplicate key. To prevent potential resource leaks, the onInvalidate callback is now invoked prior to storing the new value, ensuring the previous entry is properly cleaned up. --- internal/cache/lru_cache.go | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/internal/cache/lru_cache.go b/internal/cache/lru_cache.go index f2d83aa8..e7323c7d 100644 --- a/internal/cache/lru_cache.go +++ b/internal/cache/lru_cache.go @@ -106,8 +106,12 @@ func (c *LRUCache[K]) Store(key K, value any, opts *options) bool { return false } - if ok { + if ok { // Key already exists entry := element.Value.(*lruEntry[K]) + if c.onInvalidate != nil { + // Invoke onInvalidate to ensure associated resources are properly released. + c.onInvalidate(entry.key, entry.value) + } entry.value = value c.list.MoveToFront(element) From c5b627db0e1ca0cb7ffe582cfacb803730e8f380 Mon Sep 17 00:00:00 2001 From: xvzc Date: Thu, 19 Mar 2026 14:55:18 +0900 Subject: [PATCH 29/39] refactor(socks5): improve variable naming and logging --- internal/netutil/key.go | 10 +-- internal/server/socks5/udp_associate.go | 82 +++++++++++++++---------- 2 files changed, 54 insertions(+), 38 deletions(-) diff --git a/internal/netutil/key.go b/internal/netutil/key.go index bba313bd..615103ea 100644 --- a/internal/netutil/key.go +++ b/internal/netutil/key.go @@ -59,22 +59,22 @@ func NewIPKey(ip net.IP) IPKey { } // NewNATKey zero-alloc constructs a NATKey from two UDPAddr -func NewNATKey(src *net.UDPAddr, dst *net.UDPAddr) NATKey { +func NewNATKey(srcIP net.IP, srcPort int, dstIP net.IP, dstPort int) NATKey { var k NATKey // net.IP is a slice. Let's force it to 16 bytes for comparable struct key - srcIP16 := src.IP.To16() + srcIP16 := srcIP.To16() if srcIP16 != nil { copy(k.SrcIP[:], srcIP16) } - dstIP16 := dst.IP.To16() + dstIP16 := dstIP.To16() if dstIP16 != nil { copy(k.DstIP[:], dstIP16) } - k.SrcPort = uint16(src.Port) - k.DstPort = uint16(dst.Port) + k.SrcPort = uint16(srcPort) + k.DstPort = uint16(dstPort) return k } diff --git a/internal/server/socks5/udp_associate.go b/internal/server/socks5/udp_associate.go index 8a0495fe..f4d57705 100644 --- a/internal/server/socks5/udp_associate.go +++ b/internal/server/socks5/udp_associate.go @@ -46,21 +46,21 @@ func (h *UdpAssociateHandler) Handle( logger := logging.WithLocalScope(ctx, h.logger, "udp_associate") // 1. Listen on a random UDP port - lAddrTCP := lConn.LocalAddr().(*net.TCPAddr) // SOCKS5 listens on TCP - lNewConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: lAddrTCP.IP, Port: 0}) + lTCPAddr := lConn.LocalAddr().(*net.TCPAddr) // SOCKS5 listens on TCP + lUDPConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: lTCPAddr.IP, Port: 0}) if err != nil { logger.Error().Err(err).Msg("failed to create udp listener") _ = proto.SOCKS5FailureResponse().Write(lConn) return err } - defer netutil.CloseConns(lNewConn) + defer netutil.CloseConns(lUDPConn) logger.Debug(). - Str("addr", lNewConn.LocalAddr().String()). - Str("network", lNewConn.LocalAddr().Network()). + Str("addr", lUDPConn.LocalAddr().String()). + Str("network", lUDPConn.LocalAddr().Network()). Msg("new conn") - lAddr := lNewConn.LocalAddr().(*net.UDPAddr) + lAddr := lUDPConn.LocalAddr().(*net.UDPAddr) logger.Debug(). Str("bind_addr", lAddr.String()). @@ -82,15 +82,15 @@ func (h *UdpAssociateHandler) Handle( go func() { _, _ = io.Copy(io.Discard, lConn) // Block until TCP closes close(done) // Close the channel to signal UDP handler to exit - _ = lNewConn.Close() // Force ReadFromUDP to unblock and avoid goroutine leak + _ = lUDPConn.Close() // Force ReadFromUDP to unblock and avoid goroutine leak }() buf := make([]byte, 65535) - tcpRemoteIP := lConn.RemoteAddr().(*net.TCPAddr).IP + rTCPAddr := lConn.RemoteAddr().(*net.TCPAddr).IP for { // Wait for data - n, lAddr, err := lNewConn.ReadFromUDP(buf) + n, srcAddr, err := lUDPConn.ReadFromUDP(buf) if err != nil { // Normal closure check select { @@ -105,57 +105,65 @@ func (h *UdpAssociateHandler) Handle( } // Security: Only accept UDP packets from the same IP that established the TCP connection - if !lAddr.IP.Equal(tcpRemoteIP) { + if !srcAddr.IP.Equal(rTCPAddr) { logger.Debug(). - Str("expected", tcpRemoteIP.String()). - Str("actual", lAddr.IP.String()). + Str("expected", rTCPAddr.String()). + Str("actual", srcAddr.IP.String()). Msg("dropped udp packet from unexpected ip") continue } // Outbound: Client -> Proxy -> Target - dstUDPAddrStr, payload, err := parseUDPHeader(buf[:n]) + dstAddrStr, payload, err := parseUDPHeader(buf[:n]) if err != nil { logger.Warn().Err(err).Msg("failed to parse socks5 udp header") continue } // Resolve address to construct Destination - dstUDPAddr, err := net.ResolveUDPAddr("udp", dstUDPAddrStr) + dstAddr, err := net.ResolveUDPAddr("udp", dstAddrStr) if err != nil { logger.Warn(). Err(err). - Str("addr", dstUDPAddrStr). + Str("addr", dstAddrStr). Msg("failed to resolve udp target") continue } // Key: Client Addr -> Target Addr (Zero Allocation Struct) - key := netutil.NewNATKey(lAddr, dstUDPAddr) - - dst := &netutil.Destination{ - Addrs: []net.IP{dstUDPAddr.IP}, - Port: dstUDPAddr.Port, - } + key := netutil.NewNATKey(srcAddr.IP, srcAddr.Port, dstAddr.IP, dstAddr.Port) // Check if connection already exists in the pool - if conn, ok := h.pool.Fetch(key); ok { + if cachedConn, ok := h.pool.Fetch(key); ok { + logger.Debug(). + Str("key", fmt.Sprintf("%s > %s", srcAddr.String(), dstAddr.String())). + Msg("session cache hit") + // Write payload to target - if _, err := conn.Write(payload); err != nil { + if _, err := cachedConn.Write(payload); err != nil { logger.Warn().Err(err).Msg("failed to write udp to target") } continue + } else { + logger.Debug(). + Str("key", fmt.Sprintf("%s > %s", srcAddr.String(), dstAddr.String())). + Msg("session cache miss") + } + + dst := &netutil.Destination{ + Addrs: []net.IP{dstAddr.IP}, + Port: dstAddr.Port, } - rConnRaw, err := netutil.DialFastest(ctx, "udp", dst) + rRawConn, err := netutil.DialFastest(ctx, "udp", dst) if err != nil { - logger.Warn().Err(err).Str("addr", dstUDPAddrStr).Msg("failed to dial udp target") + logger.Warn().Err(err).Str("addr", dstAddrStr).Msg("failed to dial udp target") continue } // Add to pool (pool handles LRU eviction and deadline) // returns IdleTimeoutConn with the actual net.Conn inside - rConn := h.pool.Store(key, rConnRaw) + rConn := h.pool.Store(key, rRawConn) // Apply UDP options from rule if matched udpOpts := h.defaultUDPOpts.Clone() @@ -165,14 +173,19 @@ func (h *UdpAssociateHandler) Handle( // Send fake packets before real payload (UDP desync) if h.desyncer != nil { - _, _ = h.desyncer.Desync(ctx, lNewConn, rConn.Conn, udpOpts) + _, _ = h.desyncer.Desync(ctx, lUDPConn, rConn.Conn, udpOpts) } - // Start a goroutine to read from the target and forward to the client - go func(rConn *netutil.IdleTimeoutConn, lAddr *net.UDPAddr) { + // Start a goroutine to read from the target and forward to the client. + // rConn is a connected UDP socket, so all responses come from the single remote. + // Using rConn.Read() (via IdleTimeoutConn) properly extends the idle deadline + // on each inbound packet, preventing premature timeout on asymmetric flows. + // dstAddr is already *net.UDPAddr (resolved above), same as rConn.RemoteAddr(). + go func(rConn *netutil.IdleTimeoutConn, lAddr *net.UDPAddr, remoteAddr *net.UDPAddr) { respBuf := make([]byte, 65535) for { - n, _, err := rConn.Conn.(*net.UDPConn).ReadFromUDP(respBuf) + // Read via IdleTimeoutConn so each inbound packet extends the deadline. + n, err := rConn.Read(respBuf) if err != nil { // Connection closed or network issues return @@ -180,18 +193,21 @@ func (h *UdpAssociateHandler) Handle( // Inbound: Target -> Proxy -> Client // Wrap with SOCKS5 Header - remoteAddr := rConn.Conn.(*net.UDPConn).RemoteAddr().(*net.UDPAddr) header := createUDPHeaderFromAddr(remoteAddr) response := append(header, respBuf[:n]...) - if _, err := lNewConn.WriteToUDP(response, lAddr); err != nil { + if _, err := lUDPConn.WriteToUDP(response, lAddr); err != nil { // If we can't write back to the client, it might be gone or network issue. // Exit this goroutine to avoid busy looping. logger.Warn().Err(err).Msg("failed to write udp to client") return } } - }(rConn, lAddr) + }( + rConn, + srcAddr, + dstAddr, + ) // Write payload to target if _, err := rConn.Write(payload); err != nil { From d7a6a63109bcd67a67c55dc9436a14416158a60c Mon Sep 17 00:00:00 2001 From: xvzc Date: Thu, 19 Mar 2026 15:06:01 +0900 Subject: [PATCH 30/39] test: remove unused test files --- internal/cache/ttl_cache_benchmark_test.go | 54 ---------------------- internal/netutil/key_benchmark_test.go | 30 ------------ 2 files changed, 84 deletions(-) delete mode 100644 internal/cache/ttl_cache_benchmark_test.go delete mode 100644 internal/netutil/key_benchmark_test.go diff --git a/internal/cache/ttl_cache_benchmark_test.go b/internal/cache/ttl_cache_benchmark_test.go deleted file mode 100644 index acedc7f2..00000000 --- a/internal/cache/ttl_cache_benchmark_test.go +++ /dev/null @@ -1,54 +0,0 @@ -package cache - -import ( - "fmt" - "testing" - "time" -) - -type dummyIPKey [16]byte - -func generateDummyIPKey(i int) dummyIPKey { - var k dummyIPKey - k[0] = byte(i) - k[1] = byte(i >> 8) - return k -} - -func BenchmarkCacheKeys(b *testing.B) { - strCache := NewTTLCache[string](TTLCacheAttrs{ - NumOfShards: 1, - CleanupInterval: time.Minute, - }) - - ipCache := NewTTLCache[dummyIPKey](TTLCacheAttrs{ - NumOfShards: 1, - CleanupInterval: time.Minute, - }) - - b.Run("TTLCache_StringKey", func(b *testing.B) { - var keys []string - for i := 0; i < b.N; i++ { - keys = append(keys, "192.168.0."+fmt.Sprint(i)) - } - b.ReportAllocs() - for i := 0; i < b.N; i++ { - key := keys[i] - strCache.Store(key, 1, nil) - strCache.Fetch(key) - } - }) - - b.Run("TTLCache_GenericStructKey", func(b *testing.B) { - var keys []dummyIPKey - for i := 0; i < b.N; i++ { - keys = append(keys, generateDummyIPKey(i)) - } - b.ReportAllocs() - for i := 0; i < b.N; i++ { - key := keys[i] - ipCache.Store(key, 1, nil) - ipCache.Fetch(key) - } - }) -} diff --git a/internal/netutil/key_benchmark_test.go b/internal/netutil/key_benchmark_test.go deleted file mode 100644 index 9b1fca6a..00000000 --- a/internal/netutil/key_benchmark_test.go +++ /dev/null @@ -1,30 +0,0 @@ -package netutil - -import ( - "net" - "testing" -) - -func BenchmarkKeyAllocation(b *testing.B) { - // Dummy test data - clientAddr := &net.UDPAddr{IP: net.ParseIP("192.168.1.100"), Port: 12345} - targetAddrStr := "142.250.190.46:443" - - uAddr, _ := net.ResolveUDPAddr("udp", targetAddrStr) - - b.Run("StringKey_Legacy", func(b *testing.B) { - b.ReportAllocs() - for i := 0; i < b.N; i++ { - // This was the old way - _ = clientAddr.String() + ">" + targetAddrStr - } - }) - - b.Run("StructKey_NATKey", func(b *testing.B) { - b.ReportAllocs() - for i := 0; i < b.N; i++ { - // This is the new way (zero allocation) - _ = NewNATKey(clientAddr, uAddr) - } - }) -} From 0f9799fad6c518b7970e02d93737e7a8db3ad6f2 Mon Sep 17 00:00:00 2001 From: xvzc Date: Thu, 19 Mar 2026 15:10:17 +0900 Subject: [PATCH 31/39] refactor(netutil): rename SessionCache to ConnRegistry --- cmd/spoofdpi/main.go | 2 +- .../{session_cache.go => conn_registry.go} | 28 +++++++++---------- internal/server/socks5/udp_associate.go | 4 +-- 3 files changed, 17 insertions(+), 17 deletions(-) rename internal/netutil/{session_cache.go => conn_registry.go} (78%) diff --git a/cmd/spoofdpi/main.go b/cmd/spoofdpi/main.go index cbf75d01..c3e33286 100644 --- a/cmd/spoofdpi/main.go +++ b/cmd/spoofdpi/main.go @@ -342,7 +342,7 @@ func createServer( udpWriter, udpSniffer, ) - udpPool := netutil.NewSessionCache[netutil.NATKey](4096, 60*time.Second) + udpPool := netutil.NewConnRegistry[netutil.NATKey](4096, 60*time.Second) udpPool.RunCleanupLoop(appctx) udpAssociateHandler := socks5.NewUdpAssociateHandler( logging.WithScope(logger, "hnd"), diff --git a/internal/netutil/session_cache.go b/internal/netutil/conn_registry.go similarity index 78% rename from internal/netutil/session_cache.go rename to internal/netutil/conn_registry.go index 709a0e7d..050a4a0c 100644 --- a/internal/netutil/session_cache.go +++ b/internal/netutil/conn_registry.go @@ -8,18 +8,18 @@ import ( "github.com/xvzc/SpoofDPI/internal/cache" ) -// SessionCache manages UDP connections with LRU eviction policy and idle timeout. -type SessionCache[K comparable] struct { +// ConnRegistry manages UDP connections with LRU eviction policy and idle timeout. +type ConnRegistry[K comparable] struct { storage cache.Cache[K] timeout time.Duration } -// NewSessionCache creates a new pool with the specified capacity and timeout. -func NewSessionCache[K comparable]( +// NewConnRegistry creates a new pool with the specified capacity and timeout. +func NewConnRegistry[K comparable]( capacity int, timeout time.Duration, -) *SessionCache[K] { - p := &SessionCache[K]{ +) *ConnRegistry[K] { + p := &ConnRegistry[K]{ timeout: timeout, } @@ -36,7 +36,7 @@ func NewSessionCache[K comparable]( // RunCleanupLoop runs the background cleanup goroutine. // It exits when appctx is cancelled, closing all remaining cached connections. -func (p *SessionCache[K]) RunCleanupLoop(appctx context.Context) { +func (p *ConnRegistry[K]) RunCleanupLoop(appctx context.Context) { // Cleanup interval: half of timeout, min 10s, max 60s cleanupInterval := p.timeout / 2 cleanupInterval = max(cleanupInterval, 10*time.Second) @@ -61,7 +61,7 @@ func (p *SessionCache[K]) RunCleanupLoop(appctx context.Context) { // Store adds a connection to the cache and returns the wrapped connection. // If the key already exists, the old connection is closed and evicted first. // If capacity is full, evicts the least recently used connection. -func (p *SessionCache[K]) Store(key K, rawConn net.Conn) *IdleTimeoutConn { +func (p *ConnRegistry[K]) Store(key K, rawConn net.Conn) *IdleTimeoutConn { wrapper := NewIdleTimeoutConn(rawConn, p.timeout) wrapper.Key = key @@ -79,7 +79,7 @@ func (p *SessionCache[K]) Store(key K, rawConn net.Conn) *IdleTimeoutConn { } // Fetch retrieves a connection from the pool, refreshing its LRU status. -func (p *SessionCache[K]) Fetch(key K) (*IdleTimeoutConn, bool) { +func (p *ConnRegistry[K]) Fetch(key K) (*IdleTimeoutConn, bool) { if val, ok := p.storage.Fetch(key); ok { return val.(*IdleTimeoutConn), true } @@ -87,22 +87,22 @@ func (p *SessionCache[K]) Fetch(key K) (*IdleTimeoutConn, bool) { } // Evict closes and removes the connection from the pool. -func (p *SessionCache[K]) Evict(key K) { +func (p *ConnRegistry[K]) Evict(key K) { p.storage.Evict(key) } // Has checks if the connection exists in the cache. -func (p *SessionCache[K]) Has(key K) bool { +func (p *ConnRegistry[K]) Has(key K) bool { return p.storage.Has(key) } // Size returns the number of connections in the pool. -func (p *SessionCache[K]) Size() int { +func (p *ConnRegistry[K]) Size() int { return p.storage.Size() } // CloseAll closes all connections in the pool. -func (p *SessionCache[K]) CloseAll() { +func (p *ConnRegistry[K]) CloseAll() { var toRemove []K _ = p.storage.ForEach(func(key K, value any) error { toRemove = append(toRemove, key) @@ -113,7 +113,7 @@ func (p *SessionCache[K]) CloseAll() { } } -func (p *SessionCache[K]) evictExpired() { +func (p *ConnRegistry[K]) evictExpired() { now := time.Now() var toRemove []K _ = p.storage.ForEach(func(key K, value any) error { diff --git a/internal/server/socks5/udp_associate.go b/internal/server/socks5/udp_associate.go index f4d57705..db9aab06 100644 --- a/internal/server/socks5/udp_associate.go +++ b/internal/server/socks5/udp_associate.go @@ -17,14 +17,14 @@ import ( type UdpAssociateHandler struct { logger zerolog.Logger - pool *netutil.SessionCache[netutil.NATKey] + pool *netutil.ConnRegistry[netutil.NATKey] desyncer *desync.UDPDesyncer defaultUDPOpts *config.UDPOptions } func NewUdpAssociateHandler( logger zerolog.Logger, - pool *netutil.SessionCache[netutil.NATKey], + pool *netutil.ConnRegistry[netutil.NATKey], desyncer *desync.UDPDesyncer, defaultUDPOpts *config.UDPOptions, ) *UdpAssociateHandler { From 25306ab0c3ccd4014b5195422ad71db22c16ed10 Mon Sep 17 00:00:00 2001 From: xvzc Date: Sat, 21 Mar 2026 12:23:31 +0900 Subject: [PATCH 32/39] fix(tcp_writer): also check dstMAC to determine ethernet layer inclusion --- internal/packet/tcp_writer.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/internal/packet/tcp_writer.go b/internal/packet/tcp_writer.go index 4ed28756..146557e9 100644 --- a/internal/packet/tcp_writer.go +++ b/internal/packet/tcp_writer.go @@ -124,7 +124,7 @@ func (tw *TCPWriter) createIPv4Layers( ) ([]gopacket.SerializableLayer, error) { var packetLayers []gopacket.SerializableLayer - if srcMAC != nil { + if srcMAC != nil && dstMAC != nil { eth := &layers.Ethernet{ SrcMAC: srcMAC, DstMAC: dstMAC, @@ -173,7 +173,7 @@ func (tw *TCPWriter) createIPv6Layers( ) ([]gopacket.SerializableLayer, error) { var packetLayers []gopacket.SerializableLayer - if srcMAC != nil { + if srcMAC != nil && dstMAC != nil { eth := &layers.Ethernet{ SrcMAC: srcMAC, DstMAC: dstMAC, From 804063cbe8a91425943288a1acc96d9ebfa8967e Mon Sep 17 00:00:00 2001 From: xvzc Date: Sat, 21 Mar 2026 12:26:03 +0900 Subject: [PATCH 33/39] refactor: rename lAddr to lUDPAddr --- internal/server/socks5/udp_associate.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/internal/server/socks5/udp_associate.go b/internal/server/socks5/udp_associate.go index db9aab06..b76dbfdc 100644 --- a/internal/server/socks5/udp_associate.go +++ b/internal/server/socks5/udp_associate.go @@ -60,14 +60,14 @@ func (h *UdpAssociateHandler) Handle( Str("network", lUDPConn.LocalAddr().Network()). Msg("new conn") - lAddr := lUDPConn.LocalAddr().(*net.UDPAddr) + lUDPAddr := lUDPConn.LocalAddr().(*net.UDPAddr) logger.Debug(). - Str("bind_addr", lAddr.String()). + Str("bind_addr", lUDPAddr.String()). Msg("socks5 udp associate established") // 2. Reply with the bound address - err = proto.SOCKS5SuccessResponse().Bind(lAddr.IP).Port(lAddr.Port).Write(lConn) + err = proto.SOCKS5SuccessResponse().Bind(lUDPAddr.IP).Port(lUDPAddr.Port).Write(lConn) if err != nil { logger.Error().Err(err).Msg("failed to write socks5 success reply") return err From 34928d08dea37663d88e8f4cebf9241b31bc4366 Mon Sep 17 00:00:00 2001 From: xvzc Date: Sat, 21 Mar 2026 21:25:58 +0900 Subject: [PATCH 34/39] refactor: rename network config cleanup function to unset --- cmd/spoofdpi/main.go | 15 +++++--- internal/server/http/network.go | 8 +--- internal/server/http/network_darwin.go | 48 ++++++++++-------------- internal/server/http/server.go | 8 +--- internal/server/server.go | 3 +- internal/server/socks5/network.go | 8 +--- internal/server/socks5/network_darwin.go | 47 +++++++++-------------- internal/server/socks5/server.go | 8 +--- internal/server/tun/server.go | 41 +++++++++++--------- 9 files changed, 78 insertions(+), 108 deletions(-) diff --git a/cmd/spoofdpi/main.go b/cmd/spoofdpi/main.go index c3e33286..62d364ce 100644 --- a/cmd/spoofdpi/main.go +++ b/cmd/spoofdpi/main.go @@ -118,14 +118,17 @@ func runApp(appctx context.Context, configDir string, cfg *config.Config) { // System Proxy Config if *cfg.App.AutoConfigureNetwork { - if err := srv.SetNetworkConfig(); err != nil { + unset, err := srv.SetNetworkConfig() + if err != nil { logger.Fatal().Err(err).Msg("failed to set system network config") } - defer func() { - if err := srv.UnsetNetworkConfig(); err != nil { - logger.Error().Err(err).Msg("failed to unset system network config") - } - }() + if unset != nil { + defer func() { + if err := unset(); err != nil { + logger.Error().Err(err).Msg("failed to unset system network config") + } + }() + } } <-appctx.Done() diff --git a/internal/server/http/network.go b/internal/server/http/network.go index 13e47a26..bf8f070b 100644 --- a/internal/server/http/network.go +++ b/internal/server/http/network.go @@ -6,10 +6,6 @@ import ( "github.com/rs/zerolog" ) -func SetSystemProxy(logger zerolog.Logger, port uint16) error { - return nil -} - -func UnsetSystemProxy(logger zerolog.Logger) error { - return nil +func setSystemProxy(logger zerolog.Logger, port uint16) (func() error, error) { + return func() error { return nil }, nil } diff --git a/internal/server/http/network_darwin.go b/internal/server/http/network_darwin.go index 71e90fd3..6714a04b 100644 --- a/internal/server/http/network_darwin.go +++ b/internal/server/http/network_darwin.go @@ -5,7 +5,6 @@ package http import ( "errors" "fmt" - "net" "os/exec" "strconv" "strings" @@ -23,12 +22,10 @@ const ( " -system-proxy=false." ) -var pacListener net.Listener - -func SetSystemProxy(logger zerolog.Logger, port uint16) error { +func setSystemProxy(logger zerolog.Logger, port uint16) (func() error, error) { network, err := getDefaultNetwork() if err != nil { - return err + return nil, err } portStr := strconv.Itoa(int(port)) @@ -36,47 +33,40 @@ func SetSystemProxy(logger zerolog.Logger, port uint16) error { return "PROXY 127.0.0.1:%s; DIRECT"; }`, portStr) - pacURL, l, err := netutil.RunPACServer(pacContent) + pacURL, pacListener, err := netutil.RunPACServer(pacContent) if err != nil { - return fmt.Errorf("error creating pac server: %w", err) + return nil, fmt.Errorf("error creating pac server: %w", err) } - pacListener = l // Enable Auto Proxy Configuration // networksetup -setautoproxyurl if err := networkSetup("-setautoproxyurl", network, pacURL); err != nil { - return fmt.Errorf("setting autoproxyurl: %w", err) + _ = pacListener.Close() + return nil, fmt.Errorf("setting autoproxyurl: %w", err) } // networksetup -setproxyautodiscovery if err := networkSetup("-setproxyautodiscovery", network, "on"); err != nil { - return fmt.Errorf("setting proxyautodiscovery: %w", err) - } - - return nil -} - -func UnsetSystemProxy(logger zerolog.Logger) error { - if pacListener != nil { _ = pacListener.Close() - pacListener = nil + return nil, fmt.Errorf("setting proxyautodiscovery: %w", err) } - network, err := getDefaultNetwork() - if err != nil { - return err - } + unset := func() error { + _ = pacListener.Close() - // Disable Auto Proxy Configuration - if err := networkSetup("-setautoproxystate", network, "off"); err != nil { - return fmt.Errorf("unsetting autoproxystate: %w", err) - } + // Disable Auto Proxy Configuration + if err := networkSetup("-setautoproxystate", network, "off"); err != nil { + return fmt.Errorf("unsetting autoproxystate: %w", err) + } - if err := networkSetup("-setproxyautodiscovery", network, "off"); err != nil { - return fmt.Errorf("unsetting proxyautodiscovery: %w", err) + if err := networkSetup("-setproxyautodiscovery", network, "off"); err != nil { + return fmt.Errorf("unsetting proxyautodiscovery: %w", err) + } + + return nil } - return nil + return unset, nil } func getDefaultNetwork() (string, error) { diff --git a/internal/server/http/server.go b/internal/server/http/server.go index ba31e7f4..4c7d4cb2 100644 --- a/internal/server/http/server.go +++ b/internal/server/http/server.go @@ -95,12 +95,8 @@ func (p *HTTPProxy) ListenAndServe( } } -func (p *HTTPProxy) SetNetworkConfig() error { - return SetSystemProxy(p.logger, uint16(p.appOpts.ListenAddr.Port)) -} - -func (p *HTTPProxy) UnsetNetworkConfig() error { - return UnsetSystemProxy(p.logger) +func (p *HTTPProxy) SetNetworkConfig() (func() error, error) { + return setSystemProxy(p.logger, uint16(p.appOpts.ListenAddr.Port)) } func (p *HTTPProxy) Addr() string { diff --git a/internal/server/server.go b/internal/server/server.go index cc26d5f5..46d2225a 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -6,8 +6,7 @@ import "context" // ListenAndServe blocks until ctx is cancelled, then releases all resources. type Server interface { ListenAndServe(ctx context.Context, ready chan<- struct{}) error - SetNetworkConfig() error - UnsetNetworkConfig() error + SetNetworkConfig() (func() error, error) // Addr returns the network address or interface name the server is bound to Addr() string diff --git a/internal/server/socks5/network.go b/internal/server/socks5/network.go index c7100f56..d53f5133 100644 --- a/internal/server/socks5/network.go +++ b/internal/server/socks5/network.go @@ -6,10 +6,6 @@ import ( "github.com/rs/zerolog" ) -func SetSystemProxy(logger zerolog.Logger, port uint16) error { - return nil -} - -func UnsetSystemProxy(logger zerolog.Logger) error { - return nil +func setSystemProxy(logger zerolog.Logger, port uint16) (func() error, error) { + return func() error { return nil }, nil } diff --git a/internal/server/socks5/network_darwin.go b/internal/server/socks5/network_darwin.go index c23e07fb..e5a65790 100644 --- a/internal/server/socks5/network_darwin.go +++ b/internal/server/socks5/network_darwin.go @@ -5,7 +5,6 @@ package socks5 import ( "errors" "fmt" - "net" "os/exec" "strconv" "strings" @@ -23,12 +22,10 @@ const ( " -system-proxy=false." ) -var pacListener net.Listener - -func SetSystemProxy(logger zerolog.Logger, port uint16) error { +func setSystemProxy(logger zerolog.Logger, port uint16) (func() error, error) { network, err := getDefaultNetwork() if err != nil { - return err + return nil, err } portStr := strconv.Itoa(int(port)) @@ -36,47 +33,39 @@ func SetSystemProxy(logger zerolog.Logger, port uint16) error { return "SOCKS5 127.0.0.1:%s; DIRECT"; }`, portStr) - pacURL, l, err := netutil.RunPACServer(pacContent) + pacURL, pacListener, err := netutil.RunPACServer(pacContent) if err != nil { - return fmt.Errorf("error creating pac server: %w", err) + return nil, fmt.Errorf("error creating pac server: %w", err) } - pacListener = l // Enable Auto Proxy Configuration // networksetup -setautoproxyurl if err := networkSetup("-setautoproxyurl", network, pacURL); err != nil { - return fmt.Errorf("setting autoproxyurl: %w", err) + _ = pacListener.Close() + return nil, fmt.Errorf("setting autoproxyurl: %w", err) } // networksetup -setproxyautodiscovery if err := networkSetup("-setproxyautodiscovery", network, "on"); err != nil { - return fmt.Errorf("setting proxyautodiscovery: %w", err) - } - - return nil -} - -func UnsetSystemProxy(logger zerolog.Logger) error { - if pacListener != nil { _ = pacListener.Close() - pacListener = nil + return nil, fmt.Errorf("setting proxyautodiscovery: %w", err) } - network, err := getDefaultNetwork() - if err != nil { - return err - } + unset := func() error { + _ = pacListener.Close() - // Disable Auto Proxy Configuration - if err := networkSetup("-setautoproxystate", network, "off"); err != nil { - return fmt.Errorf("unsetting autoproxystate: %w", err) - } + if err := networkSetup("-setautoproxystate", network, "off"); err != nil { + return fmt.Errorf("unsetting autoproxystate: %w", err) + } - if err := networkSetup("-setproxyautodiscovery", network, "off"); err != nil { - return fmt.Errorf("unsetting proxyautodiscovery: %w", err) + if err := networkSetup("-setproxyautodiscovery", network, "off"); err != nil { + return fmt.Errorf("unsetting proxyautodiscovery: %w", err) + } + + return nil } - return nil + return unset, nil } func getDefaultNetwork() (string, error) { diff --git a/internal/server/socks5/server.go b/internal/server/socks5/server.go index 3014dc8c..bded9e9b 100644 --- a/internal/server/socks5/server.go +++ b/internal/server/socks5/server.go @@ -98,12 +98,8 @@ func (p *SOCKS5Proxy) ListenAndServe( } } -func (p *SOCKS5Proxy) SetNetworkConfig() error { - return SetSystemProxy(p.logger, uint16(p.appOpts.ListenAddr.Port)) -} - -func (p *SOCKS5Proxy) UnsetNetworkConfig() error { - return UnsetSystemProxy(p.logger) +func (p *SOCKS5Proxy) SetNetworkConfig() (func() error, error) { + return setSystemProxy(p.logger, uint16(p.appOpts.ListenAddr.Port)) } func (p *SOCKS5Proxy) Addr() string { diff --git a/internal/server/tun/server.go b/internal/server/tun/server.go index 91cb9640..aff39d42 100644 --- a/internal/server/tun/server.go +++ b/internal/server/tun/server.go @@ -71,7 +71,7 @@ func (s *TunServer) ListenAndServe( ready chan<- struct{}, ) error { var err error - s.tunDevice, err = NewTunDevice() + s.tunDevice, err = newTunDevice() if err != nil { return fmt.Errorf("failed to create tun device: %w", err) } @@ -88,18 +88,18 @@ func (s *TunServer) ListenAndServe( return s.handle(appctx) } -func (s *TunServer) SetNetworkConfig() error { +func (s *TunServer) SetNetworkConfig() (func() error, error) { if s.tunDevice == nil { - return fmt.Errorf("tun device not initialized") + return nil, fmt.Errorf("tun device not initialized") } local, remote, err := netutil.FindSafeSubnet() if err != nil { - return fmt.Errorf("failed to find safe subnet: %w", err) + return nil, fmt.Errorf("failed to find safe subnet: %w", err) } if err := SetInterfaceAddress(s.tunDevice.Name(), local, remote); err != nil { - return fmt.Errorf("failed to set interface address: %w", err) + return nil, fmt.Errorf("failed to set interface address: %w", err) } // Add route for the TUN interface subnet to ensure packets can return @@ -115,7 +115,7 @@ func (s *TunServer) SetNetworkConfig() error { err = SetRoute(s.tunDevice.Name(), []string{networkAddr.String() + "/30"}) if err != nil { - return fmt.Errorf("failed to set local route: %w", err) + return nil, fmt.Errorf("failed to set local route: %w", err) } // Add a host route to the gateway via the physical interface @@ -124,22 +124,27 @@ func (s *TunServer) SetNetworkConfig() error { s.logger.Error().Err(err).Msg("failed to set gateway route") } - return SetRoute(s.tunDevice.Name(), []string{"0.0.0.0/0"}) // Default Route -} - -func (s *TunServer) UnsetNetworkConfig() error { - if s.tunDevice == nil { - return nil + err = SetRoute(s.tunDevice.Name(), []string{"0.0.0.0/0"}) // Default Route + if err != nil { + return nil, fmt.Errorf("failed to set default route: %w", err) } - // Remove the gateway route - if s.gateway != "" && s.iface != "" { - if err := UnsetGatewayRoute(s.gateway, s.iface); err != nil { - s.logger.Warn().Err(err).Msg("failed to unset gateway route") + unset := func() error { + if s.tunDevice == nil { + return nil } + + // Remove the gateway route + if s.gateway != "" && s.iface != "" { + if err := UnsetGatewayRoute(s.gateway, s.iface); err != nil { + s.logger.Warn().Err(err).Msg("failed to unset gateway route") + } + } + + return UnsetRoute(s.tunDevice.Name(), []string{"0.0.0.0/0"}) // Default Route } - return UnsetRoute(s.tunDevice.Name(), []string{"0.0.0.0/0"}) // Default Route + return unset, nil } func (s *TunServer) Addr() string { @@ -360,7 +365,7 @@ func (s *TunServer) stackToTun( } } -func NewTunDevice() (*water.Interface, error) { +func newTunDevice() (*water.Interface, error) { config := water.Config{ DeviceType: water.TUN, } From b71e1493ad7f6a57e3421ea9df3f0a67ede51b45 Mon Sep 17 00:00:00 2001 From: xvzc Date: Sat, 21 Mar 2026 21:52:31 +0900 Subject: [PATCH 35/39] style: remove trailing whitespace in network_darwin --- internal/server/http/network_darwin.go | 2 +- internal/server/socks5/network_darwin.go | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/internal/server/http/network_darwin.go b/internal/server/http/network_darwin.go index 6714a04b..7a703ad9 100644 --- a/internal/server/http/network_darwin.go +++ b/internal/server/http/network_darwin.go @@ -62,7 +62,7 @@ func setSystemProxy(logger zerolog.Logger, port uint16) (func() error, error) { if err := networkSetup("-setproxyautodiscovery", network, "off"); err != nil { return fmt.Errorf("unsetting proxyautodiscovery: %w", err) } - + return nil } diff --git a/internal/server/socks5/network_darwin.go b/internal/server/socks5/network_darwin.go index e5a65790..4f6ea69e 100644 --- a/internal/server/socks5/network_darwin.go +++ b/internal/server/socks5/network_darwin.go @@ -61,7 +61,7 @@ func setSystemProxy(logger zerolog.Logger, port uint16) (func() error, error) { if err := networkSetup("-setproxyautodiscovery", network, "off"); err != nil { return fmt.Errorf("unsetting proxyautodiscovery: %w", err) } - + return nil } From c892f273e7e1f07b3332d4d8f5e4df74aa92475e Mon Sep 17 00:00:00 2001 From: xvzc Date: Sat, 21 Mar 2026 21:52:31 +0900 Subject: [PATCH 36/39] fix(socks5): send success response after upstream dial --- internal/server/socks5/connect.go | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/internal/server/socks5/connect.go b/internal/server/socks5/connect.go index fc22cb9d..b83117ba 100644 --- a/internal/server/socks5/connect.go +++ b/internal/server/socks5/connect.go @@ -77,22 +77,22 @@ func (h *ConnectHandler) Handle( return netutil.ErrBlocked } - // 3. Send Success Response - err = proto.SOCKS5SuccessResponse().Bind(net.IPv4zero).Port(0).Write(lConn) - if err != nil { - logger.Error().Err(err).Msg("failed to write socks5 success reply") - return err - } - - // logger := logging.WithLocalScope(ctx, h.logger, "connect(tcp)") dst.Timeout = *connOpts.TCPTimeout rConn, err := netutil.DialFastest(ctx, "tcp", dst) if err != nil { + _ = proto.SOCKS5FailureResponse().Write(lConn) return err } defer netutil.CloseConns(rConn) + // 3. Send Success Response + err = proto.SOCKS5SuccessResponse().Bind(net.IPv4zero).Port(0).Write(lConn) + if err != nil { + logger.Error().Err(err).Msg("failed to write socks5 success reply") + return err + } + logger.Debug().Msgf("new remote conn -> %s", rConn.RemoteAddr()) // Wrap lConn with a buffered reader to peek for TLS From 9242e82e195b7ec8a209b2b18d137766fd05dddd Mon Sep 17 00:00:00 2001 From: xvzc Date: Sat, 21 Mar 2026 21:59:37 +0900 Subject: [PATCH 37/39] fix(netutil): fix goroutine leak by returning and closing http.Server in RunPACServer --- internal/netutil/pac.go | 4 ++-- internal/server/http/network_darwin.go | 8 ++++---- internal/server/socks5/network_darwin.go | 8 ++++---- 3 files changed, 10 insertions(+), 10 deletions(-) diff --git a/internal/netutil/pac.go b/internal/netutil/pac.go index 67b27a11..03b936cc 100644 --- a/internal/netutil/pac.go +++ b/internal/netutil/pac.go @@ -6,7 +6,7 @@ import ( "net/http" ) -func RunPACServer(content string) (string, net.Listener, error) { +func RunPACServer(content string) (string, *http.Server, error) { listener, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { return "", nil, err @@ -30,5 +30,5 @@ func RunPACServer(content string) (string, net.Listener, error) { addr := listener.Addr().(*net.TCPAddr) url := fmt.Sprintf("http://127.0.0.1:%d/proxy.pac", addr.Port) - return url, listener, nil + return url, server, nil } diff --git a/internal/server/http/network_darwin.go b/internal/server/http/network_darwin.go index 7a703ad9..b2a3d6a8 100644 --- a/internal/server/http/network_darwin.go +++ b/internal/server/http/network_darwin.go @@ -33,7 +33,7 @@ func setSystemProxy(logger zerolog.Logger, port uint16) (func() error, error) { return "PROXY 127.0.0.1:%s; DIRECT"; }`, portStr) - pacURL, pacListener, err := netutil.RunPACServer(pacContent) + pacURL, pacServer, err := netutil.RunPACServer(pacContent) if err != nil { return nil, fmt.Errorf("error creating pac server: %w", err) } @@ -41,18 +41,18 @@ func setSystemProxy(logger zerolog.Logger, port uint16) (func() error, error) { // Enable Auto Proxy Configuration // networksetup -setautoproxyurl if err := networkSetup("-setautoproxyurl", network, pacURL); err != nil { - _ = pacListener.Close() + _ = pacServer.Close() return nil, fmt.Errorf("setting autoproxyurl: %w", err) } // networksetup -setproxyautodiscovery if err := networkSetup("-setproxyautodiscovery", network, "on"); err != nil { - _ = pacListener.Close() + _ = pacServer.Close() return nil, fmt.Errorf("setting proxyautodiscovery: %w", err) } unset := func() error { - _ = pacListener.Close() + _ = pacServer.Close() // Disable Auto Proxy Configuration if err := networkSetup("-setautoproxystate", network, "off"); err != nil { diff --git a/internal/server/socks5/network_darwin.go b/internal/server/socks5/network_darwin.go index 4f6ea69e..bed35848 100644 --- a/internal/server/socks5/network_darwin.go +++ b/internal/server/socks5/network_darwin.go @@ -33,7 +33,7 @@ func setSystemProxy(logger zerolog.Logger, port uint16) (func() error, error) { return "SOCKS5 127.0.0.1:%s; DIRECT"; }`, portStr) - pacURL, pacListener, err := netutil.RunPACServer(pacContent) + pacURL, pacServer, err := netutil.RunPACServer(pacContent) if err != nil { return nil, fmt.Errorf("error creating pac server: %w", err) } @@ -41,18 +41,18 @@ func setSystemProxy(logger zerolog.Logger, port uint16) (func() error, error) { // Enable Auto Proxy Configuration // networksetup -setautoproxyurl if err := networkSetup("-setautoproxyurl", network, pacURL); err != nil { - _ = pacListener.Close() + _ = pacServer.Close() return nil, fmt.Errorf("setting autoproxyurl: %w", err) } // networksetup -setproxyautodiscovery if err := networkSetup("-setproxyautodiscovery", network, "on"); err != nil { - _ = pacListener.Close() + _ = pacServer.Close() return nil, fmt.Errorf("setting proxyautodiscovery: %w", err) } unset := func() error { - _ = pacListener.Close() + _ = pacServer.Close() if err := networkSetup("-setautoproxystate", network, "off"); err != nil { return fmt.Errorf("unsetting autoproxystate: %w", err) From 5b59a8d2f95893eaa7158528c2391c522385ce55 Mon Sep 17 00:00:00 2001 From: xvzc Date: Sat, 21 Mar 2026 22:09:09 +0900 Subject: [PATCH 38/39] refactor(socks5): extract inbound udp relay logic into a method --- internal/server/socks5/udp_associate.go | 59 +++++++++++++++---------- 1 file changed, 36 insertions(+), 23 deletions(-) diff --git a/internal/server/socks5/udp_associate.go b/internal/server/socks5/udp_associate.go index b76dbfdc..a8987f06 100644 --- a/internal/server/socks5/udp_associate.go +++ b/internal/server/socks5/udp_associate.go @@ -181,32 +181,13 @@ func (h *UdpAssociateHandler) Handle( // Using rConn.Read() (via IdleTimeoutConn) properly extends the idle deadline // on each inbound packet, preventing premature timeout on asymmetric flows. // dstAddr is already *net.UDPAddr (resolved above), same as rConn.RemoteAddr(). - go func(rConn *netutil.IdleTimeoutConn, lAddr *net.UDPAddr, remoteAddr *net.UDPAddr) { - respBuf := make([]byte, 65535) - for { - // Read via IdleTimeoutConn so each inbound packet extends the deadline. - n, err := rConn.Read(respBuf) - if err != nil { - // Connection closed or network issues - return - } - - // Inbound: Target -> Proxy -> Client - // Wrap with SOCKS5 Header - header := createUDPHeaderFromAddr(remoteAddr) - response := append(header, respBuf[:n]...) - - if _, err := lUDPConn.WriteToUDP(response, lAddr); err != nil { - // If we can't write back to the client, it might be gone or network issue. - // Exit this goroutine to avoid busy looping. - logger.Warn().Err(err).Msg("failed to write udp to client") - return - } - } - }( + go h.relayInboundUDP( + logger, + lUDPConn, rConn, srcAddr, dstAddr, + key, ) // Write payload to target @@ -216,6 +197,38 @@ func (h *UdpAssociateHandler) Handle( } } +func (h *UdpAssociateHandler) relayInboundUDP( + logger zerolog.Logger, + lUDPConn *net.UDPConn, + rConn *netutil.IdleTimeoutConn, + clientAddr *net.UDPAddr, + targetAddr *net.UDPAddr, + key netutil.NATKey, +) { + respBuf := make([]byte, 65535) + for { + // Read via IdleTimeoutConn so each inbound packet extends the deadline. + n, err := rConn.Read(respBuf) + if err != nil { + // Connection closed or network issues + h.pool.Evict(key) + return + } + + // Inbound: Target -> Proxy -> Client + // Wrap with SOCKS5 Header + header := createUDPHeaderFromAddr(targetAddr) + response := append(header, respBuf[:n]...) + + if _, err := lUDPConn.WriteToUDP(response, clientAddr); err != nil { + // If we can't write back to the client, it might be gone or network issue. + // Exit this goroutine to avoid busy looping. + logger.Warn().Err(err).Msg("failed to write udp to client") + return + } + } +} + func parseUDPHeader(b []byte) (string, []byte, error) { if len(b) < 4 { return "", nil, fmt.Errorf("header too short") From 73d362c505f7d21a097c142c1559dcf191ba1195 Mon Sep 17 00:00:00 2001 From: xvzc Date: Sat, 21 Mar 2026 22:15:59 +0900 Subject: [PATCH 39/39] style(socks5): make relayInboundUDP inline --- internal/server/socks5/udp_associate.go | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/internal/server/socks5/udp_associate.go b/internal/server/socks5/udp_associate.go index a8987f06..2061c875 100644 --- a/internal/server/socks5/udp_associate.go +++ b/internal/server/socks5/udp_associate.go @@ -181,14 +181,7 @@ func (h *UdpAssociateHandler) Handle( // Using rConn.Read() (via IdleTimeoutConn) properly extends the idle deadline // on each inbound packet, preventing premature timeout on asymmetric flows. // dstAddr is already *net.UDPAddr (resolved above), same as rConn.RemoteAddr(). - go h.relayInboundUDP( - logger, - lUDPConn, - rConn, - srcAddr, - dstAddr, - key, - ) + go h.relayInboundUDP(logger, lUDPConn, rConn, srcAddr, dstAddr, key) // Write payload to target if _, err := rConn.Write(payload); err != nil {