diff --git a/cmd/spoofdpi/main.go b/cmd/spoofdpi/main.go index e26874a7..62d364ce 100644 --- a/cmd/spoofdpi/main.go +++ b/cmd/spoofdpi/main.go @@ -16,11 +16,13 @@ import ( "github.com/xvzc/SpoofDPI/internal/dns" "github.com/xvzc/SpoofDPI/internal/logging" "github.com/xvzc/SpoofDPI/internal/matcher" + "github.com/xvzc/SpoofDPI/internal/netutil" "github.com/xvzc/SpoofDPI/internal/packet" - "github.com/xvzc/SpoofDPI/internal/proxy" - "github.com/xvzc/SpoofDPI/internal/proxy/http" + "github.com/xvzc/SpoofDPI/internal/server" + "github.com/xvzc/SpoofDPI/internal/server/http" // Add http import + "github.com/xvzc/SpoofDPI/internal/server/socks5" + "github.com/xvzc/SpoofDPI/internal/server/tun" "github.com/xvzc/SpoofDPI/internal/session" - "github.com/xvzc/SpoofDPI/internal/system" ) // Version and commit are set at build time. @@ -32,21 +34,25 @@ var ( func main() { cmd := config.CreateCommand(runApp, version, commit, build) - ctx := session.WithNewTraceID(context.Background()) - if err := cmd.Run(ctx, os.Args); err != nil { + appctx, cancel := signal.NotifyContext( + session.WithNewTraceID(context.Background()), + syscall.SIGINT, syscall.SIGTERM, syscall.SIGQUIT, syscall.SIGHUP, + ) + defer cancel() + if err := cmd.Run(appctx, os.Args); err != nil { fmt.Fprintln(os.Stderr, err) os.Exit(1) } } -func runApp(ctx context.Context, configDir string, cfg *config.Config) { - if !*cfg.General.Silent { +func runApp(appctx context.Context, configDir string, cfg *config.Config) { + if !*cfg.App.Silent { printBanner() } - logging.SetGlobalLogger(ctx, *cfg.General.LogLevel) + logging.SetGlobalLogger(appctx, *cfg.App.LogLevel) - logger := log.Logger.With().Ctx(ctx).Logger() + logger := log.Logger.With().Ctx(appctx).Logger() logger.Info().Str("version", version).Msg("started spoofdpi") if configDir != "" { logger.Info(). @@ -54,35 +60,20 @@ func runApp(ctx context.Context, configDir string, cfg *config.Config) { Msgf("config file loaded") } - // set system-wide proxy configuration. - if !*cfg.General.SetSystemProxy { - logger.Info().Msg("use `--system-proxy` to automatically set system proxy") - } - resolver := createResolver(logger, cfg) - p, err := createProxy(logger, cfg, resolver) + + srv, err := createServer(appctx, logger, cfg, resolver) if err != nil { - logger.Fatal(). - Err(err). - Msg("failed to create proxy") + logger.Fatal().Err(err).Msg("failed to create server") } - // start app - wait := make(chan struct{}) // wait for setup logs to be printed - go p.ListenAndServe(ctx, wait) - - // set system-wide proxy configuration. - if *cfg.General.SetSystemProxy { - port := cfg.Server.ListenAddr.Port - if err := system.SetProxy(logger, uint16(port)); err != nil { - logger.Fatal().Err(err).Msg("failed to enable system proxy") + // Start server + ready := make(chan struct{}) + go func() { + if err := srv.ListenAndServe(appctx, ready); err != nil { + logger.Fatal().Err(err).Msgf("failed to start server: %T", srv) } - defer func() { - if err := system.UnsetProxy(logger); err != nil { - logger.Fatal().Err(err).Msg("failed to disable system proxy") - } - }() - } + }() logger.Info().Msg("dns info") logger.Info().Msgf(" query type '%s'", cfg.DNS.QType.String()) @@ -103,42 +94,60 @@ func runApp(ctx context.Context, configDir string, cfg *config.Config) { Uint8("count", uint8(*cfg.HTTPS.FakeCount)). Msg(" fake") - logger.Info(). - Bool("auto", *cfg.Policy.Auto). - Msgf("policy") - - if *cfg.Server.Timeout > 0 { + if *cfg.Conn.DNSTimeout > 0 { + logger.Info(). + Str("value", fmt.Sprintf("%dms", cfg.Conn.DNSTimeout.Milliseconds())). + Msgf("dns connection timeout") + } + if *cfg.Conn.TCPTimeout > 0 { + logger.Info(). + Str("value", fmt.Sprintf("%dms", cfg.Conn.TCPTimeout.Milliseconds())). + Msgf("tcp connection timeout") + } + if *cfg.Conn.UDPIdleTimeout > 0 { logger.Info(). - Str("value", fmt.Sprintf("%dms", cfg.Server.Timeout)). - Msgf("connection timeout") + Str("value", fmt.Sprintf("%dms", cfg.Conn.UDPIdleTimeout.Milliseconds())). + Msgf("udp idle timeout") } - wait <- struct{}{} + logger.Info().Msgf("app-mode; %s", cfg.App.Mode.String()) - sigs := make(chan os.Signal, 1) - done := make(chan bool, 1) + logger.Info().Msgf("server started on %s", srv.Addr()) - signal.Notify( - sigs, - syscall.SIGINT, - syscall.SIGTERM, - syscall.SIGQUIT, - syscall.SIGHUP) + <-ready - go func() { - <-sigs - done <- true - }() + // System Proxy Config + if *cfg.App.AutoConfigureNetwork { + unset, err := srv.SetNetworkConfig() + if err != nil { + logger.Fatal().Err(err).Msg("failed to set system network config") + } + if unset != nil { + defer func() { + if err := unset(); err != nil { + logger.Error().Err(err).Msg("failed to unset system network config") + } + }() + } + } - <-done + <-appctx.Done() } func createResolver(logger zerolog.Logger, cfg *config.Config) dns.Resolver { // create a TTL cache for storing DNS records. - udpResolver := dns.NewUDPResolver(logging.WithScope(logger, "dns"), cfg.DNS.Clone()) + udpResolver := dns.NewUDPResolver( + logging.WithScope(logger, "dns"), + cfg.DNS.Clone(), + cfg.Conn.Clone(), + ) - dohResolver := dns.NewHTTPSResolver(logging.WithScope(logger, "dns"), cfg.DNS.Clone()) + dohResolver := dns.NewHTTPSResolver( + logging.WithScope(logger, "dns"), + cfg.DNS.Clone(), + cfg.Conn.Clone(), + ) sysResolver := dns.NewSystemResolver( logging.WithScope(logger, "dns"), @@ -147,7 +156,7 @@ func createResolver(logger zerolog.Logger, cfg *config.Config) dns.Resolver { cacheResolver := dns.NewCacheResolver( logging.WithScope(logger, "dns"), - cache.NewTTLCache( + cache.NewTTLCache[string]( cache.TTLCacheAttrs{ NumOfShards: 64, CleanupInterval: time.Duration(3 * time.Minute), @@ -169,14 +178,14 @@ func createResolver(logger zerolog.Logger, cfg *config.Config) dns.Resolver { func createPacketObjects( logger zerolog.Logger, cfg *config.Config, -) (packet.Sniffer, packet.Writer, error) { +) (packet.Sniffer, packet.Writer, packet.Sniffer, packet.Writer, error) { // create a network detector for passive discovery networkDetector := packet.NewNetworkDetector( logging.WithScope(logger, "pkt"), ) if err := networkDetector.Start(context.Background()); err != nil { - return nil, nil, fmt.Errorf("error starting network detector: %w", err) + return nil, nil, nil, nil, fmt.Errorf("error starting network detector: %w", err) } // Wait for gateway MAC with timeout @@ -185,15 +194,27 @@ func createPacketObjects( gatewayMAC, err := networkDetector.WaitForGatewayMAC(ctx) if err != nil { - return nil, nil, fmt.Errorf("failed to detect gateway (timeout): %w", err) + return nil, nil, nil, nil, fmt.Errorf( + "failed to detect gateway (timeout): %w", + err, + ) } iface := networkDetector.GetInterface() // create a pcap handle for packet capturing. - handle, err := packet.NewHandle(iface) + tcpHandle, err := packet.NewHandle(iface) if err != nil { - return nil, nil, fmt.Errorf( + return nil, nil, nil, nil, fmt.Errorf( + "error opening pcap handle on interface %s: %w", + iface.Name, + err, + ) + } + + udpHandle, err := packet.NewHandle(iface) + if err != nil { + return nil, nil, nil, nil, fmt.Errorf( "error opening pcap handle on interface %s: %w", iface.Name, err, @@ -204,34 +225,58 @@ func createPacketObjects( logger.Info().Str("name", iface.Name). Str("mac", iface.HardwareAddr.String()). Msg(" interface") + + gatewayMACStr := gatewayMAC.String() + if gatewayMACStr == "" { + gatewayMACStr = "none" + } logger.Info(). - Str("mac", gatewayMAC.String()). + Str("mac", gatewayMACStr). Msg(" gateway (passive detection)") - hopCache := cache.NewLRUCache(4096) - sniffer := packet.NewTCPSniffer( + hopCache := cache.NewLRUCache[netutil.IPKey](4096, nil) + + // TCP Objects + tcpSniffer := packet.NewTCPSniffer( logging.WithScope(logger, "pkt"), hopCache, - handle, - uint8(*cfg.Server.DefaultTTL), + tcpHandle, + uint8(*cfg.Conn.DefaultFakeTTL), ) - sniffer.StartCapturing() + tcpSniffer.StartCapturing() - writer := packet.NewTCPWriter( + tcpWriter := packet.NewTCPWriter( logging.WithScope(logger, "pkt"), - handle, + tcpHandle, iface, gatewayMAC, ) - return sniffer, writer, nil + // UDP Objects + udpSniffer := packet.NewUDPSniffer( + logging.WithScope(logger, "pkt"), + hopCache, + udpHandle, + uint8(*cfg.Conn.DefaultFakeTTL), + ) + udpSniffer.StartCapturing() + + udpWriter := packet.NewUDPWriter( + logging.WithScope(logger, "pkt"), + udpHandle, + iface, + gatewayMAC, + ) + + return tcpSniffer, tcpWriter, udpSniffer, udpWriter, nil } -func createProxy( +func createServer( + appctx context.Context, logger zerolog.Logger, cfg *config.Config, resolver dns.Resolver, -) (proxy.ProxyServer, error) { +) (server.Server, error) { ruleMatcher := matcher.NewRuleMatcher( matcher.NewAddrMatcher(), matcher.NewDomainMatcher(), @@ -244,51 +289,138 @@ func createProxy( } } - // create an HTTP handler. - httpHandler := http.NewHTTPHandler(logging.WithScope(logger, "hnd")) - - var sniffer packet.Sniffer - var writer packet.Writer + var tcpSniffer packet.Sniffer + var tcpWriter packet.Writer + var udpSniffer packet.Sniffer + var udpWriter packet.Writer if cfg.ShouldEnablePcap() { var err error - sniffer, writer, err = createPacketObjects(logger, cfg) + tcpSniffer, tcpWriter, udpSniffer, udpWriter, err = createPacketObjects( + logger, + cfg, + ) if err != nil { return nil, err } } - httpsHandler := http.NewHTTPSHandler( - logging.WithScope(logger, "hnd"), - desync.NewTLSDesyncer( - writer, - sniffer, - &desync.TLSDesyncerAttrs{DefaultTTL: *cfg.Server.DefaultTTL}, - ), - sniffer, - cfg.HTTPS.Clone(), + desyncer := desync.NewTLSDesyncer( + tcpWriter, + tcpSniffer, ) - // if cfg.Server.EnableSocks5 != nil && *cfg.Server.EnableSocks5 { - // return socks5.NewSocks5Proxy( - // logging.WithScope(logger, "pxy"), - // resolver, - // httpsHandler, - // ruleMatcher, - // cfg.Server.Clone(), - // cfg.Policy.Clone(), - // ), nil - // } - - return http.NewHTTPProxy( - logging.WithScope(logger, "pxy"), - resolver, - httpHandler, - httpsHandler, - ruleMatcher, - cfg.Server.Clone(), - cfg.Policy.Clone(), - ), nil + switch *cfg.App.Mode { + case config.AppModeHTTP: + httpHandler := http.NewHTTPHandler(logging.WithScope(logger, "hnd")) + httpsHandler := http.NewHTTPSHandler( + logging.WithScope(logger, "hnd"), + desyncer, + tcpSniffer, + cfg.HTTPS.Clone(), + cfg.Conn.Clone(), + ) + + return http.NewHTTPProxy( + logging.WithScope(logger, "srv"), + resolver, + httpHandler, + httpsHandler, + ruleMatcher, + cfg.App.Clone(), + cfg.Conn.Clone(), + cfg.Policy.Clone(), + ), nil + case config.AppModeSOCKS5: + connectHandler := socks5.NewConnectHandler( + logging.WithScope(logger, "hnd"), + desyncer, + tcpSniffer, + cfg.App.Clone(), + cfg.Conn.Clone(), + cfg.HTTPS.Clone(), + ) + udpDesyncer := desync.NewUDPDesyncer( + logging.WithScope(logger, "dsn"), + udpWriter, + udpSniffer, + ) + udpPool := netutil.NewConnRegistry[netutil.NATKey](4096, 60*time.Second) + udpPool.RunCleanupLoop(appctx) + udpAssociateHandler := socks5.NewUdpAssociateHandler( + logging.WithScope(logger, "hnd"), + udpPool, + udpDesyncer, + cfg.UDP.Clone(), + ) + bindHandler := socks5.NewBindHandler(logging.WithScope(logger, "hnd")) + + return socks5.NewSOCKS5Proxy( + logging.WithScope(logger, "srv"), + resolver, + ruleMatcher, + connectHandler, + bindHandler, + udpAssociateHandler, + cfg.App.Clone(), + cfg.Conn.Clone(), + cfg.Policy.Clone(), + ), nil + case config.AppModeTUN: + // Find default interface and gateway before modifying routes + defaultIface, defaultGateway, err := netutil.GetDefaultInterfaceAndGateway() + if err != nil { + return nil, fmt.Errorf("failed to get default interface: %w", err) + } + logger.Info(). + Str("interface", defaultIface). + Str("gateway", defaultGateway). + Msg("determined default interface and gateway") + // s.defaultIface = defaultIface + // s.defaultGateway = defaultGateway + + // Update handlers with network info + // s.tcpHandler.SetNetworkInfo(defaultIface, defaultGateway) + // s.udpHandler.SetNetworkInfo(defaultIface, defaultGateway) + // + tcpHandler := tun.NewTCPHandler( + logging.WithScope(logger, "hnd"), + ruleMatcher, // For domain-based TLS matching + cfg.HTTPS.Clone(), + cfg.Conn.Clone(), + desyncer, + tcpSniffer, // For TTL tracking + defaultIface, + defaultGateway, + ) + + udpDesyncer := desync.NewUDPDesyncer( + logging.WithScope(logger, "hnd"), + udpWriter, + udpSniffer, + ) + + udpHandler := tun.NewUDPHandler( + logging.WithScope(logger, "hnd"), + udpDesyncer, + cfg.UDP.Clone(), + cfg.Conn.Clone(), + defaultIface, + defaultGateway, + ) + + return tun.NewTunServer( + logging.WithScope(logger, "srv"), + cfg, + ruleMatcher, // For IP-based matching in server.go + tcpHandler, + udpHandler, + defaultIface, + defaultGateway, + ), nil + default: + return nil, fmt.Errorf("unknown server mode: %s", *cfg.App.Mode) + } } func printBanner() { diff --git a/cmd/spoofdpi/main_test.go b/cmd/spoofdpi/main_test.go index 7076c373..5eb77e47 100644 --- a/cmd/spoofdpi/main_test.go +++ b/cmd/spoofdpi/main_test.go @@ -1,26 +1,32 @@ package main import ( + "context" "net" "testing" "time" "github.com/rs/zerolog" + "github.com/samber/lo" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/xvzc/SpoofDPI/internal/config" "github.com/xvzc/SpoofDPI/internal/proto" - "github.com/xvzc/SpoofDPI/internal/ptr" ) func TestCreateResolver(t *testing.T) { cfg := config.NewConfig() cfg.DNS = &config.DNSOptions{ - Mode: ptr.FromValue(config.DNSModeUDP), + Mode: lo.ToPtr(config.DNSModeUDP), Addr: &net.TCPAddr{IP: net.ParseIP("8.8.8.8"), Port: 53}, - HTTPSURL: ptr.FromValue("https://dns.google/dns-query"), - QType: ptr.FromValue(config.DNSQueryIPv4), - Cache: ptr.FromValue(true), + HTTPSURL: lo.ToPtr("https://dns.google/dns-query"), + QType: lo.ToPtr(config.DNSQueryIPv4), + Cache: lo.ToPtr(true), + } + cfg.Conn = &config.ConnOptions{ + DNSTimeout: lo.ToPtr(time.Duration(0)), + TCPTimeout: lo.ToPtr(time.Duration(0)), + UDPIdleTimeout: lo.ToPtr(time.Duration(0)), } logger := zerolog.Nop() @@ -30,44 +36,49 @@ func TestCreateResolver(t *testing.T) { } func TestCreateProxy_NoPcap(t *testing.T) { - // Setup configuration that doesn't require PCAP (root privileges) + // Setup configuration that dAppModeHTTP PCAP (root privileges) cfg := config.NewConfig() - // Server Config - cfg.Server = &config.ServerOptions{ + // App Config + cfg.App = &config.AppOptions{ + Mode: lo.ToPtr(config.AppModeHTTP), ListenAddr: &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 0}, - DefaultTTL: ptr.FromValue(uint8(64)), - Timeout: ptr.FromValue(time.Duration(1 * time.Second)), + } + + // Conn Config + cfg.Conn = &config.ConnOptions{ + DefaultFakeTTL: lo.ToPtr(uint8(64)), + DNSTimeout: lo.ToPtr(time.Duration(0)), + TCPTimeout: lo.ToPtr(time.Duration(0)), + UDPIdleTimeout: lo.ToPtr(time.Duration(0)), } // HTTPS Config (Ensure FakeCount is 0 to disable PCAP) cfg.HTTPS = &config.HTTPSOptions{ - Disorder: ptr.FromValue(false), - FakeCount: ptr.FromValue(uint8(0)), + Disorder: lo.ToPtr(false), + FakeCount: lo.ToPtr(uint8(0)), FakePacket: proto.NewFakeTLSMessage([]byte{}), - SplitMode: ptr.FromValue(config.HTTPSSplitModeChunk), - ChunkSize: ptr.FromValue(uint8(10)), - Skip: ptr.FromValue(false), + SplitMode: lo.ToPtr(config.HTTPSSplitModeChunk), + ChunkSize: lo.ToPtr(uint8(10)), + Skip: lo.ToPtr(false), } // Policy Config - cfg.Policy = &config.PolicyOptions{ - Auto: ptr.FromValue(false), - } + cfg.Policy = &config.PolicyOptions{} // DNS Config cfg.DNS = &config.DNSOptions{ - Mode: ptr.FromValue(config.DNSModeUDP), + Mode: lo.ToPtr(config.DNSModeUDP), Addr: &net.TCPAddr{IP: net.ParseIP("8.8.8.8"), Port: 53}, - HTTPSURL: ptr.FromValue("https://dns.google/dns-query"), - QType: ptr.FromValue(config.DNSQueryIPv4), - Cache: ptr.FromValue(false), + HTTPSURL: lo.ToPtr("https://dns.google/dns-query"), + QType: lo.ToPtr(config.DNSQueryIPv4), + Cache: lo.ToPtr(false), } logger := zerolog.Nop() resolver := createResolver(logger, cfg) - p, err := createProxy(logger, cfg, resolver) + p, err := createServer(context.Background(), logger, cfg, resolver) require.NoError(t, err) assert.NotNil(t, p) } @@ -75,32 +86,38 @@ func TestCreateProxy_NoPcap(t *testing.T) { func TestCreateProxy_WithPolicy(t *testing.T) { cfg := config.NewConfig() - // Server Config - cfg.Server = &config.ServerOptions{ + // App Config + cfg.App = &config.AppOptions{ + Mode: lo.ToPtr(config.AppModeHTTP), ListenAddr: &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 0}, - DefaultTTL: ptr.FromValue(uint8(64)), - Timeout: ptr.FromValue(time.Duration(0)), + } + + // Conn Config + cfg.Conn = &config.ConnOptions{ + DefaultFakeTTL: lo.ToPtr(uint8(64)), + DNSTimeout: lo.ToPtr(time.Duration(0)), + TCPTimeout: lo.ToPtr(time.Duration(0)), + UDPIdleTimeout: lo.ToPtr(time.Duration(0)), } // HTTPS Config cfg.HTTPS = &config.HTTPSOptions{ - FakeCount: ptr.FromValue(uint8(0)), + FakeCount: lo.ToPtr(uint8(0)), } // Policy Config with one override cfg.Policy = &config.PolicyOptions{ - Auto: ptr.FromValue(false), Overrides: []config.Rule{ { - Name: ptr.FromValue("test-rule"), + Name: lo.ToPtr("test-rule"), Match: &config.MatchAttrs{ Domains: []string{"example.com"}, }, DNS: &config.DNSOptions{ - Mode: ptr.FromValue(config.DNSModeSystem), + Mode: lo.ToPtr(config.DNSModeSystem), }, HTTPS: &config.HTTPSOptions{ - Skip: ptr.FromValue(true), + Skip: lo.ToPtr(true), }, }, }, @@ -108,17 +125,17 @@ func TestCreateProxy_WithPolicy(t *testing.T) { // DNS Config cfg.DNS = &config.DNSOptions{ - Mode: ptr.FromValue(config.DNSModeUDP), + Mode: lo.ToPtr(config.DNSModeUDP), Addr: &net.TCPAddr{IP: net.ParseIP("8.8.8.8"), Port: 53}, - HTTPSURL: ptr.FromValue("https://dns.google/dns-query"), - QType: ptr.FromValue(config.DNSQueryIPv4), - Cache: ptr.FromValue(false), + HTTPSURL: lo.ToPtr("https://dns.google/dns-query"), + QType: lo.ToPtr(config.DNSQueryIPv4), + Cache: lo.ToPtr(false), } logger := zerolog.Nop() resolver := createResolver(logger, cfg) - p, err := createProxy(logger, cfg, resolver) + p, err := createServer(context.Background(), logger, cfg, resolver) require.NoError(t, err) assert.NotNil(t, p) } diff --git a/docs/user-guide/general.md b/docs/user-guide/app.md similarity index 61% rename from docs/user-guide/general.md rename to docs/user-guide/app.md index bb2bbf74..5c6a2932 100644 --- a/docs/user-guide/general.md +++ b/docs/user-guide/app.md @@ -1,6 +1,60 @@ -# General Configuration +# App Configuration -General settings for the application, including logging and system integration. +Application-level settings including mode, logging, and system integration. + +## `app-mode` + +`type: string` + +### Description + +Specifies the proxy mode. `(default: "http")` + +### Allowed Values + +- `http`: HTTP proxy mode +- `socks5`: SOCKS5 proxy mode +- `tun`: TUN interface mode (transparent proxy) + +### Usage + +**Command-Line Flag** +```console +$ spoofdpi --app-mode socks5 +``` + +**TOML Config** +```toml +[app] +mode = "socks5" +``` + +--- + +## `listen-addr` + +`type: ` + +### Description + +Specifies the IP address and port to listen on. `(default: 127.0.0.1:8080 for http, 127.0.0.1:1080 for socks5)` + +If you want to run SpoofDPI remotely (e.g., on a physically separated machine), set the IP part to `0.0.0.0`. Otherwise, it is recommended to leave this option as default for security. + +### Usage + +**Command-Line Flag** +```console +$ spoofdpi --listen-addr "0.0.0.0:8080" +``` + +**TOML Config** +```toml +[app] +listen-addr = "0.0.0.0:8080" +``` + +--- ## `log-level` @@ -21,7 +75,7 @@ $ spoofdpi --log-level trace **TOML Config** ```toml -[general] +[app] log-level = "trace" ``` @@ -44,13 +98,13 @@ $ spoofdpi --silent **TOML Config** ```toml -[general] +[app] silent = true ``` --- -## `system-proxy` +## `auto-configure-network` `type: boolean` @@ -65,13 +119,13 @@ Specifies whether to automatically set up the system-wide proxy configuration. ` **Command-Line Flag** ```console -$ spoofdpi --system-proxy +$ spoofdpi --auto-configure-network ``` **TOML Config** ```toml -[general] -system-proxy = true +[app] +auto-configure-network = true ``` --- diff --git a/docs/user-guide/connection.md b/docs/user-guide/connection.md new file mode 100644 index 00000000..71f7db06 --- /dev/null +++ b/docs/user-guide/connection.md @@ -0,0 +1,106 @@ +# Connection Configuration + +Settings for network connection timeouts and packet configuration. + +## `dns-timeout` + +`type: uint16` + +### Description + +Specifies the timeout (in milliseconds) for DNS connections. `(default: 5000, max: 65535)` + +A value of `0` means no timeout. + +### Usage + +**Command-Line Flag** +```console +$ spoofdpi --dns-timeout 3000 +``` + +**TOML Config** +```toml +[connection] +dns-timeout = 3000 +``` + +--- + +## `tcp-timeout` + +`type: uint16` + +### Description + +Specifies the timeout (in milliseconds) for TCP connections. `(default: 10000, max: 65535)` + +A value of `0` means no timeout. + +### Usage + +**Command-Line Flag** +```console +$ spoofdpi --tcp-timeout 5000 +``` + +**TOML Config** +```toml +[connection] +tcp-timeout = 5000 +``` + +--- + +## `udp-idle-timeout` + +`type: uint16` + +### Description + +Specifies the idle timeout (in milliseconds) for UDP connections. `(default: 25000, max: 65535)` + +The connection will be closed if there is no read/write activity for this duration. Each read or write operation resets the timeout. + +A value of `0` means no timeout. + +### Usage + +**Command-Line Flag** +```console +$ spoofdpi --udp-idle-timeout 30000 +``` + +**TOML Config** +```toml +[connection] +udp-idle-timeout = 30000 +``` + +--- + +## `default-fake-ttl` + +`type: uint8` + +### Description + +Specifies the default [Time To Live (TTL)](https://en.wikipedia.org/wiki/Time_to_live) value for fake packets. `(default: 8)` + +This value is used for fake packets sent during disorder strategies. A lower value ensures fake packets expire before reaching the destination, while the real packets arrive successfully. + +!!! note + The fake TTL should be less than the number of hops to the destination. + +### Usage + +**Command-Line Flag** +```console +$ spoofdpi --default-fake-ttl 10 +``` + +**TOML Config** +```toml +[connection] +default-fake-ttl = 10 +``` diff --git a/docs/user-guide/https.md b/docs/user-guide/https.md index c56386fb..4bcf51c9 100644 --- a/docs/user-guide/https.md +++ b/docs/user-guide/https.md @@ -16,6 +16,7 @@ Specifies the default packet fragmentation strategy to use for the Client Hello - `random`: Splits the packet at a random position. - `chunk`: Splits the packet into fixed-size chunks (controlled by `https-chunk-size`). - `first-byte`: Splits only the first byte of the packet. +- `custom`: Uses custom segment plans (TOML only). - `none`: Disables fragmentation. ### Usage @@ -39,7 +40,7 @@ split-mode = "sni" ### Description -Specifies the chunk size in bytes for packet fragmentation. `(default: 0, max: 255)` +Specifies the chunk size in bytes for packet fragmentation. `(default: 35, max: 255)` This value is only applied when `https-split-mode` is set to `chunk`. Try lower values if the default fails to bypass the DPI. @@ -129,10 +130,10 @@ The value should be a sequence of bytes representing a valid (or semi-valid) TLS ### Usage **Command-Line Flag** -Provide a comma-separated string of hexadecimal bytes (e.g., `16,03,01,...`). +Provide a comma-separated string of hexadecimal bytes (e.g., `0x16, 0x03, 0x01, ...`). ```console -$ spoofdpi --https-fake-packet "16,03,01,00,a1,..." +$ spoofdpi --https-fake-packet "0x16, 0x03, 0x01, 0x00, 0xa1, ..." ``` **TOML Config** @@ -165,3 +166,45 @@ $ spoofdpi --https-skip [https] skip = true ``` + +--- + +## `custom-segments` + +`type: array of segment plans` + +### Description + +Defines custom segmentation plans for fine-grained control over how the Client Hello packet is split. This option is only used when `split-mode` is set to `"custom"`. + +Each segment plan specifies where to split the packet relative to a reference point (`from`) with an offset (`at`). + +!!! note + This option is only available via the TOML config file. + +!!! important + When using `custom` split-mode, the global `disorder` option is **ignored**. Use the `lazy` field in each segment plan to control packet ordering instead. + +### Segment Plan Fields + +| Field | Type | Required | Description | +| :------ | :------ | :------- | :---------- | +| `from` | String | Yes | Reference point: `"head"` (start of packet) or `"sni"` (start of SNI extension). | +| `at` | Int | Yes | Byte offset from the reference point. Negative values count backwards. | +| `lazy` | Boolean | No | If `true`, delays sending this segment (equivalent to disorder). `(default: false)` | +| `noise` | Int | No | Adds random noise (in bytes) to the split position. `(default: 0)` | + +### Usage + +**TOML Config** + +```toml +[https] +split-mode = "custom" +custom-segments = [ + { from = "head", at = 2 }, # Split 2 bytes from start + { from = "sni", at = 0 }, # Split at SNI start + { from = "sni", at = -5, lazy = true }, # Split 5 bytes before SNI, delayed + { from = "head", at = 100, noise = 10 }, # Split at ~100 bytes with ±10 noise +] +``` diff --git a/docs/user-guide/overview.md b/docs/user-guide/overview.md index b6cac093..2122d9db 100644 --- a/docs/user-guide/overview.md +++ b/docs/user-guide/overview.md @@ -25,14 +25,15 @@ If a specific path is not provided via a `--config` flag, SpoofDPI will search f ## Options -The configuration is organized into five main categories. Click on each category to view detailed options. +The configuration is organized into six main categories. Click on each category to view detailed options. | Category | Description | | :--- | :--- | -| **[General](general.md)** | General application options (logging, system proxy, etc.). | -| **[Server](server.md)** | Server connection options (address, timeout). | +| **[App](app.md)** | Application-level options (mode, address, logging, etc.). | +| **[Connection](connection.md)** | Connection timeout and packet TTL settings. | | **[DNS](dns.md)** | DNS resolution options. | | **[HTTPS](https.md)** | HTTPS/TLS packet manipulation options. | +| **[UDP](udp.md)** | UDP packet manipulation options. | | **[Policy](policy.md)** | Rule-based routing and automatic bypass policies. | ## Example @@ -44,7 +45,7 @@ The following two methods will achieve the exact same configuration. All settings are passed directly via flags. ```console -$ spoofdpi --dns-addr "1.1.1.1:53" --dns-https-url "https://dns.google/dns-query" --dns-mode "https" +$ spoofdpi --app-mode socks5 --dns-mode https --https-disorder ``` ### Method 2: Using a TOML Config File @@ -52,10 +53,14 @@ $ spoofdpi --dns-addr "1.1.1.1:53" --dns-https-url "https://dns.google/dns-query Place the settings in your `spoofdpi.toml` file: ```toml +[app] +mode = "socks5" + [dns] - addr = "1.1.1.1:53" - https-url = "https://dns.google/dns-query" - mode = "https" +mode = "https" + +[https] +disorder = true ``` Then, run spoofdpi without those flags (it will automatically load the file if placed in a standard path): diff --git a/docs/user-guide/policy.md b/docs/user-guide/policy.md index 09111529..1b6a52cf 100644 --- a/docs/user-guide/policy.md +++ b/docs/user-guide/policy.md @@ -2,51 +2,24 @@ By defining rules within the Policy section, you can granularly control how SpoofDPI handles connections to specific domains or IP addresses. You can define per-domain bypass strategies, DNS settings, or simply block connections. -## `auto` - -`type: boolean` - -### Description - -Automatically detect blocked sites and add them to the bypass list. `(default: false)` - -When enabled, SpoofDPI attempts to detect if a connection is being blocked and temporarily applies bypass rules for that destination. These generated rules utilize the configuration defined in `[policy.template]`. - -### Usage - -**Command-Line Flag** -```console -$ spoofdpi --policy-auto -``` - -**TOML Config** -```toml -[policy] -auto = true -``` - ---- - ## `template` -The `[policy.template]` section defines the default behavior for rules automatically generated when `auto = true`. If you enable automatic detection, you should configure this template to ensure the generated rules effectively bypass the DPI. +The `[policy.template]` section defines a base rule configuration. This template can be cloned and customized when programmatically adding rules. !!! note The template configuration is only available via the TOML config file. ### Structure -The template uses the same `Rule` structure as overrides, but typically only the `https` and `dns` sections are relevant, as the `match` criteria are determined dynamically. +The template uses the same `Rule` structure as overrides, but typically only the `https` and `dns` sections are relevant. ### Example ```toml [policy] - auto = true - - # This configuration is applied to automatically detected blocked sites [policy.template] https = { fake-count = 7, disorder = true } + dns = { mode = "https" } ``` --- diff --git a/docs/user-guide/server.md b/docs/user-guide/server.md deleted file mode 100644 index 8bdb6590..00000000 --- a/docs/user-guide/server.md +++ /dev/null @@ -1,79 +0,0 @@ -# Server Configuration - -Settings related to the proxy server connection and listener. - -## `listen-addr` - -`type: ` - -### Description - -Specifies the IP address and port to listen on. `(default: 127.0.0.1:8080)` - -If you want to run SpoofDPI remotely (e.g., on a physically separated machine), set the IP part to `0.0.0.0`. Otherwise, it is recommended to leave this option as default for security. - -### Usage - -**Command-Line Flag** -```console -$ spoofdpi --listen-addr "0.0.0.0:8080" -``` - -**TOML Config** -```toml -[server] -listen-addr = "0.0.0.0:8080" -``` - ---- - -## `timeout` - -`type: uint16` - -### Description - -Specifies the timeout (in milliseconds) for every TCP connection. `(default: 0, max: 65535)` - -A value of `0` means no timeout. You can set this option if you know what you are doing, but in most cases, leaving this option unset is recommended. - -### Usage - -**Command-Line Flag** -```console -$ spoofdpi --timeout 5000 -``` - -**TOML Config** -```toml -[server] -timeout = 5000 -``` - ---- - -## `default-ttl` - -`type: uint8` - -### Description - -Specifies the default [Time To Live (TTL)](https://en.wikipedia.org/wiki/Time_to_live) value for outgoing packets. `(default: 64)` - -This value is used to restore the TTL to its default state after applying disorder strategies. Changing this option is generally not required. - -!!! note - The default TTL value for macOS and Linux is usually `64`. - -### Usage - -**Command-Line Flag** -```console -$ spoofdpi --default-ttl 128 -``` - -**TOML Config** -```toml -[server] -default-ttl = 128 -``` diff --git a/docs/user-guide/udp.md b/docs/user-guide/udp.md new file mode 100644 index 00000000..c5f776c9 --- /dev/null +++ b/docs/user-guide/udp.md @@ -0,0 +1,58 @@ +# UDP Configuration + +Settings for UDP packet manipulation and bypass techniques. + +## `fake-count` + +`type: int` + +### Description + +Specifies the number of fake packets to be sent before actual UDP packets. `(default: 0)` + +Sending fake packets can trick DPI systems into inspecting invalid traffic, allowing real packets to pass through. + +!!! note + This feature requires root privileges and packet capture capabilities. + +### Usage + +**Command-Line Flag** +```console +$ spoofdpi --udp-fake-count 5 +``` + +**TOML Config** +```toml +[udp] +fake-count = 5 +``` + +--- + +## `fake-packet` + +`type: byte array` + +### Description + +Customizes the content of the fake packets used by `udp-fake-count`. `(default: 64 bytes of zeros)` + +The value should be a sequence of bytes representing the fake packet data. + +### Usage + +**Command-Line Flag** +Provide a comma-separated string of hexadecimal bytes (e.g., `0x00, 0x01, 0x02, ...`). + +```console +$ spoofdpi --udp-fake-packet "0x00, 0x01, 0x02, 0x03, 0x04" +``` + +**TOML Config** +Provide an array of integers (bytes). + +```toml +[udp] +fake-packet = [0x00, 0x01, 0x02, 0x03, 0x04] +``` diff --git a/go.mod b/go.mod index c6243528..33d82d80 100644 --- a/go.mod +++ b/go.mod @@ -9,21 +9,27 @@ require ( github.com/google/gopacket v1.1.19 github.com/miekg/dns v1.1.61 github.com/rs/zerolog v1.33.0 + github.com/samber/lo v1.52.0 + github.com/songgao/water v0.0.0-20200317203138-2b4b6d7c09d8 github.com/stretchr/testify v1.11.1 github.com/urfave/cli/v3 v3.6.1 + golang.org/x/net v0.44.0 golang.org/x/sys v0.36.0 + gvisor.dev/gvisor v0.0.0-20251220000015-517913d17844 ) require ( github.com/davecgh/go-spew v1.1.1 // indirect + github.com/google/btree v1.1.2 // indirect github.com/kr/pretty v0.1.0 // indirect github.com/mattn/go-colorable v0.1.13 // indirect github.com/mattn/go-isatty v0.0.19 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect - golang.org/x/mod v0.27.0 // indirect - golang.org/x/net v0.43.0 // indirect + golang.org/x/mod v0.28.0 // indirect golang.org/x/sync v0.17.0 // indirect - golang.org/x/tools v0.36.0 // indirect + golang.org/x/text v0.29.0 // indirect + golang.org/x/time v0.12.0 // indirect + golang.org/x/tools v0.37.0 // indirect gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index 1c71dc55..002d7467 100644 --- a/go.sum +++ b/go.sum @@ -4,8 +4,10 @@ github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSV github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= -github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= -github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/google/btree v1.1.2 h1:xf4v41cLI2Z6FxbKm+8Bu+m8ifhj15JuZ9sa0jZCMUU= +github.com/google/btree v1.1.2/go.mod h1:qOPhT0dTNdNzV6Z/lhRX0YXUafgPLFUh+gZMl761Gm4= +github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= +github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= github.com/google/gopacket v1.1.19 h1:ves8RnFZPGiFnTS0uPQStjwru6uO6h+nlr9j6fL7kF8= github.com/google/gopacket v1.1.19/go.mod h1:iJ8V8n6KS+z2U1A8pUwu8bW5SyEMkXJB8Yo/Vo+TKTo= github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI= @@ -26,6 +28,10 @@ github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZN github.com/rs/xid v1.5.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg= github.com/rs/zerolog v1.33.0 h1:1cU2KZkvPxNyfgEmhHAz/1A9Bz+llsdYzklWFzgp0r8= github.com/rs/zerolog v1.33.0/go.mod h1:/7mN4D5sKwJLZQ2b/znpjC3/GQWY/xaDXUM0kKWRHss= +github.com/samber/lo v1.52.0 h1:Rvi+3BFHES3A8meP33VPAxiBZX/Aws5RxrschYGjomw= +github.com/samber/lo v1.52.0/go.mod h1:4+MXEGsJzbKGaUEQFKBq2xtfuznW9oz/WrgyzMzRoM0= +github.com/songgao/water v0.0.0-20200317203138-2b4b6d7c09d8 h1:TG/diQgUe0pntT/2D9tmUCz4VNwm9MfrtPr0SU2qSX8= +github.com/songgao/water v0.0.0-20200317203138-2b4b6d7c09d8/go.mod h1:P5HUIBuIWKbyjl083/loAegFkfbFNx5i2qEP4CNbm7E= github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= github.com/urfave/cli/v3 v3.6.1 h1:j8Qq8NyUawj/7rTYdBGrxcH7A/j7/G8Q5LhWEW4G3Mo= @@ -34,12 +40,12 @@ golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACk golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/lint v0.0.0-20200302205851-738671d3881b/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY= golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg= -golang.org/x/mod v0.27.0 h1:kb+q2PyFnEADO2IEF935ehFUXlWiNjJWtRNgBLSfbxQ= -golang.org/x/mod v0.27.0/go.mod h1:rWI627Fq0DEoudcK+MBkNkCe0EetEaDSwJJkCcjpazc= +golang.org/x/mod v0.28.0 h1:gQBtGhjxykdjY9YhZpSlZIsbnaE2+PgjfLWUQTnoZ1U= +golang.org/x/mod v0.28.0/go.mod h1:yfB/L0NOf/kmEbXjzCPOx1iK1fRutOydrCMsqRhEBxI= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= -golang.org/x/net v0.43.0 h1:lat02VYK2j4aLzMzecihNvTlJNQUq316m2Mr9rnM6YE= -golang.org/x/net v0.43.0/go.mod h1:vhO1fvI4dGsIjh73sWfUVjj3N7CA9WkKJNQm2svM6Jg= +golang.org/x/net v0.44.0 h1:evd8IRDyfNBMBTTY5XRF1vaZlD+EmWx6x8PkhR04H/I= +golang.org/x/net v0.44.0/go.mod h1:ECOoLqd5U3Lhyeyo/QDCEVQ4sNgYsqvCZ722XogGieY= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.17.0 h1:l60nONMj9l5drqw6jlhIELNv9I0A4OFgRsG9k2oT9Ug= golang.org/x/sync v0.17.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= @@ -51,12 +57,18 @@ golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.36.0 h1:KVRy2GtZBrk1cBYA7MKu5bEZFxQk4NIDV6RLVcC8o0k= golang.org/x/sys v0.36.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/text v0.29.0 h1:1neNs90w9YzJ9BocxfsQNHKuAT4pkghyXc4nhZ6sJvk= +golang.org/x/text v0.29.0/go.mod h1:7MhJOA9CD2qZyOKYazxdYMF85OwPdEr9jTtBpO7ydH4= +golang.org/x/time v0.12.0 h1:ScB/8o8olJvc+CQPWrK3fPZNfh7qgwCrY0zJmoEQLSE= +golang.org/x/time v0.12.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg= golang.org/x/tools v0.0.0-20200130002326-2f3ba24bd6e7/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= -golang.org/x/tools v0.36.0 h1:kWS0uv/zsvHEle1LbV5LE8QujrxB3wfQyxHfhOk0Qkg= -golang.org/x/tools v0.36.0/go.mod h1:WBDiHKJK8YgLHlcQPYQzNCkUxUypCaa5ZegCVutKm+s= +golang.org/x/tools v0.37.0 h1:DVSRzp7FwePZW356yEAChSdNcQo6Nsp+fex1SUW09lE= +golang.org/x/tools v0.37.0/go.mod h1:MBN5QPQtLMHVdvsbtarmTNukZDdgwdwlO5qGacAzF0w= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 h1:YR8cESwS4TdDjEe65xsg0ogRM/Nc3DYOhEAlW+xobZo= gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gvisor.dev/gvisor v0.0.0-20251220000015-517913d17844 h1:7SkRScnij3eBOP12JnH9KnIfAcpPFMWy9KxNHjOXGTM= +gvisor.dev/gvisor v0.0.0-20251220000015-517913d17844/go.mod h1:W1ZgZ/Dh85TgSZWH67l2jKVpDE5bjIaut7rjwwOiHzQ= diff --git a/internal/cache/cache.go b/internal/cache/cache.go index 965e204f..1d816b89 100644 --- a/internal/cache/cache.go +++ b/internal/cache/cache.go @@ -33,9 +33,15 @@ func (o *options) WithSkipExisting(skipExisting bool) *options { // Cache is the unified interface for all cache implementations. // The Set method accepts a variadic list of options. -type Cache interface { - // Get retrieves a value from the cache. - Get(key string) (any, bool) - // Set adds a value to the cache, applying any provided options. - Set(key string, value any, opts *options) bool +type Cache[K comparable] interface { + // Fetch retrieves a value from the cache. + Fetch(key K) (any, bool) + // Store adds a value to the cache, applying any provided options. + Store(key K, value any, opts *options) bool + Evict(key K) + Has(key K) bool + // ForEach iterates over the cache items. + ForEach(f func(key K, value any) error) error + // Size returns the number of items in the cache. + Size() int } diff --git a/internal/cache/lru_cache.go b/internal/cache/lru_cache.go index b58cfa95..e7323c7d 100644 --- a/internal/cache/lru_cache.go +++ b/internal/cache/lru_cache.go @@ -5,66 +5,73 @@ import ( "sync" ) -var _ Cache = (*LRUCache)(nil) +var _ Cache[string] = (*LRUCache[string])(nil) // lruEntry represents the value stored in the cache and the linked list node. -type lruEntry struct { - key string +type lruEntry[K comparable] struct { + key K value any // expiry time.Time field removed } // LRUCache is a concurrent, fixed-size cache with an LRU eviction policy. -type LRUCache struct { +type LRUCache[K comparable] struct { capacity int - mu sync.RWMutex + mu sync.Mutex // list is a doubly linked list used for tracking access order. // Front is Most Recently Used (MRU), Back is Least Recently Used (LRU). list *list.List // cache maps the key to the list element (*list.Element) which holds the lruEntry. - cache map[string]*list.Element + cache map[K]*list.Element + + onInvalidate func(key K, value any) } -// NewLRUCache creates a new LRU Cache instance with the given capacity. +// NewLRUCache creates a new LRU Cache instance. // Capacity must be greater than zero. -func NewLRUCache(capacity int) Cache { +func NewLRUCache[K comparable]( + capacity int, + onInvalidate func(key K, value any), +) Cache[K] { if capacity <= 0 { // Default to a sensible minimum capacity if input is invalid capacity = 100 } - return &LRUCache{ - capacity: capacity, - list: list.New(), - cache: make(map[string]*list.Element, capacity), + return &LRUCache[K]{ + capacity: capacity, + list: list.New(), + cache: make(map[K]*list.Element, capacity), + onInvalidate: onInvalidate, } } -// isExpired function removed - // evictOldest removes the least recently used item from the cache. -func (c *LRUCache) evictOldest() { +func (c *LRUCache[K]) evictOldest() { // Element is at the back of the list (LRU) tail := c.list.Back() if tail != nil { - c.removeElement(tail) + c.removeByElement(tail) } } -// removeElement removes a specific list element from both the list and the map. -func (c *LRUCache) removeElement(e *list.Element) { +// removeByElement removes a specific list element from both the list and the map. +func (c *LRUCache[K]) removeByElement(e *list.Element) { c.list.Remove(e) - entry := e.Value.(*lruEntry) + entry := e.Value.(*lruEntry[K]) delete(c.cache, entry.key) + if c.onInvalidate != nil { + c.onInvalidate(entry.key, entry.value) + } } -// Get retrieves a value from the cache. +// Fetch retrieves a value from the cache. // If found, the item is promoted to Most Recently Used (MRU). -func (c *LRUCache) Get(key string) (any, bool) { - // Use RLock for concurrent safe read operations - c.mu.RLock() - defer c.mu.RUnlock() +func (c *LRUCache[K]) Fetch(key K) (any, bool) { + // Use Lock since MoveToFront modifies the linked list + c.mu.Lock() + defer c.mu.Unlock() // Check if key exists in the map if element, ok := c.cache[key]; ok { @@ -73,15 +80,15 @@ func (c *LRUCache) Get(key string) (any, bool) { // Promote to Most Recently Used (MRU) c.list.MoveToFront(element) - entry := element.Value.(*lruEntry) + entry := element.Value.(*lruEntry[K]) return entry.value, true } return nil, false } -// Set adds a value to the cache, applying any provided options. -func (c *LRUCache) Set(key string, value any, opts *options) bool { +// Store adds a value to the cache, applying any provided options. +func (c *LRUCache[K]) Store(key K, value any, opts *options) bool { // Use Write Lock for modification operations c.mu.Lock() defer c.mu.Unlock() @@ -99,8 +106,12 @@ func (c *LRUCache) Set(key string, value any, opts *options) bool { return false } - if ok { - entry := element.Value.(*lruEntry) + if ok { // Key already exists + entry := element.Value.(*lruEntry[K]) + if c.onInvalidate != nil { + // Invoke onInvalidate to ensure associated resources are properly released. + c.onInvalidate(entry.key, entry.value) + } entry.value = value c.list.MoveToFront(element) @@ -108,7 +119,7 @@ func (c *LRUCache) Set(key string, value any, opts *options) bool { } // Key is new: Create a new entry - entry := &lruEntry{ + entry := &lruEntry[K]{ key: key, value: value, } @@ -124,3 +135,44 @@ func (c *LRUCache) Set(key string, value any, opts *options) bool { return true } + +// ForEach iterates over the cache items. +func (c *LRUCache[K]) ForEach(f func(key K, value any) error) error { + c.mu.Lock() + defer c.mu.Unlock() + + var next *list.Element + for e := c.list.Front(); e != nil; e = next { + next = e.Next() + entry := e.Value.(*lruEntry[K]) + if err := f(entry.key, entry.value); err != nil { + return err + } + } + return nil +} + +// Evict removes an item from the cache. +func (c *LRUCache[K]) Evict(key K) { + c.mu.Lock() + defer c.mu.Unlock() + + if element, ok := c.cache[key]; ok { + c.removeByElement(element) + } +} + +// Has checks if an item exists in the cache without moving its MRU status. +func (c *LRUCache[K]) Has(key K) bool { + c.mu.Lock() + defer c.mu.Unlock() + _, ok := c.cache[key] + return ok +} + +// Size returns the number of items in the cache. +func (c *LRUCache[K]) Size() int { + c.mu.Lock() + defer c.mu.Unlock() + return c.list.Len() +} diff --git a/internal/cache/ttl_cache.go b/internal/cache/ttl_cache.go index 70d7b10c..b4a5678d 100644 --- a/internal/cache/ttl_cache.go +++ b/internal/cache/ttl_cache.go @@ -7,16 +7,16 @@ import ( "time" ) -var _ Cache = (*TTLCache)(nil) +var _ Cache[string] = (*TTLCache[string])(nil) // ttlCacheItem represents a single cached item using generics. -type ttlCacheItem struct { +type ttlCacheItem[K comparable] struct { value any // The cached data of type T. expiresAt time.Time // The time when the item expires. } // isExpired checks if the item has expired. -func (i ttlCacheItem) isExpired() bool { +func (i ttlCacheItem[K]) isExpired() bool { if i.expiresAt.IsZero() { // zero time means no expiration. return false @@ -25,40 +25,43 @@ func (i ttlCacheItem) isExpired() bool { } // ttlCacheShard represents a single, thread-safe shard of the cache. -type ttlCacheShard struct { - items map[string]ttlCacheItem // items holds the cache data for this shard. +type ttlCacheShard[K comparable] struct { + items map[K]ttlCacheItem[K] // items holds the cache data for this shard. mu sync.RWMutex } type TTLCacheAttrs struct { NumOfShards uint8 CleanupInterval time.Duration + HashFunc func(key any) uint64 } // TTLCache is a high-performance, sharded, generic TTL cache. -type TTLCache struct { - shards []*ttlCacheShard // A slice of cache shards. +type TTLCache[K comparable] struct { + shards []*ttlCacheShard[K] // A slice of cache shards. + hashFunc func(key any) uint64 } // NewTTLCache creates a new sharded TTL cache with a background janitor goroutine. // numShards specifies the number of shards to create and must be greater than 0. // cleanupInterval specifies how often the janitor should run. -func NewTTLCache( +func NewTTLCache[K comparable]( attrs TTLCacheAttrs, -) *TTLCache { +) *TTLCache[K] { if attrs.NumOfShards == 0 { panic( fmt.Errorf("number of shards must be greater than 0, got %d", attrs.NumOfShards), ) } - c := &TTLCache{ - shards: make([]*ttlCacheShard, attrs.NumOfShards), + c := &TTLCache[K]{ + shards: make([]*ttlCacheShard[K], attrs.NumOfShards), + hashFunc: attrs.HashFunc, } for i := range attrs.NumOfShards { - c.shards[i] = &ttlCacheShard{ - items: make(map[string]ttlCacheItem), + c.shards[i] = &ttlCacheShard[K]{ + items: make(map[K]ttlCacheItem[K]), } } @@ -70,7 +73,7 @@ func NewTTLCache( } // janitor runs the cleanup goroutine, calling ForceCleanup at the specified interval. -func (c *TTLCache) janitor(interval time.Duration) { +func (c *TTLCache[K]) janitor(interval time.Duration) { ticker := time.NewTicker(interval) defer ticker.Stop() for range ticker.C { @@ -79,9 +82,22 @@ func (c *TTLCache) janitor(interval time.Duration) { } // getShard maps a key to its corresponding cache shard using a hash function. -func (c *TTLCache) getShard(key string) *ttlCacheShard { +func (c *TTLCache[K]) getShard(key K) *ttlCacheShard[K] { + if c.hashFunc != nil { + hash := c.hashFunc(key) + return c.shards[hash%uint64(len(c.shards))] + } + hasher := fnv.New64a() - hasher.Write([]byte(key)) + // Optimally hash the memory without string allocation + switch v := any(key).(type) { + case string: + hasher.Write([]byte(v)) + case []byte: + hasher.Write(v) + default: + _, _ = fmt.Fprint(hasher, key) + } hash := hasher.Sum64() return c.shards[hash%uint64(len(c.shards))] } @@ -89,9 +105,9 @@ func (c *TTLCache) getShard(key string) *ttlCacheShard { // ┌─────────────┐ // │ PUBLIC APIs │ // └─────────────┘ -// Set adds an item to the cache, replacing any existing item. +// Store adds an item to the cache, replacing any existing item. // If ttl is 0 or negative, the item will never expire (passive-only). -func (c *TTLCache) Set(key string, value any, opts *options) bool { +func (c *TTLCache[K]) Store(key K, value any, opts *options) bool { shard := c.getShard(key) shard.mu.Lock() @@ -115,7 +131,7 @@ func (c *TTLCache) Set(key string, value any, opts *options) bool { } expiresAt := time.Now().Add(opts.ttl) - newItem := ttlCacheItem{ + newItem := ttlCacheItem[K]{ value: value, expiresAt: expiresAt, } @@ -124,10 +140,10 @@ func (c *TTLCache) Set(key string, value any, opts *options) bool { return true } -// Get retrieves an item from the cache. +// Fetch retrieves an item from the cache. // It returns the item (of type T) and true if found and not expired. // Otherwise, it returns the zero value of T and false. -func (c *TTLCache) Get(key string) (any, bool) { +func (c *TTLCache[K]) Fetch(key K) (any, bool) { shard := c.getShard(key) shard.mu.RLock() i, ok := shard.items[key] @@ -157,17 +173,31 @@ func (c *TTLCache) Get(key string) (any, bool) { return i.value, true } -// Delete removes an item from the cache. -func (c *TTLCache) Delete(key string) { +// Evict removes an item from the cache. +func (c *TTLCache[K]) Evict(key K) { shard := c.getShard(key) shard.mu.Lock() delete(shard.items, key) shard.mu.Unlock() } +// Has checks if an item exists in the cache and is not expired. +func (c *TTLCache[K]) Has(key K) bool { + shard := c.getShard(key) + shard.mu.RLock() + i, ok := shard.items[key] + shard.mu.RUnlock() + + if !ok { + return false + } + + return !i.isExpired() +} + // ForceCleanup actively scans all shards and deletes expired items. // This is called periodically by the janitor but can also be called manually. -func (c *TTLCache) ForceCleanup() { +func (c *TTLCache[K]) ForceCleanup() { now := time.Now() for _, shard := range c.shards { shard.mu.Lock() @@ -179,3 +209,29 @@ func (c *TTLCache) ForceCleanup() { shard.mu.Unlock() } } + +// ForEach iterates over the cache items. +func (c *TTLCache[K]) ForEach(f func(key K, value any) error) error { + for _, shard := range c.shards { + shard.mu.RLock() + for key, i := range shard.items { // Pre-allocate values to avoid holding RLock unnecessarily? For now, keep simple. + if err := f(key, i.value); err != nil { + shard.mu.RUnlock() + return err + } + } + shard.mu.RUnlock() + } + return nil +} + +// Size returns the total number of items across all shards. +func (c *TTLCache[K]) Size() int { + total := 0 + for _, shard := range c.shards { + shard.mu.RLock() + total += len(shard.items) + shard.mu.RUnlock() + } + return total +} diff --git a/internal/config/cli.go b/internal/config/cli.go index c776e513..182830f5 100644 --- a/internal/config/cli.go +++ b/internal/config/cli.go @@ -5,14 +5,15 @@ import ( "fmt" "io" "math" + "net" "os" "path" "strings" "time" + "github.com/samber/lo" "github.com/urfave/cli/v3" "github.com/xvzc/SpoofDPI/internal/proto" - "github.com/xvzc/SpoofDPI/internal/ptr" ) func CreateCommand( @@ -35,10 +36,24 @@ func CreateCommand( return err }, Flags: []cli.Flag{ + &cli.StringFlag{ + Name: "app-mode", + Usage: fmt.Sprintf(`<"http"|"socks5"|"tun"> + Specifies the proxy mode. (default: %q)`, + defaultCfg.App.Mode.String(), + ), + OnlyOnce: true, + Validator: checkAppMode, + Action: func(ctx context.Context, cmd *cli.Command, v string) error { + argsCfg.App.Mode = lo.ToPtr(MustParseServerModeType(v)) + return nil + }, + }, + &cli.BoolFlag{ Name: "clean", Usage: ` - if set, all configuration files will be ignored (default: %v)`, + if set, all configuration files will be ignored (default: false)`, OnlyOnce: true, }, @@ -53,15 +68,15 @@ func CreateCommand( }, &cli.Int64Flag{ - Name: "default-ttl", + Name: "default-fake-ttl", Usage: fmt.Sprintf(` - Default TTL value for manipulated packets. (default: %v)`, - *defaultCfg.Server.DefaultTTL, + Default TTL value for fake packets. (default: %v)`, + *defaultCfg.Conn.DefaultFakeTTL, ), OnlyOnce: true, Validator: checkUint8NonZero, Action: func(ctx context.Context, cmd *cli.Command, v int64) error { - argsCfg.Server.DefaultTTL = ptr.FromValue(uint8(v)) + argsCfg.Conn.DefaultFakeTTL = lo.ToPtr(uint8(v)) return nil }, }, @@ -75,7 +90,7 @@ func CreateCommand( OnlyOnce: true, Validator: checkHostPort, Action: func(ctx context.Context, cmd *cli.Command, v string) error { - argsCfg.DNS.Addr = ptr.FromValue(MustParseTCPAddr(v)) + argsCfg.DNS.Addr = lo.ToPtr(MustParseTCPAddr(v)) return nil }, }, @@ -84,12 +99,12 @@ func CreateCommand( Name: "dns-cache", Usage: fmt.Sprintf(` If set, DNS records will be cached. (default: %v)`, - defaultCfg.DNS.Cache, + *defaultCfg.DNS.Cache, ), Value: false, OnlyOnce: true, Action: func(ctx context.Context, cmd *cli.Command, v bool) error { - argsCfg.DNS.Cache = ptr.FromValue(v) + argsCfg.DNS.Cache = lo.ToPtr(v) return nil }, }, @@ -105,7 +120,7 @@ func CreateCommand( OnlyOnce: true, Validator: checkDNSMode, Action: func(ctx context.Context, cmd *cli.Command, v string) error { - argsCfg.DNS.Mode = ptr.FromValue(MustParseDNSModeType(v)) + argsCfg.DNS.Mode = lo.ToPtr(MustParseDNSModeType(v)) return nil }, }, @@ -121,7 +136,7 @@ func CreateCommand( OnlyOnce: true, Validator: checkHTTPSEndpoint, Action: func(ctx context.Context, cmd *cli.Command, v string) error { - argsCfg.DNS.HTTPSURL = ptr.FromValue(v) + argsCfg.DNS.HTTPSURL = lo.ToPtr(v) return nil }, }, @@ -137,7 +152,26 @@ func CreateCommand( OnlyOnce: true, Validator: checkDNSQueryType, Action: func(ctx context.Context, cmd *cli.Command, v string) error { - argsCfg.DNS.QType = ptr.FromValue(MustParseDNSQueryType(v)) + argsCfg.DNS.QType = lo.ToPtr(MustParseDNSQueryType(v)) + return nil + }, + }, + + &cli.Int64Flag{ + Name: "dns-timeout", + Usage: fmt.Sprintf(` + Timeout for dns connection in milliseconds. + No effect when the value is 0 (default: %v, max: %v)`, + defaultCfg.Conn.DNSTimeout.Milliseconds(), + math.MaxUint16, + ), + Value: 0, + OnlyOnce: true, + Validator: checkUint16, + Action: func(ctx context.Context, cmd *cli.Command, v int64) error { + argsCfg.Conn.DNSTimeout = lo.ToPtr( + time.Duration(v * int64(time.Millisecond)), + ) return nil }, }, @@ -147,13 +181,13 @@ func CreateCommand( Usage: fmt.Sprintf(` Number of fake packets to be sent before the Client Hello. Requires 'https-chunk-size' > 0 for fragmentation. (default: %v)`, - defaultCfg.HTTPS.FakeCount, + *defaultCfg.HTTPS.FakeCount, ), Value: 0, OnlyOnce: true, Validator: checkUint8, Action: func(ctx context.Context, cmd *cli.Command, v int64) error { - argsCfg.HTTPS.FakeCount = ptr.FromValue(uint8(v)) + argsCfg.HTTPS.FakeCount = lo.ToPtr(uint8(v)) return nil }, }, @@ -176,18 +210,18 @@ func CreateCommand( Name: "https-disorder", Usage: fmt.Sprintf(` If set, sends fragmented Client Hello packets out-of-order. (default: %v)`, - defaultCfg.HTTPS.Disorder, + *defaultCfg.HTTPS.Disorder, ), OnlyOnce: true, Action: func(ctx context.Context, cmd *cli.Command, v bool) error { - argsCfg.HTTPS.Disorder = ptr.FromValue(v) + argsCfg.HTTPS.Disorder = lo.ToPtr(v) return nil }, }, &cli.StringFlag{ Name: "https-split-mode", - Usage: fmt.Sprintf(`<"sni"|"random"|"chunk"|"sni"|"none"> + Usage: fmt.Sprintf(`<"sni"|"random"|"chunk"|"sni"|"custom"|"none"> Specifies the default packet fragmentation strategy to use. (default: %q)`, defaultCfg.HTTPS.SplitMode.String(), ), @@ -195,7 +229,7 @@ func CreateCommand( OnlyOnce: true, Validator: checkHTTPSSplitMode, Action: func(ctx context.Context, cmd *cli.Command, v string) error { - argsCfg.HTTPS.SplitMode = ptr.FromValue(mustParseHTTPSSplitModeType(v)) + argsCfg.HTTPS.SplitMode = lo.ToPtr(mustParseHTTPSSplitModeType(v)) return nil }, }, @@ -205,11 +239,11 @@ func CreateCommand( Usage: fmt.Sprintf(` If set, HTTPS traffic will be processed without any DPI bypass techniques. (default: %v)`, - defaultCfg.HTTPS.Skip, + *defaultCfg.HTTPS.Skip, ), OnlyOnce: true, Action: func(ctx context.Context, cmd *cli.Command, v bool) error { - argsCfg.HTTPS.Skip = ptr.FromValue(v) + argsCfg.HTTPS.Skip = lo.ToPtr(v) return nil }, }, @@ -222,29 +256,76 @@ func CreateCommand( disables fragmentation (to avoid division-by-zero errors), you should set 'https-split-default' to 'none' to disable the feature cleanly. (default: %v, max: %v)`, - defaultCfg.HTTPS.ChunkSize, + *defaultCfg.HTTPS.ChunkSize, math.MaxUint8, ), - Value: 0, OnlyOnce: true, Validator: checkUint8NonZero, Action: func(ctx context.Context, cmd *cli.Command, v int64) error { - argsCfg.HTTPS.ChunkSize = ptr.FromValue(uint8(v)) + argsCfg.HTTPS.ChunkSize = lo.ToPtr(uint8(v)) + return nil + }, + }, + + &cli.Int64Flag{ + Name: "udp-fake-count", + Usage: fmt.Sprintf(` + Number of fake packets to be sent. (default: %v)`, + *defaultCfg.UDP.FakeCount, + ), + Value: 0, + OnlyOnce: true, + Validator: int64Range(0, math.MaxInt), + Action: func(ctx context.Context, cmd *cli.Command, v int64) error { + argsCfg.UDP.FakeCount = lo.ToPtr(int(v)) return nil }, }, &cli.StringFlag{ - Name: "listen-addr", + Name: "udp-fake-packet", + Usage: ` + Comma-separated hexadecimal byte array used for fake packet. + (default: built-in fake packet)`, + Value: MustParseHexCSV(defaultCfg.UDP.FakePacket), + OnlyOnce: true, + Validator: checkHexBytesStr, + Action: func(ctx context.Context, cmd *cli.Command, v string) error { + argsCfg.UDP.FakePacket = MustParseBytes(v) + return nil + }, + }, + + &cli.Int64Flag{ + Name: "udp-idle-timeout", Usage: fmt.Sprintf(` - IP address to listen on (default: %v)`, - defaultCfg.Server.ListenAddr.String(), + Idle timeout for udp connection in milliseconds. + No effect when the value is 0 (default: %v, max: %v)`, + defaultCfg.Conn.UDPIdleTimeout.Milliseconds(), + math.MaxUint16, ), - Value: "127.0.0.1:8080", + Value: 0, + OnlyOnce: true, + Validator: checkUint16, + Action: func(ctx context.Context, cmd *cli.Command, v int64) error { + argsCfg.Conn.UDPIdleTimeout = lo.ToPtr( + time.Duration(v * int64(time.Millisecond)), + ) + return nil + }, + }, + + &cli.StringFlag{ + Name: "listen-addr", + Usage: ` + IP address to listen on (default: 127.0.0.1:8080 for http, or 127.0.0.1:1080 for socks5)`, OnlyOnce: true, Validator: checkHostPort, Action: func(ctx context.Context, cmd *cli.Command, v string) error { - argsCfg.Server.ListenAddr = ptr.FromValue(MustParseTCPAddr(v)) + if v == "" { + return nil + } + argsCfg.App.ListenAddr = lo.ToPtr(MustParseTCPAddr(v)) return nil }, }, @@ -256,20 +337,7 @@ func CreateCommand( OnlyOnce: true, Validator: checkLogLevel, Action: func(ctx context.Context, cmd *cli.Command, v string) error { - argsCfg.General.LogLevel = ptr.FromValue(MustParseLogLevel(v)) - return nil - }, - }, - - &cli.BoolFlag{ - Name: "policy-auto", - Usage: fmt.Sprintf(` - Automatically detect the blocked sites and add policies (default: %v)`, - *defaultCfg.Policy.Auto, - ), - OnlyOnce: true, - Action: func(ctx context.Context, cmd *cli.Command, v bool) error { - argsCfg.Policy.Auto = ptr.FromValue(v) + argsCfg.App.LogLevel = lo.ToPtr(MustParseLogLevel(v)) return nil }, }, @@ -278,41 +346,41 @@ func CreateCommand( Name: "silent", Usage: fmt.Sprintf(` Do not show the banner at start up (default: %v)`, - defaultCfg.General.Silent, + *defaultCfg.App.Silent, ), OnlyOnce: true, Action: func(ctx context.Context, cmd *cli.Command, v bool) error { - argsCfg.General.Silent = ptr.FromValue(v) + argsCfg.App.Silent = lo.ToPtr(v) return nil }, }, &cli.BoolFlag{ - Name: "system-proxy", + Name: "auto-configure-network", Usage: fmt.Sprintf(` Automatically set system-wide proxy configuration (default: %v)`, - defaultCfg.General.SetSystemProxy, + *defaultCfg.App.AutoConfigureNetwork, ), OnlyOnce: true, Action: func(ctx context.Context, cmd *cli.Command, v bool) error { - argsCfg.General.SetSystemProxy = ptr.FromValue(v) + argsCfg.App.AutoConfigureNetwork = lo.ToPtr(v) return nil }, }, &cli.Int64Flag{ - Name: "timeout", + Name: "tcp-timeout", Usage: fmt.Sprintf(` Timeout for tcp connection in milliseconds. No effect when the value is 0 (default: %v, max: %v)`, - defaultCfg.Server.Timeout, + defaultCfg.Conn.TCPTimeout.Milliseconds(), math.MaxUint16, ), Value: 0, OnlyOnce: true, Validator: checkUint16, Action: func(ctx context.Context, cmd *cli.Command, v int64) error { - argsCfg.Server.Timeout = ptr.FromValue( + argsCfg.Conn.TCPTimeout = lo.ToPtr( time.Duration(v * int64(time.Millisecond)), ) return nil @@ -361,6 +429,17 @@ func CreateCommand( finalCfg := defaultCfg.Merge(tomlCfg.Merge(argsCfg)) + if finalCfg.App.ListenAddr == nil { + port := 8080 + if *finalCfg.App.Mode == AppModeSOCKS5 { + port = 1080 + } + finalCfg.App.ListenAddr = &net.TCPAddr{ + IP: net.ParseIP("127.0.0.1"), + Port: port, + } + } + runFunc(ctx, strings.Replace(configDir, os.Getenv("HOME"), "~", 1), finalCfg) return nil }, diff --git a/internal/config/cli_test.go b/internal/config/cli_test.go index 9f0ee18c..d1e8ff4b 100644 --- a/internal/config/cli_test.go +++ b/internal/config/cli_test.go @@ -24,12 +24,14 @@ func TestCreateCommand_Flags(t *testing.T) { args: []string{"spoofdpi", "--clean"}, assert: func(t *testing.T, cfg *Config) { // Verify defaults are preserved - assert.Equal(t, zerolog.InfoLevel, *cfg.General.LogLevel) - assert.False(t, *cfg.General.Silent) - assert.False(t, *cfg.General.SetSystemProxy) - assert.Equal(t, "127.0.0.1:8080", cfg.Server.ListenAddr.String()) - assert.Equal(t, uint8(64), *cfg.Server.DefaultTTL) - assert.Equal(t, time.Duration(0), *cfg.Server.Timeout) + assert.Equal(t, zerolog.InfoLevel, *cfg.App.LogLevel) + assert.False(t, *cfg.App.Silent) + assert.False(t, *cfg.App.AutoConfigureNetwork) + assert.Equal(t, "127.0.0.1:8080", cfg.App.ListenAddr.String()) + assert.Equal(t, uint8(8), *cfg.Conn.DefaultFakeTTL) + assert.Equal(t, int64(5000), cfg.Conn.DNSTimeout.Milliseconds()) + assert.Equal(t, int64(10000), cfg.Conn.TCPTimeout.Milliseconds()) + assert.Equal(t, int64(25000), cfg.Conn.UDPIdleTimeout.Milliseconds()) assert.Equal(t, "8.8.8.8:53", cfg.DNS.Addr.String()) assert.Equal(t, DNSModeUDP, *cfg.DNS.Mode) assert.Equal(t, "https://dns.google/dns-query", *cfg.DNS.HTTPSURL) @@ -38,9 +40,10 @@ func TestCreateCommand_Flags(t *testing.T) { assert.Equal(t, uint8(0), *cfg.HTTPS.FakeCount) assert.False(t, *cfg.HTTPS.Disorder) assert.Equal(t, HTTPSSplitModeSNI, *cfg.HTTPS.SplitMode) - assert.Equal(t, uint8(0), *cfg.HTTPS.ChunkSize) + assert.Equal(t, uint8(35), *cfg.HTTPS.ChunkSize) assert.False(t, *cfg.HTTPS.Skip) - assert.False(t, *cfg.Policy.Auto) + assert.Equal(t, 0, *cfg.UDP.FakeCount) + assert.Equal(t, 64, len(cfg.UDP.FakePacket)) }, }, { @@ -50,10 +53,12 @@ func TestCreateCommand_Flags(t *testing.T) { "--clean", // Ensure no config file interferes "--log-level", "debug", "--silent", - "--system-proxy", + "--auto-configure-network", "--listen-addr", "127.0.0.1:9090", - "--default-ttl", "128", - "--timeout", "5000", + "--default-fake-ttl", "128", + "--dns-timeout", "5000", + "--tcp-timeout", "5000", + "--udp-idle-timeout", "5000", "--dns-addr", "1.1.1.1:53", "--dns-mode", "https", "--dns-https-url", "https://cloudflare-dns.com/dns-query", @@ -65,18 +70,21 @@ func TestCreateCommand_Flags(t *testing.T) { "--https-split-mode", "chunk", "--https-chunk-size", "50", "--https-skip", - "--policy-auto", + "--udp-fake-count", "5", + "--udp-fake-packet", "0x01, 0x02", }, assert: func(t *testing.T, cfg *Config) { // General - assert.Equal(t, zerolog.DebugLevel, *cfg.General.LogLevel) - assert.True(t, *cfg.General.Silent) - assert.True(t, *cfg.General.SetSystemProxy) + assert.Equal(t, zerolog.DebugLevel, *cfg.App.LogLevel) + assert.True(t, *cfg.App.Silent) + assert.True(t, *cfg.App.AutoConfigureNetwork) // Server - assert.Equal(t, "127.0.0.1:9090", cfg.Server.ListenAddr.String()) - assert.Equal(t, uint8(128), *cfg.Server.DefaultTTL) - assert.Equal(t, 5000*time.Millisecond, *cfg.Server.Timeout) + assert.Equal(t, "127.0.0.1:9090", cfg.App.ListenAddr.String()) + assert.Equal(t, uint8(128), *cfg.Conn.DefaultFakeTTL) + assert.Equal(t, 5000*time.Millisecond, *cfg.Conn.DNSTimeout) + assert.Equal(t, 5000*time.Millisecond, *cfg.Conn.TCPTimeout) + assert.Equal(t, 5000*time.Millisecond, *cfg.Conn.UDPIdleTimeout) // DNS assert.Equal(t, "1.1.1.1:53", cfg.DNS.Addr.String()) @@ -93,8 +101,9 @@ func TestCreateCommand_Flags(t *testing.T) { assert.Equal(t, uint8(50), *cfg.HTTPS.ChunkSize) assert.True(t, *cfg.HTTPS.Skip) - // Policy - assert.True(t, *cfg.Policy.Auto) + // UDP + assert.Equal(t, 5, *cfg.UDP.FakeCount) + assert.Equal(t, []byte{0x01, 0x02}, cfg.UDP.FakePacket) }, }, { @@ -108,7 +117,7 @@ func TestCreateCommand_Flags(t *testing.T) { "--https-split-mode", "random", }, assert: func(t *testing.T, cfg *Config) { - assert.Equal(t, zerolog.ErrorLevel, *cfg.General.LogLevel) + assert.Equal(t, zerolog.ErrorLevel, *cfg.App.LogLevel) assert.Equal(t, DNSModeSystem, *cfg.DNS.Mode) assert.Equal(t, DNSQueryAll, *cfg.DNS.QType) assert.Equal(t, HTTPSSplitModeRandom, *cfg.HTTPS.SplitMode) @@ -122,9 +131,21 @@ func TestCreateCommand_Flags(t *testing.T) { "--listen-addr", "[::1]:1080", }, assert: func(t *testing.T, cfg *Config) { - assert.Equal(t, "[::1]:1080", cfg.Server.ListenAddr.String()) + assert.Equal(t, "[::1]:1080", cfg.App.ListenAddr.String()) ip := net.ParseIP("::1") - assert.True(t, cfg.Server.ListenAddr.IP.Equal(ip)) + assert.True(t, cfg.App.ListenAddr.IP.Equal(ip)) + }, + }, + { + name: "socks5 default port", + args: []string{ + "spoofdpi", + "--clean", + "--app-mode", "socks5", + }, + assert: func(t *testing.T, cfg *Config) { + assert.Equal(t, "127.0.0.1:1080", cfg.App.ListenAddr.String()) + assert.Equal(t, AppModeSOCKS5, *cfg.App.Mode) }, }, } @@ -153,15 +174,17 @@ func TestCreateCommand_Flags(t *testing.T) { func TestCreateCommand_OverrideTOML(t *testing.T) { tomlContent := ` -[general] +[app] log-level = "debug" silent = true system-proxy = true -[server] +[connection] listen-addr = "127.0.0.1:8080" - timeout = 1000 - default-ttl = 100 + dns-timeout = 1000 + tcp-timeout = 1000 + udp-idle-timeout = 1000 + default-fake-ttl = 100 [dns] addr = "8.8.8.8:53" @@ -179,7 +202,6 @@ func TestCreateCommand_OverrideTOML(t *testing.T) { skip = true [policy] - auto = true [[policy.overrides]] name = "test-rule" priority = 100 @@ -224,10 +246,12 @@ func TestCreateCommand_OverrideTOML(t *testing.T) { "--config", configPath, "--log-level", "error", "--silent=false", - "--system-proxy=false", + "--auto-configure-network=false", "--listen-addr", "127.0.0.1:9090", - "--timeout", "2000", - "--default-ttl", "200", + "--dns-timeout", "2000", + "--tcp-timeout", "2000", + "--udp-idle-timeout", "2000", + "--default-fake-ttl", "200", "--dns-addr", "1.1.1.1:53", "--dns-cache=false", "--dns-mode", "udp", @@ -239,7 +263,8 @@ func TestCreateCommand_OverrideTOML(t *testing.T) { "--https-split-mode", "sni", "--https-chunk-size", "10", "--https-skip=false", - "--policy-auto=false", + "--udp-fake-count", "20", + "--udp-fake-packet", "0xcc,0xdd", } err = cmd.Run(context.Background(), args) @@ -248,14 +273,16 @@ func TestCreateCommand_OverrideTOML(t *testing.T) { // Verify Overrides // General - assert.Equal(t, zerolog.ErrorLevel, *capturedCfg.General.LogLevel) - assert.False(t, *capturedCfg.General.Silent) - assert.False(t, *capturedCfg.General.SetSystemProxy) + assert.Equal(t, zerolog.ErrorLevel, *capturedCfg.App.LogLevel) + assert.False(t, *capturedCfg.App.Silent) + assert.False(t, *capturedCfg.App.AutoConfigureNetwork) // Server - assert.Equal(t, "127.0.0.1:9090", capturedCfg.Server.ListenAddr.String()) - assert.Equal(t, 2000*time.Millisecond, *capturedCfg.Server.Timeout) - assert.Equal(t, uint8(200), *capturedCfg.Server.DefaultTTL) + assert.Equal(t, "127.0.0.1:9090", capturedCfg.App.ListenAddr.String()) + assert.Equal(t, 2000*time.Millisecond, *capturedCfg.Conn.DNSTimeout) + assert.Equal(t, 2000*time.Millisecond, *capturedCfg.Conn.TCPTimeout) + assert.Equal(t, 2000*time.Millisecond, *capturedCfg.Conn.UDPIdleTimeout) + assert.Equal(t, uint8(200), *capturedCfg.Conn.DefaultFakeTTL) // DNS assert.Equal(t, "1.1.1.1:53", capturedCfg.DNS.Addr.String()) @@ -272,8 +299,10 @@ func TestCreateCommand_OverrideTOML(t *testing.T) { assert.Equal(t, uint8(10), *capturedCfg.HTTPS.ChunkSize) assert.False(t, *capturedCfg.HTTPS.Skip) - // Policy - assert.False(t, *capturedCfg.Policy.Auto) + // UDP + assert.Equal(t, 20, *capturedCfg.UDP.FakeCount) + assert.Equal(t, []byte{0xcc, 0xdd}, capturedCfg.UDP.FakePacket) + assert.Equal(t, []byte{0xcc, 0xdd}, capturedCfg.UDP.FakePacket) // Verify TOML-only fields are preserved require.Len(t, capturedCfg.Policy.Overrides, 1) diff --git a/internal/config/config.go b/internal/config/config.go index 757a2877..081108e7 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -6,8 +6,8 @@ import ( "time" "github.com/rs/zerolog" + "github.com/samber/lo" "github.com/xvzc/SpoofDPI/internal/proto" - "github.com/xvzc/SpoofDPI/internal/ptr" ) type merger[T any] interface { @@ -22,11 +22,12 @@ type cloner[T any] interface { var _ merger[*Config] = (*Config)(nil) type Config struct { - General *GeneralOptions `toml:"general"` - Server *ServerOptions `toml:"server"` - DNS *DNSOptions `toml:"dns"` - HTTPS *HTTPSOptions `toml:"https"` - Policy *PolicyOptions `toml:"policy"` + App *AppOptions `toml:"app"` + Conn *ConnOptions `toml:"connection"` + DNS *DNSOptions `toml:"dns"` + HTTPS *HTTPSOptions `toml:"https"` + UDP *UDPOptions `toml:"udp"` + Policy *PolicyOptions `toml:"policy"` } func (c *Config) UnmarshalTOML(data any) (err error) { @@ -35,10 +36,11 @@ func (c *Config) UnmarshalTOML(data any) (err error) { return fmt.Errorf("non-table type config file") } - c.General = findStructFrom[GeneralOptions](m, "general", &err) - c.Server = findStructFrom[ServerOptions](m, "server", &err) + c.App = findStructFrom[AppOptions](m, "app", &err) + c.Conn = findStructFrom[ConnOptions](m, "connection", &err) c.DNS = findStructFrom[DNSOptions](m, "dns", &err) c.HTTPS = findStructFrom[HTTPSOptions](m, "https", &err) + c.UDP = findStructFrom[UDPOptions](m, "udp", &err) c.Policy = findStructFrom[PolicyOptions](m, "policy", &err) return @@ -46,11 +48,12 @@ func (c *Config) UnmarshalTOML(data any) (err error) { func NewConfig() *Config { return &Config{ - General: &GeneralOptions{}, - Server: &ServerOptions{}, - DNS: &DNSOptions{}, - HTTPS: &HTTPSOptions{}, - Policy: &PolicyOptions{}, + App: &AppOptions{}, + Conn: &ConnOptions{}, + DNS: &DNSOptions{}, + HTTPS: &HTTPSOptions{}, + UDP: &UDPOptions{}, + Policy: &PolicyOptions{}, } } @@ -60,11 +63,12 @@ func (c *Config) Clone() *Config { } return &Config{ - General: c.General.Clone(), - Server: c.Server.Clone(), - DNS: c.DNS.Clone(), - HTTPS: c.HTTPS.Clone(), - Policy: c.Policy.Clone(), + App: c.App.Clone(), + Conn: c.Conn.Clone(), + DNS: c.DNS.Clone(), + HTTPS: c.HTTPS.Clone(), + UDP: c.UDP.Clone(), + Policy: c.Policy.Clone(), } } @@ -78,11 +82,12 @@ func (origin *Config) Merge(overrides *Config) *Config { } return &Config{ - General: origin.General.Merge(overrides.General), - Server: origin.Server.Merge(overrides.Server), - DNS: origin.DNS.Merge(overrides.DNS), - HTTPS: origin.HTTPS.Merge(overrides.HTTPS), - Policy: origin.Policy.Merge(overrides.Policy), + App: origin.App.Merge(overrides.App), + Conn: origin.Conn.Merge(overrides.Conn), + DNS: origin.DNS.Merge(overrides.DNS), + HTTPS: origin.HTTPS.Merge(overrides.HTTPS), + UDP: origin.UDP.Merge(overrides.UDP), + Policy: origin.Policy.Merge(overrides.Policy), } } @@ -91,13 +96,20 @@ func (c *Config) ShouldEnablePcap() bool { return true } + if c.UDP != nil && c.UDP.FakeCount != nil && *c.UDP.FakeCount > 0 { + return true + } + if c.Policy == nil { return false } if c.Policy.Template != nil { template := c.Policy.Template - if template.HTTPS != nil && ptr.FromPtr(template.HTTPS.FakeCount) > 0 { + if template.HTTPS != nil && lo.FromPtr(template.HTTPS.FakeCount) > 0 { + return true + } + if template.UDP != nil && lo.FromPtr(template.UDP.FakeCount) > 0 { return true } } @@ -105,15 +117,11 @@ func (c *Config) ShouldEnablePcap() bool { if c.Policy.Overrides != nil { rules := c.Policy.Overrides for _, r := range rules { - if r.HTTPS == nil { - continue - } - - if r.HTTPS.FakeCount == nil { - continue + if r.HTTPS != nil && r.HTTPS.FakeCount != nil && *r.HTTPS.FakeCount > 0 { + return true } - if *r.HTTPS.FakeCount > 0 { + if r.UDP != nil && r.UDP.FakeCount != nil && *r.UDP.FakeCount > 0 { return true } } @@ -124,33 +132,40 @@ func (c *Config) ShouldEnablePcap() bool { func getDefault() *Config { //exhaustruct:enforce return &Config{ - General: &GeneralOptions{ - LogLevel: ptr.FromValue(zerolog.InfoLevel), - Silent: ptr.FromValue(false), - SetSystemProxy: ptr.FromValue(false), + App: &AppOptions{ + LogLevel: lo.ToPtr(zerolog.InfoLevel), + Silent: lo.ToPtr(false), + AutoConfigureNetwork: lo.ToPtr(false), + Mode: lo.ToPtr(AppModeHTTP), + ListenAddr: nil, }, - Server: &ServerOptions{ - DefaultTTL: ptr.FromValue(uint8(64)), - ListenAddr: &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 8080, Zone: ""}, - Timeout: ptr.FromValue(time.Duration(0)), + Conn: &ConnOptions{ + DefaultFakeTTL: lo.ToPtr(uint8(8)), + DNSTimeout: lo.ToPtr(time.Duration(5000) * time.Millisecond), + TCPTimeout: lo.ToPtr(time.Duration(10000) * time.Millisecond), + UDPIdleTimeout: lo.ToPtr(time.Duration(25000) * time.Millisecond), }, DNS: &DNSOptions{ - Mode: ptr.FromValue(DNSModeUDP), + Mode: lo.ToPtr(DNSModeUDP), Addr: &net.TCPAddr{IP: net.ParseIP("8.8.8.8"), Port: 53, Zone: ""}, - HTTPSURL: ptr.FromValue("https://dns.google/dns-query"), - QType: ptr.FromValue(DNSQueryIPv4), - Cache: ptr.FromValue(false), + HTTPSURL: lo.ToPtr("https://dns.google/dns-query"), + QType: lo.ToPtr(DNSQueryIPv4), + Cache: lo.ToPtr(false), }, HTTPS: &HTTPSOptions{ - Disorder: ptr.FromValue(false), - FakeCount: ptr.FromValue(uint8(0)), - FakePacket: proto.NewFakeTLSMessage([]byte(FakeClientHello)), - SplitMode: ptr.FromValue(HTTPSSplitModeSNI), - ChunkSize: ptr.FromValue(uint8(0)), - Skip: ptr.FromValue(false), + Disorder: lo.ToPtr(false), + FakeCount: lo.ToPtr(uint8(0)), + FakePacket: proto.NewFakeTLSMessage([]byte(FakeClientHello)), + SplitMode: lo.ToPtr(HTTPSSplitModeSNI), + ChunkSize: lo.ToPtr(uint8(35)), + CustomSegmentPlans: []SegmentPlan{}, + Skip: lo.ToPtr(false), + }, + UDP: &UDPOptions{ + FakeCount: lo.ToPtr(0), + FakePacket: make([]byte, 64), }, Policy: &PolicyOptions{ - Auto: ptr.FromValue(false), Template: &Rule{}, Overrides: []Rule{}, }, diff --git a/internal/config/config_test.go b/internal/config/config_test.go index c148caaf..10b55b35 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -4,8 +4,8 @@ import ( "net" "testing" + "github.com/samber/lo" "github.com/stretchr/testify/assert" - "github.com/xvzc/SpoofDPI/internal/ptr" ) func TestConfig_UnmarshalTOML(t *testing.T) { @@ -18,7 +18,7 @@ func TestConfig_UnmarshalTOML(t *testing.T) { { name: "valid config", input: map[string]any{ - "server": map[string]any{ + "app": map[string]any{ "listen-addr": "127.0.0.1:9090", }, "dns": map[string]any{ @@ -40,7 +40,7 @@ func TestConfig_UnmarshalTOML(t *testing.T) { }, wantErr: false, assert: func(t *testing.T, c Config) { - assert.Equal(t, "127.0.0.1:9090", c.Server.ListenAddr.String()) + assert.Equal(t, "127.0.0.1:9090", c.App.ListenAddr.String()) assert.Equal(t, "1.1.1.1:53", c.DNS.Addr.String()) if assert.Len(t, c.Policy.Overrides, 1) { assert.Equal(t, "test", *c.Policy.Overrides[0].Name) @@ -55,7 +55,7 @@ func TestConfig_UnmarshalTOML(t *testing.T) { { name: "validation error", input: map[string]any{ - "server": map[string]any{ + "app": map[string]any{ "listen-addr": "invalid-addr", }, }, @@ -89,7 +89,7 @@ func TestConfig_ShouldEnablePcap(t *testing.T) { name: "global fake count > 0", config: Config{ HTTPS: &HTTPSOptions{ - FakeCount: ptr.FromValue(uint8(1)), + FakeCount: lo.ToPtr(uint8(1)), }, }, expect: true, @@ -98,13 +98,13 @@ func TestConfig_ShouldEnablePcap(t *testing.T) { name: "rule fake count > 0", config: Config{ HTTPS: &HTTPSOptions{ - FakeCount: ptr.FromValue(uint8(0)), + FakeCount: lo.ToPtr(uint8(0)), }, Policy: &PolicyOptions{ Overrides: []Rule{ { HTTPS: &HTTPSOptions{ - FakeCount: ptr.FromValue(uint8(1)), + FakeCount: lo.ToPtr(uint8(1)), }, }, }, @@ -116,13 +116,13 @@ func TestConfig_ShouldEnablePcap(t *testing.T) { name: "none", config: Config{ HTTPS: &HTTPSOptions{ - FakeCount: ptr.FromValue(uint8(0)), + FakeCount: lo.ToPtr(uint8(0)), }, Policy: &PolicyOptions{ Overrides: []Rule{ { HTTPS: &HTTPSOptions{ - FakeCount: ptr.FromValue(uint8(0)), + FakeCount: lo.ToPtr(uint8(0)), }, }, }, @@ -149,17 +149,29 @@ func TestConfig_Merge(t *testing.T) { { name: "keep toml if arg is nil", tomlCfg: &Config{ - Server: &ServerOptions{ + App: &AppOptions{ ListenAddr: &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 8080}, }, }, argsCfg: &Config{ - Server: &ServerOptions{ + App: &AppOptions{ ListenAddr: nil, }, }, assert: func(t *testing.T, merged *Config) { - assert.Equal(t, "127.0.0.1:8080", merged.Server.ListenAddr.String()) + assert.Equal(t, "127.0.0.1:8080", merged.App.ListenAddr.String()) + }, + }, + { + name: "default udp fake packet", + tomlCfg: &Config{}, + argsCfg: &Config{}, + assert: func(t *testing.T, merged *Config) { + defaultCfg := getDefault() + assert.Equal(t, 64, len(defaultCfg.UDP.FakePacket)) + for _, b := range defaultCfg.UDP.FakePacket { + assert.Equal(t, byte(0), b) + } }, }, } @@ -187,17 +199,17 @@ func TestConfig_Clone(t *testing.T) { { name: "non-nil receiver", input: &Config{ - General: &GeneralOptions{}, - Server: &ServerOptions{}, - DNS: &DNSOptions{}, - HTTPS: &HTTPSOptions{}, - Policy: &PolicyOptions{}, + App: &AppOptions{}, + Conn: &ConnOptions{}, + DNS: &DNSOptions{}, + HTTPS: &HTTPSOptions{}, + Policy: &PolicyOptions{}, }, assert: func(t *testing.T, input *Config, output *Config) { assert.NotNil(t, output) assert.NotSame(t, input, output) - assert.NotSame(t, input.General, output.General) - assert.NotSame(t, input.Server, output.Server) + assert.NotSame(t, input.App, output.App) + assert.NotSame(t, input.Conn, output.Conn) assert.NotSame(t, input.DNS, output.DNS) assert.NotSame(t, input.HTTPS, output.HTTPS) assert.NotSame(t, input.Policy, output.Policy) diff --git a/internal/config/parse.go b/internal/config/parse.go index 7d2da219..23905376 100644 --- a/internal/config/parse.go +++ b/internal/config/parse.go @@ -108,6 +108,19 @@ func MustParseLogLevel(s string) zerolog.Level { return level } +func MustParseServerModeType(s string) AppModeType { + switch s { + case "http": + return AppModeHTTP + case "socks5": + return AppModeSOCKS5 + case "tun": + return AppModeTUN + default: + panic(fmt.Sprintf("cannot parse %q to ServerModeType", s)) + } +} + func MustParseDNSModeType(s string) DNSModeType { switch s { case "udp": @@ -146,11 +159,24 @@ func mustParseHTTPSSplitModeType(s string) HTTPSSplitModeType { return HTTPSSplitModeFirstByte case "none": return HTTPSSplitModeNone + case "custom": + return HTTPSSplitModeCustom default: panic(fmt.Sprintf("cannot parse %q to HTTPSSplitModeType", s)) } } +func mustParseSegmentFromType(s string) SegmentFromType { + switch s { + case "sni": + return SegmentFromSNI + case "head": + return SegmentFromHead + default: + panic(fmt.Sprintf("cannot parse %q to SegmentFromType", s)) + } +} + type Integer interface { ~int | ~int8 | ~int16 | ~int32 | ~int64 | ~uint | ~uint8 | ~uint16 | ~uint32 | ~uint64 | ~uintptr diff --git a/internal/config/segment_test.go b/internal/config/segment_test.go new file mode 100644 index 00000000..d912156b --- /dev/null +++ b/internal/config/segment_test.go @@ -0,0 +1 @@ +package config diff --git a/internal/config/toml.go b/internal/config/toml.go index d2c5c498..9ada3adb 100644 --- a/internal/config/toml.go +++ b/internal/config/toml.go @@ -5,7 +5,7 @@ import ( "os" "github.com/BurntSushi/toml" - "github.com/xvzc/SpoofDPI/internal/ptr" + "github.com/samber/lo" ) func fromTomlFile(dir string) (*Config, error) { @@ -65,7 +65,7 @@ func findFrom[T any]( return nil } - return ptr.FromValue(val) + return lo.ToPtr(val) } func findStructFrom[T any, PT interface { diff --git a/internal/config/toml_test.go b/internal/config/toml_test.go index e97b1d84..fb05cfbb 100644 --- a/internal/config/toml_test.go +++ b/internal/config/toml_test.go @@ -414,15 +414,17 @@ func TestFindSliceFrom(t *testing.T) { func TestFromTomlFile(t *testing.T) { t.Run("full valid config", func(t *testing.T) { tomlContent := ` - [general] - log-level = "debug" - silent = true - system-proxy = true - - [server] - listen-addr = "127.0.0.1:8080" - timeout = 1000 - default-ttl = 100 + [app] + log-level = "debug" + silent = true + auto-configure-network = true + mode = "socks5" + listen-addr = "127.0.0.1:8080" + [connection] + dns-timeout = 1000 + tcp-timeout = 1000 + udp-idle-timeout = 1000 + default-fake-ttl = 100 [dns] addr = "8.8.8.8:53" @@ -440,7 +442,6 @@ func TestFromTomlFile(t *testing.T) { skip = true [policy] - auto = true [[policy.overrides]] name = "test-rule" priority = 100 @@ -482,18 +483,20 @@ func TestFromTomlFile(t *testing.T) { return } - assert.Equal(t, "127.0.0.1:8080", cfg.Server.ListenAddr.String()) - assert.Equal(t, time.Duration(1000*time.Millisecond), *cfg.Server.Timeout) - assert.Equal(t, zerolog.DebugLevel, *cfg.General.LogLevel) - assert.True(t, *cfg.General.Silent) - assert.True(t, *cfg.General.SetSystemProxy) + assert.Equal(t, "127.0.0.1:8080", cfg.App.ListenAddr.String()) + assert.Equal(t, time.Duration(1000*time.Millisecond), *cfg.Conn.DNSTimeout) + assert.Equal(t, time.Duration(1000*time.Millisecond), *cfg.Conn.TCPTimeout) + assert.Equal(t, time.Duration(1000*time.Millisecond), *cfg.Conn.UDPIdleTimeout) + assert.Equal(t, zerolog.DebugLevel, *cfg.App.LogLevel) + assert.True(t, *cfg.App.Silent) + assert.True(t, *cfg.App.AutoConfigureNetwork) + assert.Equal(t, AppModeSOCKS5, *cfg.App.Mode) assert.Equal(t, "8.8.8.8:53", cfg.DNS.Addr.String()) assert.True(t, *cfg.DNS.Cache) assert.Equal(t, DNSModeHTTPS, *cfg.DNS.Mode) assert.Equal(t, "https://1.1.1.1/dns-query", *cfg.DNS.HTTPSURL) assert.Equal(t, DNSQueryIPv4, *cfg.DNS.QType) - assert.Equal(t, uint8(100), *cfg.Server.DefaultTTL) - assert.True(t, *cfg.Policy.Auto) + assert.Equal(t, uint8(100), *cfg.Conn.DefaultFakeTTL) assert.True(t, *cfg.HTTPS.Disorder) assert.Equal(t, uint8(5), *cfg.HTTPS.FakeCount) assert.Equal(t, []byte{0x01, 0x02, 0x03}, cfg.HTTPS.FakePacket.Raw()) diff --git a/internal/config/types.go b/internal/config/types.go index 5b613c20..3d3759b3 100644 --- a/internal/config/types.go +++ b/internal/config/types.go @@ -1,62 +1,106 @@ package config import ( + "encoding/json" "fmt" + "math" "net" + "slices" "strings" "time" "github.com/rs/zerolog" + "github.com/samber/lo" "github.com/xvzc/SpoofDPI/internal/proto" - "github.com/xvzc/SpoofDPI/internal/ptr" ) +type primitive interface { + ~bool | ~string | + ~int | ~int8 | ~int16 | ~int32 | ~int64 | + ~uint | ~uint8 | ~uint16 | ~uint32 | ~uint64 | ~uintptr | + ~float32 | ~float64 | + ~complex64 | ~complex128 +} + +func clonePrimitive[T primitive](x *T) *T { + if x == nil { + return nil + } + return lo.ToPtr(lo.FromPtr(x)) +} + // ┌─────────────────┐ // │ GENERAL OPTIONS │ // └─────────────────┘ -var _ merger[*GeneralOptions] = (*GeneralOptions)(nil) +var _ merger[*AppOptions] = (*AppOptions)(nil) -var availableLogLevels = []string{"info", "warn", "trace", "error", "debug"} +var availableLogLevelValues = []string{ + "info", + "warn", + "trace", + "error", + "debug", + "disabled", +} -type GeneralOptions struct { - LogLevel *zerolog.Level `toml:"log-level"` - Silent *bool `toml:"silent"` - SetSystemProxy *bool `toml:"system-proxy"` +type AppOptions struct { + LogLevel *zerolog.Level `toml:"log-level"` + Silent *bool `toml:"silent"` + AutoConfigureNetwork *bool `toml:"auto-configure-network"` + Mode *AppModeType `toml:"mode"` + ListenAddr *net.TCPAddr `toml:"listen-addr"` } -func (o *GeneralOptions) UnmarshalTOML(data any) (err error) { +func (o *AppOptions) UnmarshalTOML(data any) (err error) { m, ok := data.(map[string]any) if !ok { return fmt.Errorf("non-table type general config") } o.Silent = findFrom(m, "silent", parseBoolFn(), &err) - o.SetSystemProxy = findFrom(m, "system-proxy", parseBoolFn(), &err) + o.AutoConfigureNetwork = findFrom(m, "auto-configure-network", parseBoolFn(), &err) if p := findFrom(m, "log-level", parseStringFn(checkLogLevel), &err); isOk(p, err) { - o.LogLevel = ptr.FromValue(MustParseLogLevel(*p)) + o.LogLevel = lo.ToPtr(MustParseLogLevel(*p)) + } + if p := findFrom(m, "mode", parseStringFn(checkAppMode), &err); isOk(p, err) { + o.Mode = lo.ToPtr(MustParseServerModeType(*p)) + } + if p := findFrom(m, "listen-addr", parseStringFn(checkHostPort), &err); isOk(p, err) { + o.ListenAddr = lo.ToPtr(MustParseTCPAddr(*p)) } return err } -func (o *GeneralOptions) Clone() *GeneralOptions { +func (o *AppOptions) Clone() *AppOptions { if o == nil { return nil } var newLevel *zerolog.Level if o.LogLevel != nil { - newLevel = ptr.FromValue(MustParseLogLevel(strings.ToLower(o.LogLevel.String()))) + newLevel = lo.ToPtr(MustParseLogLevel(strings.ToLower(o.LogLevel.String()))) + } + + var newAddr *net.TCPAddr + if o.ListenAddr != nil { + newAddr = &net.TCPAddr{ + IP: append(net.IP(nil), o.ListenAddr.IP...), + Port: o.ListenAddr.Port, + Zone: o.ListenAddr.Zone, + } } - return &GeneralOptions{ - LogLevel: newLevel, - Silent: ptr.Clone(o.Silent), - SetSystemProxy: ptr.Clone(o.SetSystemProxy), + return &AppOptions{ + LogLevel: newLevel, + Silent: clonePrimitive(o.Silent), + AutoConfigureNetwork: clonePrimitive(o.AutoConfigureNetwork), + Mode: clonePrimitive(o.Mode), + ListenAddr: newAddr, } } -func (origin *GeneralOptions) Merge(overrides *GeneralOptions) *GeneralOptions { +func (origin *AppOptions) Merge(overrides *AppOptions) *AppOptions { if overrides == nil { return origin.Clone() } @@ -65,65 +109,108 @@ func (origin *GeneralOptions) Merge(overrides *GeneralOptions) *GeneralOptions { return overrides.Clone() } - return &GeneralOptions{ - LogLevel: ptr.CloneOr(overrides.LogLevel, origin.LogLevel), - Silent: ptr.CloneOr(overrides.Silent, origin.Silent), - SetSystemProxy: ptr.CloneOr(overrides.SetSystemProxy, origin.SetSystemProxy), + return &AppOptions{ + LogLevel: lo.CoalesceOrEmpty(overrides.LogLevel, origin.LogLevel), + Silent: lo.CoalesceOrEmpty(overrides.Silent, origin.Silent), + AutoConfigureNetwork: lo.CoalesceOrEmpty( + overrides.AutoConfigureNetwork, + origin.AutoConfigureNetwork, + ), + Mode: lo.CoalesceOrEmpty(overrides.Mode, origin.Mode), + ListenAddr: lo.CoalesceOrEmpty(overrides.ListenAddr, origin.ListenAddr), } } -// ┌────────────────┐ -// │ SERVER OPTIONS │ -// └────────────────┘ -var _ merger[*ServerOptions] = (*ServerOptions)(nil) +// ┌──────────────────────┐ +// │ CONNECTION OPTIONS │ +// └──────────────────────┘ +var _ merger[*ConnOptions] = (*ConnOptions)(nil) -type ServerOptions struct { - DefaultTTL *uint8 `toml:"default-ttl"` - ListenAddr *net.TCPAddr `toml:"listen-addr"` - Timeout *time.Duration `toml:"timeout"` -} +type AppModeType int -func (o *ServerOptions) UnmarshalTOML(data any) (err error) { - v, ok := data.(map[string]any) - if !ok { - return fmt.Errorf("non-table type server config") - } +const ( + AppModeHTTP AppModeType = iota + AppModeSOCKS5 + AppModeTUN +) - o.DefaultTTL = findFrom(v, "default-ttl", parseIntFn[uint8](checkUint8NonZero), &err) +var availableAppModeValues = []string{"http", "socks5", "tun"} - if p := findFrom(v, "listen-addr", parseStringFn(checkHostPort), &err); isOk(p, err) { - o.ListenAddr = ptr.FromValue(MustParseTCPAddr(*p)) - } +func (t AppModeType) String() string { + return availableAppModeValues[t] +} + +type ConnOptions struct { + DefaultFakeTTL *uint8 `toml:"default-fake-ttl"` + DNSTimeout *time.Duration `toml:"dns-timeout"` + TCPTimeout *time.Duration `toml:"tcp-timeout"` + UDPIdleTimeout *time.Duration `toml:"udp-idle-timeout"` +} - if p := findFrom(v, "timeout", parseIntFn[uint16](checkUint16), &err); isOk(p, err) { - o.Timeout = ptr.FromValue(time.Duration(*p) * time.Millisecond) +func (o *ConnOptions) UnmarshalTOML(data any) (err error) { + v, ok := data.(map[string]any) + if !ok { + return fmt.Errorf("non-table type connection config") + } + + o.DefaultFakeTTL = findFrom( + v, + "default-fake-ttl", + parseIntFn[uint8](checkUint8NonZero), + &err, + ) + + if p := findFrom( + v, + "dns-timeout", + parseIntFn[uint16](checkUint16), + &err, + ); isOk( + p, + err, + ) { + o.DNSTimeout = lo.ToPtr(time.Duration(*p) * time.Millisecond) + } + if p := findFrom( + v, + "tcp-timeout", + parseIntFn[uint16](checkUint16), + &err, + ); isOk( + p, + err, + ) { + o.TCPTimeout = lo.ToPtr(time.Duration(*p) * time.Millisecond) + } + if p := findFrom( + v, + "udp-idle-timeout", + parseIntFn[uint16](checkUint16), + &err, + ); isOk( + p, + err, + ) { + o.UDPIdleTimeout = lo.ToPtr(time.Duration(*p) * time.Millisecond) } return err } -func (o *ServerOptions) Clone() *ServerOptions { +func (o *ConnOptions) Clone() *ConnOptions { if o == nil { return nil } - var newAddr *net.TCPAddr - if o.ListenAddr != nil { - newAddr = &net.TCPAddr{ - IP: append(net.IP(nil), o.ListenAddr.IP...), - Port: o.ListenAddr.Port, - Zone: o.ListenAddr.Zone, - } - } - - return &ServerOptions{ - DefaultTTL: ptr.Clone(o.DefaultTTL), - ListenAddr: newAddr, - Timeout: ptr.Clone(o.Timeout), + return &ConnOptions{ + DefaultFakeTTL: clonePrimitive(o.DefaultFakeTTL), + DNSTimeout: clonePrimitive(o.DNSTimeout), + TCPTimeout: clonePrimitive(o.TCPTimeout), + UDPIdleTimeout: clonePrimitive(o.UDPIdleTimeout), } } -func (origin *ServerOptions) Merge(overrides *ServerOptions) *ServerOptions { +func (origin *ConnOptions) Merge(overrides *ConnOptions) *ConnOptions { if overrides == nil { return origin.Clone() } @@ -132,10 +219,11 @@ func (origin *ServerOptions) Merge(overrides *ServerOptions) *ServerOptions { return overrides.Clone() } - return &ServerOptions{ - DefaultTTL: ptr.CloneOr(overrides.DefaultTTL, origin.DefaultTTL), - ListenAddr: ptr.CloneOr(overrides.ListenAddr, origin.ListenAddr), - Timeout: ptr.CloneOr(overrides.Timeout, origin.Timeout), + return &ConnOptions{ + DefaultFakeTTL: lo.CoalesceOrEmpty(overrides.DefaultFakeTTL, origin.DefaultFakeTTL), + DNSTimeout: lo.CoalesceOrEmpty(overrides.DNSTimeout, origin.DNSTimeout), + TCPTimeout: lo.CoalesceOrEmpty(overrides.TCPTimeout, origin.TCPTimeout), + UDPIdleTimeout: lo.CoalesceOrEmpty(overrides.UDPIdleTimeout, origin.UDPIdleTimeout), } } @@ -150,8 +238,8 @@ type ( ) var ( - availableDNSModes = []string{"udp", "https", "system"} - availableDNSQueries = []string{"ipv4", "ipv6", "all"} + availableDNSModeValues = []string{"udp", "https", "system"} + availableDNSQueryValues = []string{"ipv4", "ipv6", "all"} ) const ( @@ -167,11 +255,11 @@ const ( ) func (t DNSModeType) String() string { - return availableDNSModes[t] + return availableDNSModeValues[t] } func (t DNSQueryType) String() string { - return availableDNSQueries[t] + return availableDNSQueryValues[t] } type DNSOptions struct { @@ -189,17 +277,17 @@ func (o *DNSOptions) UnmarshalTOML(data any) (err error) { } if p := findFrom(m, "mode", parseStringFn(checkDNSMode), &err); isOk(p, err) { - o.Mode = ptr.FromValue(MustParseDNSModeType(*p)) + o.Mode = lo.ToPtr(MustParseDNSModeType(*p)) } if p := findFrom(m, "addr", parseStringFn(checkHostPort), &err); isOk(p, err) { - o.Addr = ptr.FromValue(MustParseTCPAddr(*p)) + o.Addr = lo.ToPtr(MustParseTCPAddr(*p)) } o.HTTPSURL = findFrom(m, "https-url", parseStringFn(checkHTTPSEndpoint), &err) if p := findFrom(m, "qtype", parseStringFn(checkDNSQueryType), &err); isOk(p, err) { - o.QType = ptr.FromValue(MustParseDNSQueryType(*p)) + o.QType = lo.ToPtr(MustParseDNSQueryType(*p)) } o.Cache = findFrom(m, "cache", parseBoolFn(), &err) @@ -222,11 +310,11 @@ func (o *DNSOptions) Clone() *DNSOptions { } return &DNSOptions{ - Mode: ptr.Clone(o.Mode), + Mode: clonePrimitive(o.Mode), Addr: newAddr, - HTTPSURL: ptr.Clone(o.HTTPSURL), - QType: ptr.Clone(o.QType), - Cache: ptr.Clone(o.Cache), + HTTPSURL: clonePrimitive(o.HTTPSURL), + QType: clonePrimitive(o.QType), + Cache: clonePrimitive(o.Cache), } } @@ -240,11 +328,11 @@ func (origin *DNSOptions) Merge(overrides *DNSOptions) *DNSOptions { } return &DNSOptions{ - Mode: ptr.CloneOr(overrides.Mode, origin.Mode), - Addr: ptr.CloneOr(overrides.Addr, origin.Addr), - HTTPSURL: ptr.CloneOr(overrides.HTTPSURL, origin.HTTPSURL), - QType: ptr.CloneOr(overrides.QType, origin.QType), - Cache: ptr.CloneOr(overrides.Cache, origin.Cache), + Mode: lo.CoalesceOrEmpty(overrides.Mode, origin.Mode), + Addr: lo.CoalesceOrEmpty(overrides.Addr, origin.Addr), + HTTPSURL: lo.CoalesceOrEmpty(overrides.HTTPSURL, origin.HTTPSURL), + QType: lo.CoalesceOrEmpty(overrides.QType, origin.QType), + Cache: lo.CoalesceOrEmpty(overrides.Cache, origin.Cache), } } @@ -294,27 +382,100 @@ const FakeClientHello = "" + type HTTPSSplitModeType int -var availableHTTPSModes = []string{"sni", "random", "chunk", "first-byte", "none"} +var availableHTTPSModeValues = []string{ + "sni", + "random", + "chunk", + "first-byte", + "custom", + "none", +} const ( HTTPSSplitModeSNI HTTPSSplitModeType = iota HTTPSSplitModeRandom HTTPSSplitModeChunk HTTPSSplitModeFirstByte + HTTPSSplitModeCustom HTTPSSplitModeNone ) func (k HTTPSSplitModeType) String() string { - return availableHTTPSModes[k] + return availableHTTPSModeValues[k] +} + +type SegmentFromType int + +var availableSegmentFromValues = []string{"head", "sni"} + +const ( + SegmentFromHead SegmentFromType = iota + SegmentFromSNI +) + +func (s SegmentFromType) String() string { + return availableSegmentFromValues[s] +} + +type SegmentPlan struct { + From SegmentFromType `toml:"from"` + At int `toml:"at"` + Lazy bool `toml:"lazy"` + Noise int `toml:"noise"` +} + +func (s *SegmentPlan) UnmarshalTOML(data any) (err error) { + m, ok := data.(map[string]any) + if !ok { + return fmt.Errorf("segment must be table type") + } + + if _, ok := m["from"]; !ok { + return fmt.Errorf("field 'from' is required") + } + if p := findFrom(m, "from", parseStringFn(checkSegmentFrom), &err); isOk(p, err) { + s.From = mustParseSegmentFromType(*p) + } + + if _, ok := m["at"]; !ok { + return fmt.Errorf("field 'at' is required") + } + if p := findFrom(m, "at", parseIntFn[int](nil), &err); isOk(p, err) { + s.At = *p + } + + if p := findFrom(m, "lazy", parseBoolFn(), &err); isOk(p, err) { + s.Lazy = *p + } + + if p := findFrom(m, "noise", parseIntFn[int](nil), &err); isOk(p, err) { + s.Noise = *p + } + + return err +} + +func (s *SegmentPlan) Clone() *SegmentPlan { + if s == nil { + return nil + } + + return &SegmentPlan{ + From: s.From, + At: s.At, + Lazy: s.Lazy, + Noise: s.Noise, + } } type HTTPSOptions struct { - Disorder *bool `toml:"disorder" json:"ds,omitempty"` - FakeCount *uint8 `toml:"fake-count" json:"fc,omitempty"` - FakePacket *proto.TLSMessage `toml:"fake-packet" json:"fp,omitempty"` - SplitMode *HTTPSSplitModeType `toml:"split-mode" json:"sm,omitempty"` - ChunkSize *uint8 `toml:"chunk-size" json:"cs,omitempty"` - Skip *bool `toml:"skip" json:"sk,omitempty"` + Disorder *bool `toml:"disorder" json:"ds,omitempty"` + FakeCount *uint8 `toml:"fake-count" json:"fc,omitempty"` + FakePacket *proto.TLSMessage `toml:"fake-packet" json:"fp,omitempty"` + SplitMode *HTTPSSplitModeType `toml:"split-mode" json:"sm,omitempty"` + ChunkSize *uint8 `toml:"chunk-size" json:"cs,omitempty"` + Skip *bool `toml:"skip" json:"sk,omitempty"` + CustomSegmentPlans []SegmentPlan `toml:"custom-segments" json:"cseg,omitempty"` } func (o *HTTPSOptions) UnmarshalTOML(data any) (err error) { @@ -333,16 +494,22 @@ func (o *HTTPSOptions) UnmarshalTOML(data any) (err error) { splitModeParser := parseStringFn(checkHTTPSSplitMode) if p := findFrom(m, "split-mode", splitModeParser, &err); isOk(p, err) { - o.SplitMode = ptr.FromValue(mustParseHTTPSSplitModeType(*p)) + o.SplitMode = lo.ToPtr(mustParseHTTPSSplitModeType(*p)) } o.ChunkSize = findFrom(m, "chunk-size", parseIntFn[uint8](checkUint8NonZero), &err) o.Skip = findFrom(m, "skip", parseBoolFn(), &err) if o.Skip == nil { - o.Skip = ptr.FromValue(false) + o.Skip = lo.ToPtr(false) + } + + o.CustomSegmentPlans = findStructSliceFrom[SegmentPlan](m, "custom-segments", &err) + if err == nil && o.SplitMode != nil && *o.SplitMode == HTTPSSplitModeCustom && + len(o.CustomSegmentPlans) == 0 { + err = fmt.Errorf("custom-segments must be provided when split-mode is 'custom'") } - return nil + return err } func (o *HTTPSOptions) Clone() *HTTPSOptions { @@ -355,32 +522,100 @@ func (o *HTTPSOptions) Clone() *HTTPSOptions { fakePacket = proto.NewFakeTLSMessage(o.FakePacket.Raw()) } + var customSegmentPlans []SegmentPlan + if o.CustomSegmentPlans != nil { + customSegmentPlans = make([]SegmentPlan, 0, len(o.CustomSegmentPlans)) + for _, s := range o.CustomSegmentPlans { + customSegmentPlans = append(customSegmentPlans, *s.Clone()) + } + } + return &HTTPSOptions{ - Disorder: ptr.Clone(o.Disorder), - FakeCount: ptr.Clone(o.FakeCount), - FakePacket: fakePacket, - SplitMode: ptr.Clone(o.SplitMode), - ChunkSize: ptr.Clone(o.ChunkSize), - Skip: ptr.Clone(o.Skip), + Disorder: clonePrimitive(o.Disorder), + FakeCount: clonePrimitive(o.FakeCount), + FakePacket: fakePacket, + SplitMode: clonePrimitive(o.SplitMode), + ChunkSize: clonePrimitive(o.ChunkSize), + Skip: clonePrimitive(o.Skip), + CustomSegmentPlans: customSegmentPlans, } } func (origin *HTTPSOptions) Merge(overrides *HTTPSOptions) *HTTPSOptions { if overrides == nil { - return origin + return origin.Clone() } if origin == nil { - return overrides + return overrides.Clone() } return &HTTPSOptions{ - Disorder: ptr.CloneOr(overrides.Disorder, origin.Disorder), - FakeCount: ptr.CloneOr(overrides.FakeCount, origin.FakeCount), - FakePacket: ptr.CloneOr(overrides.FakePacket, origin.FakePacket), - SplitMode: ptr.CloneOr(overrides.SplitMode, origin.SplitMode), - ChunkSize: ptr.CloneOr(overrides.ChunkSize, origin.ChunkSize), - Skip: ptr.CloneOr(overrides.Skip, origin.Skip), + Disorder: lo.CoalesceOrEmpty(overrides.Disorder, origin.Disorder), + FakeCount: lo.CoalesceOrEmpty(overrides.FakeCount, origin.FakeCount), + FakePacket: lo.CoalesceOrEmpty(overrides.FakePacket, origin.FakePacket), + SplitMode: lo.CoalesceOrEmpty(overrides.SplitMode, origin.SplitMode), + ChunkSize: lo.CoalesceOrEmpty(overrides.ChunkSize, origin.ChunkSize), + Skip: lo.CoalesceOrEmpty(overrides.Skip, origin.Skip), + CustomSegmentPlans: lo.CoalesceSliceOrEmpty( + origin.CustomSegmentPlans, + overrides.CustomSegmentPlans, + ), + } +} + +// ┌─────────────┐ +// │ UDP OPTIONS │ +// └─────────────┘ +var _ merger[*UDPOptions] = (*UDPOptions)(nil) + +type UDPOptions struct { + FakeCount *int `toml:"fake-count" json:"fc,omitempty"` + FakePacket []byte `toml:"fake-packet" json:"fp,omitempty"` +} + +func (o *UDPOptions) UnmarshalTOML(data any) (err error) { + m, ok := data.(map[string]any) + if !ok { + return fmt.Errorf("'udp' must be table type") + } + + o.FakeCount = findFrom( + m, "fake-count", parseIntFn[int](int64Range(0, math.MaxInt64)), &err, + ) + o.FakePacket = findSliceFrom(m, "fake-packet", parseByteFn(nil), &err) + + return err +} + +func (o *UDPOptions) Clone() *UDPOptions { + if o == nil { + return nil + } + + return &UDPOptions{ + FakeCount: clonePrimitive(o.FakeCount), + FakePacket: append([]byte(nil), o.FakePacket...), + } +} + +func (origin *UDPOptions) Merge(overrides *UDPOptions) *UDPOptions { + if overrides == nil { + return origin.Clone() + } + + if origin == nil { + return overrides.Clone() + } + + fakePacket := origin.FakePacket + if len(overrides.FakePacket) > 0 { + fakePacket = overrides.FakePacket + } + + return &UDPOptions{ + FakeCount: lo.CoalesceOrEmpty(overrides.FakeCount, origin.FakeCount), + FakePacket: fakePacket, } } @@ -394,7 +629,6 @@ var ( ) type PolicyOptions struct { - Auto *bool `toml:"auto"` Template *Rule `toml:"template"` Overrides []Rule `toml:"overries"` } @@ -405,7 +639,6 @@ func (o *PolicyOptions) UnmarshalTOML(data any) (err error) { return fmt.Errorf("non-table type policy config") } - o.Auto = findFrom(m, "auto", parseBoolFn(), &err) o.Template = findStructFrom[Rule](m, "template", &err) o.Overrides = findStructSliceFrom[Rule](m, "overrides", &err) @@ -423,7 +656,6 @@ func (o *PolicyOptions) Clone() *PolicyOptions { } return &PolicyOptions{ - Auto: ptr.Clone(o.Auto), Template: o.Template.Clone(), Overrides: overrides, } @@ -438,20 +670,10 @@ func (origin *PolicyOptions) Merge(overrides *PolicyOptions) *PolicyOptions { return overrides.Clone() } - overridesCopy := overrides.Clone() - - merged := origin.Clone() - merged.Auto = ptr.CloneOr(overrides.Auto, origin.Auto) - - if overridesCopy.Template != nil { - merged.Template = overridesCopy.Template - } - - if overridesCopy.Overrides != nil { - merged.Overrides = append(merged.Overrides, overridesCopy.Overrides...) + return &PolicyOptions{ + Template: lo.CoalesceOrEmpty(overrides.Template.Clone(), origin.Template.Clone()), + Overrides: lo.CoalesceSliceOrEmpty(overrides.Overrides, origin.Overrides), } - - return merged } type AddrMatch struct { @@ -467,12 +689,12 @@ func (a *AddrMatch) UnmarshalTOML(data any) (err error) { } if p := findFrom(v, "cidr", parseStringFn(checkCIDR), &err); isOk(p, err) { - a.CIDR = ptr.FromValue(MustParseCIDR(*p)) + a.CIDR = lo.ToPtr(MustParseCIDR(*p)) } if p := findFrom(v, "port", parseStringFn(checkPortRange), &err); isOk(p, err) { portFrom, portTo := MustParsePortRange(*p) - a.PortFrom, a.PortTo = ptr.FromValue(portFrom), ptr.FromValue(portTo) + a.PortFrom, a.PortTo = lo.ToPtr(portFrom), lo.ToPtr(portTo) } return err @@ -482,10 +704,19 @@ func (a *AddrMatch) Clone() *AddrMatch { if a == nil { return nil } + + var cidr *net.IPNet + if a.CIDR != nil { + cidr = &net.IPNet{ + IP: slices.Clone(a.CIDR.IP), + Mask: slices.Clone(a.CIDR.Mask), + } + } + return &AddrMatch{ - CIDR: ptr.Clone(a.CIDR), - PortFrom: ptr.Clone(a.PortFrom), - PortTo: ptr.Clone(a.PortTo), + CIDR: cidr, + PortFrom: clonePrimitive(a.PortFrom), + PortTo: clonePrimitive(a.PortTo), } } @@ -521,7 +752,7 @@ func (a *MatchAttrs) Clone() *MatchAttrs { } return &MatchAttrs{ - Domains: ptr.CloneSlice(a.Domains), + Domains: lo.CoalesceSliceOrEmpty(a.Domains), Addrs: addrs, } } @@ -533,6 +764,8 @@ type Rule struct { Match *MatchAttrs `toml:"match" json:"mt,omitempty"` DNS *DNSOptions `toml:"dns-override" json:"D,omitempty"` HTTPS *HTTPSOptions `toml:"https-override" json:"H,omitempty"` + UDP *UDPOptions `toml:"udp-override" json:"U,omitempty"` + Conn *ConnOptions `toml:"conn-override" json:"C,omitempty"` } func (r *Rule) UnmarshalTOML(data any) (err error) { @@ -547,6 +780,8 @@ func (r *Rule) UnmarshalTOML(data any) (err error) { r.Match = findStructFrom[MatchAttrs](m, "match", &err) r.DNS = findStructFrom[DNSOptions](m, "dns", &err) r.HTTPS = findStructFrom[HTTPSOptions](m, "https", &err) + r.UDP = findStructFrom[UDPOptions](m, "udp", &err) + r.Conn = findStructFrom[ConnOptions](m, "connection", &err) // if err == nil { // err = checkRule(*r) @@ -560,11 +795,36 @@ func (r *Rule) Clone() *Rule { return nil } return &Rule{ - Name: ptr.Clone(r.Name), - Priority: ptr.Clone(r.Priority), - Block: ptr.Clone(r.Block), - Match: ptr.Clone(r.Match), + Name: clonePrimitive(r.Name), + Priority: clonePrimitive(r.Priority), + Block: clonePrimitive(r.Block), + Match: r.Match.Clone(), DNS: r.DNS.Clone(), HTTPS: r.HTTPS.Clone(), + UDP: r.UDP.Clone(), + Conn: r.Conn.Clone(), } } + +func (r *Rule) JSON() []byte { + data := map[string]any{ + "name": r.Name, + "priority": r.Priority, + } + + if r.Match == nil { + data["match"] = nil + } else { + m := map[string]any{} + if r.Match.Addrs != nil { + m["addr"] = fmt.Sprintf("%v items", len(r.Match.Addrs)) + } + if r.Match.Domains != nil { + m["domain"] = fmt.Sprintf("%v items", len(r.Match.Domains)) + } + data["match"] = m + } + + bytes, _ := json.Marshal(data) + return bytes +} diff --git a/internal/config/types_test.go b/internal/config/types_test.go index 5e2a1498..0a021a0f 100644 --- a/internal/config/types_test.go +++ b/internal/config/types_test.go @@ -5,34 +5,37 @@ import ( "testing" "time" + "github.com/BurntSushi/toml" "github.com/rs/zerolog" + "github.com/samber/lo" "github.com/stretchr/testify/assert" "github.com/xvzc/SpoofDPI/internal/proto" - "github.com/xvzc/SpoofDPI/internal/ptr" ) // ┌─────────────────┐ // │ GENERAL OPTIONS │ // └─────────────────┘ -func TestGeneralOptions_UnmarshalTOML(t *testing.T) { +func TestAppOptions_UnmarshalTOML(t *testing.T) { tcs := []struct { name string input any wantErr bool - assert func(t *testing.T, o GeneralOptions) + assert func(t *testing.T, o AppOptions) }{ { name: "valid general options", input: map[string]any{ - "log-level": "debug", - "silent": true, - "system-proxy": true, + "log-level": "debug", + "silent": true, + "auto-configure-network": true, + "mode": "socks5", }, wantErr: false, - assert: func(t *testing.T, o GeneralOptions) { + assert: func(t *testing.T, o AppOptions) { assert.Equal(t, zerolog.DebugLevel, *o.LogLevel) assert.True(t, *o.Silent) - assert.True(t, *o.SetSystemProxy) + assert.True(t, *o.AutoConfigureNetwork) + assert.Equal(t, AppModeSOCKS5, *o.Mode) }, }, { @@ -44,7 +47,7 @@ func TestGeneralOptions_UnmarshalTOML(t *testing.T) { for _, tc := range tcs { t.Run(tc.name, func(t *testing.T) { - var o GeneralOptions + var o AppOptions err := o.UnmarshalTOML(tc.input) if tc.wantErr { assert.Error(t, err) @@ -58,26 +61,26 @@ func TestGeneralOptions_UnmarshalTOML(t *testing.T) { } } -func TestGeneralOptions_Clone(t *testing.T) { +func TestAppOptions_Clone(t *testing.T) { tcs := []struct { name string - input *GeneralOptions - assert func(t *testing.T, input *GeneralOptions, output *GeneralOptions) + input *AppOptions + assert func(t *testing.T, input *AppOptions, output *AppOptions) }{ { name: "nil receiver", input: nil, - assert: func(t *testing.T, input *GeneralOptions, output *GeneralOptions) { + assert: func(t *testing.T, input *AppOptions, output *AppOptions) { assert.Nil(t, output) }, }, { name: "non-nil receiver", - input: &GeneralOptions{ - LogLevel: ptr.FromValue(zerolog.DebugLevel), - Silent: ptr.FromValue(true), + input: &AppOptions{ + LogLevel: lo.ToPtr(zerolog.DebugLevel), + Silent: lo.ToPtr(true), }, - assert: func(t *testing.T, input *GeneralOptions, output *GeneralOptions) { + assert: func(t *testing.T, input *AppOptions, output *AppOptions) { assert.NotNil(t, output) assert.Equal(t, zerolog.DebugLevel, *output.LogLevel) assert.True(t, *output.Silent) @@ -94,39 +97,39 @@ func TestGeneralOptions_Clone(t *testing.T) { } } -func TestGeneralOptions_Merge(t *testing.T) { +func TestAppOptions_Merge(t *testing.T) { tcs := []struct { name string - base *GeneralOptions - override *GeneralOptions - assert func(t *testing.T, output *GeneralOptions) + base *AppOptions + override *AppOptions + assert func(t *testing.T, output *AppOptions) }{ { name: "nil receiver", base: nil, - override: &GeneralOptions{Silent: ptr.FromValue(true)}, - assert: func(t *testing.T, output *GeneralOptions) { + override: &AppOptions{Silent: lo.ToPtr(true)}, + assert: func(t *testing.T, output *AppOptions) { assert.True(t, *output.Silent) }, }, { name: "nil override", - base: &GeneralOptions{Silent: ptr.FromValue(false)}, + base: &AppOptions{Silent: lo.ToPtr(false)}, override: nil, - assert: func(t *testing.T, output *GeneralOptions) { + assert: func(t *testing.T, output *AppOptions) { assert.False(t, *output.Silent) }, }, { name: "merge values", - base: &GeneralOptions{ - Silent: ptr.FromValue(false), - LogLevel: ptr.FromValue(zerolog.InfoLevel), + base: &AppOptions{ + Silent: lo.ToPtr(false), + LogLevel: lo.ToPtr(zerolog.InfoLevel), }, - override: &GeneralOptions{ - Silent: ptr.FromValue(true), + override: &AppOptions{ + Silent: lo.ToPtr(true), }, - assert: func(t *testing.T, output *GeneralOptions) { + assert: func(t *testing.T, output *AppOptions) { assert.True(t, *output.Silent) assert.Equal(t, zerolog.InfoLevel, *output.LogLevel) }, @@ -144,25 +147,27 @@ func TestGeneralOptions_Merge(t *testing.T) { // ┌────────────────┐ // │ SERVER OPTIONS │ // └────────────────┘ -func TestServerOptions_UnmarshalTOML(t *testing.T) { +func TestConnOptions_UnmarshalTOML(t *testing.T) { tcs := []struct { name string input any wantErr bool - assert func(t *testing.T, o ServerOptions) + assert func(t *testing.T, o ConnOptions) }{ { name: "valid server options", input: map[string]any{ - "default-ttl": int64(64), - "listen-addr": "127.0.0.1:8080", - "timeout": int64(1000), + "default-fake-ttl": int64(64), + "dns-timeout": int64(1000), + "tcp-timeout": int64(1000), + "udp-idle-timeout": int64(1000), }, wantErr: false, - assert: func(t *testing.T, o ServerOptions) { - assert.Equal(t, uint8(64), *o.DefaultTTL) - assert.Equal(t, "127.0.0.1:8080", o.ListenAddr.String()) - assert.Equal(t, 1000*time.Millisecond, *o.Timeout) + assert: func(t *testing.T, o ConnOptions) { + assert.Equal(t, uint8(64), *o.DefaultFakeTTL) + assert.Equal(t, 1000*time.Millisecond, *o.DNSTimeout) + assert.Equal(t, 1000*time.Millisecond, *o.TCPTimeout) + assert.Equal(t, 1000*time.Millisecond, *o.UDPIdleTimeout) }, }, { @@ -174,7 +179,7 @@ func TestServerOptions_UnmarshalTOML(t *testing.T) { for _, tc := range tcs { t.Run(tc.name, func(t *testing.T) { - var o ServerOptions + var o ConnOptions err := o.UnmarshalTOML(tc.input) if tc.wantErr { assert.Error(t, err) @@ -188,33 +193,34 @@ func TestServerOptions_UnmarshalTOML(t *testing.T) { } } -func TestServerOptions_Clone(t *testing.T) { +func TestConnOptions_Clone(t *testing.T) { tcs := []struct { name string - input *ServerOptions - assert func(t *testing.T, input *ServerOptions, output *ServerOptions) + input *ConnOptions + assert func(t *testing.T, input *ConnOptions, output *ConnOptions) }{ { name: "nil receiver", input: nil, - assert: func(t *testing.T, input *ServerOptions, output *ServerOptions) { + assert: func(t *testing.T, input *ConnOptions, output *ConnOptions) { assert.Nil(t, output) }, }, { name: "non-nil receiver", - input: &ServerOptions{ - DefaultTTL: ptr.FromValue(uint8(64)), - ListenAddr: &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 8080}, + input: &ConnOptions{ + DefaultFakeTTL: lo.ToPtr(uint8(64)), + DNSTimeout: lo.ToPtr(time.Duration(1000) * time.Millisecond), + TCPTimeout: lo.ToPtr(time.Duration(1000) * time.Millisecond), + UDPIdleTimeout: lo.ToPtr(time.Duration(1000) * time.Millisecond), }, - assert: func(t *testing.T, input *ServerOptions, output *ServerOptions) { + assert: func(t *testing.T, input *ConnOptions, output *ConnOptions) { assert.NotNil(t, output) - assert.Equal(t, uint8(64), *output.DefaultTTL) - assert.Equal(t, "127.0.0.1:8080", output.ListenAddr.String()) + assert.Equal(t, uint8(64), *output.DefaultFakeTTL) + assert.Equal(t, 1000*time.Millisecond, *output.DNSTimeout) + assert.Equal(t, 1000*time.Millisecond, *output.TCPTimeout) + assert.Equal(t, 1000*time.Millisecond, *output.UDPIdleTimeout) assert.NotSame(t, input, output) - if output.ListenAddr != nil { - assert.NotSame(t, input.ListenAddr, output.ListenAddr) - } }, }, } @@ -227,41 +233,45 @@ func TestServerOptions_Clone(t *testing.T) { } } -func TestServerOptions_Merge(t *testing.T) { +func TestConnOptions_Merge(t *testing.T) { tcs := []struct { name string - base *ServerOptions - override *ServerOptions - assert func(t *testing.T, output *ServerOptions) + base *ConnOptions + override *ConnOptions + assert func(t *testing.T, output *ConnOptions) }{ { name: "nil receiver", base: nil, - override: &ServerOptions{DefaultTTL: ptr.FromValue(uint8(64))}, - assert: func(t *testing.T, output *ServerOptions) { - assert.Equal(t, uint8(64), *output.DefaultTTL) + override: &ConnOptions{DefaultFakeTTL: lo.ToPtr(uint8(64))}, + assert: func(t *testing.T, output *ConnOptions) { + assert.Equal(t, uint8(64), *output.DefaultFakeTTL) }, }, { name: "nil override", - base: &ServerOptions{DefaultTTL: ptr.FromValue(uint8(128))}, + base: &ConnOptions{DefaultFakeTTL: lo.ToPtr(uint8(128))}, override: nil, - assert: func(t *testing.T, output *ServerOptions) { - assert.Equal(t, uint8(128), *output.DefaultTTL) + assert: func(t *testing.T, output *ConnOptions) { + assert.Equal(t, uint8(128), *output.DefaultFakeTTL) }, }, { name: "merge values", - base: &ServerOptions{ - DefaultTTL: ptr.FromValue(uint8(64)), - ListenAddr: &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 8080}, + base: &ConnOptions{ + DefaultFakeTTL: lo.ToPtr(uint8(64)), + DNSTimeout: lo.ToPtr(time.Duration(1000) * time.Millisecond), + TCPTimeout: lo.ToPtr(time.Duration(1000) * time.Millisecond), + UDPIdleTimeout: lo.ToPtr(time.Duration(1000) * time.Millisecond), }, - override: &ServerOptions{ - DefaultTTL: ptr.FromValue(uint8(128)), + override: &ConnOptions{ + DefaultFakeTTL: lo.ToPtr(uint8(128)), }, - assert: func(t *testing.T, output *ServerOptions) { - assert.Equal(t, uint8(128), *output.DefaultTTL) - assert.Equal(t, "127.0.0.1:8080", output.ListenAddr.String()) + assert: func(t *testing.T, output *ConnOptions) { + assert.Equal(t, uint8(128), *output.DefaultFakeTTL) + assert.Equal(t, 1000*time.Millisecond, *output.DNSTimeout) + assert.Equal(t, 1000*time.Millisecond, *output.TCPTimeout) + assert.Equal(t, 1000*time.Millisecond, *output.UDPIdleTimeout) }, }, } @@ -341,7 +351,7 @@ func TestDNSOptions_Clone(t *testing.T) { { name: "non-nil receiver", input: &DNSOptions{ - Mode: ptr.FromValue(DNSModeHTTPS), + Mode: lo.ToPtr(DNSModeHTTPS), Addr: &net.TCPAddr{IP: net.ParseIP("1.1.1.1"), Port: 53}, }, assert: func(t *testing.T, input *DNSOptions, output *DNSOptions) { @@ -373,14 +383,14 @@ func TestDNSOptions_Merge(t *testing.T) { { name: "nil receiver", base: nil, - override: &DNSOptions{Mode: ptr.FromValue(DNSModeHTTPS)}, + override: &DNSOptions{Mode: lo.ToPtr(DNSModeHTTPS)}, assert: func(t *testing.T, output *DNSOptions) { assert.Equal(t, DNSModeHTTPS, *output.Mode) }, }, { name: "nil override", - base: &DNSOptions{Mode: ptr.FromValue(DNSModeUDP)}, + base: &DNSOptions{Mode: lo.ToPtr(DNSModeUDP)}, override: nil, assert: func(t *testing.T, output *DNSOptions) { assert.Equal(t, DNSModeUDP, *output.Mode) @@ -389,12 +399,12 @@ func TestDNSOptions_Merge(t *testing.T) { { name: "merge values", base: &DNSOptions{ - Mode: ptr.FromValue(DNSModeUDP), + Mode: lo.ToPtr(DNSModeUDP), Addr: &net.TCPAddr{IP: net.ParseIP("8.8.8.8"), Port: 53}, }, override: &DNSOptions{ - Mode: ptr.FromValue(DNSModeUDP), - HTTPSURL: ptr.FromValue("https://dns.google/test"), + Mode: lo.ToPtr(DNSModeUDP), + HTTPSURL: lo.ToPtr("https://dns.google/test"), }, assert: func(t *testing.T, output *DNSOptions) { assert.Equal(t, DNSModeUDP, *output.Mode) @@ -481,7 +491,7 @@ func TestHTTPSOptions_Clone(t *testing.T) { { name: "non-nil receiver", input: &HTTPSOptions{ - Disorder: ptr.FromValue(true), + Disorder: lo.ToPtr(true), FakePacket: proto.NewFakeTLSMessage([]byte{0x01}), }, assert: func(t *testing.T, input *HTTPSOptions, output *HTTPSOptions) { @@ -514,14 +524,14 @@ func TestHTTPSOptions_Merge(t *testing.T) { { name: "nil receiver", base: nil, - override: &HTTPSOptions{Disorder: ptr.FromValue(true)}, + override: &HTTPSOptions{Disorder: lo.ToPtr(true)}, assert: func(t *testing.T, output *HTTPSOptions) { assert.True(t, *output.Disorder) }, }, { name: "nil override", - base: &HTTPSOptions{Disorder: ptr.FromValue(false)}, + base: &HTTPSOptions{Disorder: lo.ToPtr(false)}, override: nil, assert: func(t *testing.T, output *HTTPSOptions) { assert.False(t, *output.Disorder) @@ -530,12 +540,12 @@ func TestHTTPSOptions_Merge(t *testing.T) { { name: "merge values", base: &HTTPSOptions{ - Disorder: ptr.FromValue(false), - ChunkSize: ptr.FromValue(uint8(10)), + Disorder: lo.ToPtr(false), + ChunkSize: lo.ToPtr(uint8(10)), FakePacket: proto.NewFakeTLSMessage([]byte{0x01}), }, override: &HTTPSOptions{ - Disorder: ptr.FromValue(true), + Disorder: lo.ToPtr(true), FakePacket: proto.NewFakeTLSMessage([]byte{0x02}), }, assert: func(t *testing.T, output *HTTPSOptions) { @@ -567,7 +577,6 @@ func TestPolicyOptions_UnmarshalTOML(t *testing.T) { { name: "valid policy options", input: map[string]any{ - "auto": true, "overrides": []map[string]any{ { "name": "rule1", @@ -579,7 +588,6 @@ func TestPolicyOptions_UnmarshalTOML(t *testing.T) { }, wantErr: false, assert: func(t *testing.T, o PolicyOptions) { - assert.True(t, *o.Auto) assert.Len(t, o.Overrides, 1) assert.Equal(t, "rule1", *o.Overrides[0].Name) }, @@ -623,17 +631,15 @@ func TestPolicyOptions_Clone(t *testing.T) { { name: "non-nil receiver", input: &PolicyOptions{ - Auto: ptr.FromValue(true), Overrides: []Rule{ { - Name: ptr.FromValue("rule1"), + Name: lo.ToPtr("rule1"), Match: &MatchAttrs{Domains: []string{"example.com"}}, }, }, }, assert: func(t *testing.T, input *PolicyOptions, output *PolicyOptions) { assert.NotNil(t, output) - assert.True(t, *output.Auto) assert.Len(t, output.Overrides, 1) assert.NotSame(t, input, output) // Deep copy check for slice @@ -660,34 +666,30 @@ func TestPolicyOptions_Merge(t *testing.T) { { name: "nil receiver", base: nil, - override: &PolicyOptions{Auto: ptr.FromValue(true)}, + override: &PolicyOptions{Overrides: []Rule{{Name: lo.ToPtr("rule1")}}}, assert: func(t *testing.T, output *PolicyOptions) { - assert.True(t, *output.Auto) + assert.Len(t, output.Overrides, 1) }, }, { name: "nil override", - base: &PolicyOptions{Auto: ptr.FromValue(false)}, + base: &PolicyOptions{Overrides: []Rule{{Name: lo.ToPtr("rule1")}}}, override: nil, assert: func(t *testing.T, output *PolicyOptions) { - assert.False(t, *output.Auto) + assert.Len(t, output.Overrides, 1) }, }, { name: "merge values", base: &PolicyOptions{ - Auto: ptr.FromValue(false), - Overrides: []Rule{{Name: ptr.FromValue("rule1")}}, + Overrides: []Rule{{Name: lo.ToPtr("rule1")}}, }, override: &PolicyOptions{ - Auto: ptr.FromValue(true), - Overrides: []Rule{{Name: ptr.FromValue("rule2")}}, + Overrides: []Rule{{Name: lo.ToPtr("rule2")}}, }, assert: func(t *testing.T, output *PolicyOptions) { - assert.True(t, *output.Auto) - assert.Len(t, output.Overrides, 2) - assert.Equal(t, "rule1", *output.Overrides[0].Name) - assert.Equal(t, "rule2", *output.Overrides[1].Name) + assert.Len(t, output.Overrides, 1) + assert.Equal(t, "rule2", *output.Overrides[0].Name) }, }, } @@ -846,6 +848,20 @@ func TestRule_UnmarshalTOML(t *testing.T) { assert.True(t, *r.Block) }, }, + { + name: "valid rule with connection options", + input: map[string]any{ + "name": "rule2", + "connection": map[string]any{ + "tcp-timeout": int64(500), + }, + }, + wantErr: false, + assert: func(t *testing.T, r Rule) { + assert.Equal(t, "rule2", *r.Name) + assert.Equal(t, time.Duration(500*time.Millisecond), *r.Conn.TCPTimeout) + }, + }, { name: "invalid type", input: "invalid", @@ -885,7 +901,7 @@ func TestRule_Clone(t *testing.T) { { name: "non-nil receiver", input: &Rule{ - Name: ptr.FromValue("rule1"), + Name: lo.ToPtr("rule1"), Match: &MatchAttrs{Domains: []string{"example.com"}}, }, assert: func(t *testing.T, input *Rule, output *Rule) { @@ -903,3 +919,92 @@ func TestRule_Clone(t *testing.T) { }) } } + +func TestSegmentPlan_UnmarshalTOML(t *testing.T) { + t.Run("valid segment head", func(t *testing.T) { + input := ` +from = "head" +at = 10 +lazy = true +noise = 1 +` + var s SegmentPlan + err := toml.Unmarshal([]byte(input), &s) + assert.NoError(t, err) + assert.Equal(t, SegmentFromHead, s.From) + assert.Equal(t, 10, s.At) + assert.True(t, s.Lazy) + assert.Equal(t, 1, s.Noise) + }) + + t.Run("valid segment sni", func(t *testing.T) { + input := ` +from = "sni" +at = -5 +` + var s SegmentPlan + err := toml.Unmarshal([]byte(input), &s) + assert.NoError(t, err) + assert.Equal(t, SegmentFromSNI, s.From) + assert.Equal(t, -5, s.At) + }) + + t.Run("missing required field from", func(t *testing.T) { + input := ` +at = 5 +` + var s SegmentPlan + err := toml.Unmarshal([]byte(input), &s) + assert.Error(t, err) + assert.Contains(t, err.Error(), "field 'from' is required") + }) + + t.Run("missing required field at", func(t *testing.T) { + input := ` +from = "head" +` + var s SegmentPlan + err := toml.Unmarshal([]byte(input), &s) + assert.Error(t, err) + assert.Contains(t, err.Error(), "field 'at' is required") + }) + + t.Run("invalid from value", func(t *testing.T) { + input := ` +from = "invalid" +at = 5 +` + var s SegmentPlan + err := toml.Unmarshal([]byte(input), &s) + assert.Error(t, err) + }) +} + +func TestHTTPSOptions_CustomSegmentPlans(t *testing.T) { + t.Run("valid custom config", func(t *testing.T) { + input := ` +split-mode = "custom" +custom-segments = [ + { from = "head", at = 2 }, + { from = "sni", at = 0 } +] +` + var opts HTTPSOptions + err := toml.Unmarshal([]byte(input), &opts) + assert.NoError(t, err) + assert.Equal(t, HTTPSSplitModeCustom, *opts.SplitMode) + assert.Len(t, opts.CustomSegmentPlans, 2) + assert.Equal(t, SegmentFromHead, opts.CustomSegmentPlans[0].From) + assert.Equal(t, 2, opts.CustomSegmentPlans[0].At) + }) + + t.Run("missing custom segments", func(t *testing.T) { + input := ` +split-mode = "custom" +` + var opts HTTPSOptions + err := toml.Unmarshal([]byte(input), &opts) + assert.Error(t, err) + assert.Contains(t, err.Error(), "custom-segments must be provided") + }) +} diff --git a/internal/config/validate.go b/internal/config/validate.go index c45c71a0..43643a0e 100644 --- a/internal/config/validate.go +++ b/internal/config/validate.go @@ -51,10 +51,12 @@ var ( checkUint8 = int64Range(0, math.MaxUint8) checkUint16 = int64Range(0, math.MaxUint16) checkUint8NonZero = int64Range(1, math.MaxUint8) - checkDNSMode = checkOneOf(availableDNSModes...) - checkDNSQueryType = checkOneOf(availableDNSQueries...) - checkHTTPSSplitMode = checkOneOf(availableHTTPSModes...) - checkLogLevel = checkOneOf(availableLogLevels...) + checkAppMode = checkOneOf(availableAppModeValues...) + checkDNSMode = checkOneOf(availableDNSModeValues...) + checkDNSQueryType = checkOneOf(availableDNSQueryValues...) + checkHTTPSSplitMode = checkOneOf(availableHTTPSModeValues...) + checkLogLevel = checkOneOf(availableLogLevelValues...) + checkSegmentFrom = checkOneOf(availableSegmentFromValues...) ) func checkDomainPattern(v string) error { diff --git a/internal/config/validate_test.go b/internal/config/validate_test.go index 3c9c39dc..fa9aa89d 100644 --- a/internal/config/validate_test.go +++ b/internal/config/validate_test.go @@ -3,8 +3,8 @@ package config import ( "testing" + "github.com/samber/lo" "github.com/stretchr/testify/assert" - "github.com/xvzc/SpoofDPI/internal/ptr" ) func TestCheckDomainPattern(t *testing.T) { @@ -179,9 +179,9 @@ func TestCheckMatchAttrs(t *testing.T) { Domains: []string{"www.google.com"}, Addrs: []AddrMatch{ { - CIDR: ptr.FromValue(MustParseCIDR("192.168.0.1/24")), - PortFrom: ptr.FromValue(uint16(80)), - PortTo: ptr.FromValue(uint16(443)), + CIDR: lo.ToPtr(MustParseCIDR("192.168.0.1/24")), + PortFrom: lo.ToPtr(uint16(80)), + PortTo: lo.ToPtr(uint16(443)), }, }, }, @@ -199,9 +199,9 @@ func TestCheckMatchAttrs(t *testing.T) { input: MatchAttrs{ Addrs: []AddrMatch{ { - CIDR: ptr.FromValue(MustParseCIDR("10.0.0.0/8")), - PortFrom: ptr.FromValue(uint16(0)), - PortTo: ptr.FromValue(uint16(65535)), + CIDR: lo.ToPtr(MustParseCIDR("10.0.0.0/8")), + PortFrom: lo.ToPtr(uint16(0)), + PortTo: lo.ToPtr(uint16(65535)), }, }, }, @@ -217,7 +217,7 @@ func TestCheckMatchAttrs(t *testing.T) { input: MatchAttrs{ Addrs: []AddrMatch{ { - CIDR: ptr.FromValue(MustParseCIDR("10.0.0.0/8")), + CIDR: lo.ToPtr(MustParseCIDR("10.0.0.0/8")), }, }, }, @@ -228,8 +228,8 @@ func TestCheckMatchAttrs(t *testing.T) { input: MatchAttrs{ Addrs: []AddrMatch{ { - PortFrom: ptr.FromValue(uint16(80)), - PortTo: ptr.FromValue(uint16(443)), + PortFrom: lo.ToPtr(uint16(80)), + PortTo: lo.ToPtr(uint16(443)), }, }, }, @@ -262,7 +262,7 @@ func TestCheckRule(t *testing.T) { Domains: []string{"example.com"}, }, DNS: &DNSOptions{ - Mode: ptr.FromValue(DNSModeUDP), + Mode: lo.ToPtr(DNSModeUDP), }, }, wantErr: false, @@ -273,14 +273,14 @@ func TestCheckRule(t *testing.T) { Match: &MatchAttrs{ Addrs: []AddrMatch{ { - CIDR: ptr.FromValue(MustParseCIDR("192.168.1.0/24")), - PortFrom: ptr.FromValue(uint16(80)), - PortTo: ptr.FromValue(uint16(80)), + CIDR: lo.ToPtr(MustParseCIDR("192.168.1.0/24")), + PortFrom: lo.ToPtr(uint16(80)), + PortTo: lo.ToPtr(uint16(80)), }, }, }, HTTPS: &HTTPSOptions{ - Disorder: ptr.FromValue(true), + Disorder: lo.ToPtr(true), }, }, wantErr: false, @@ -305,3 +305,31 @@ func TestCheckRule(t *testing.T) { }) } } + +func TestCheckLogLevel(t *testing.T) { + tcs := []struct { + name string + input string + wantErr bool + }{ + {"valid info", "info", false}, + {"valid debug", "debug", false}, + {"valid warn", "warn", false}, + {"valid error", "error", false}, + {"valid trace", "trace", false}, + {"valid disabled", "disabled", false}, + {"invalid unknown", "unknown", true}, + {"invalid empty", "", true}, + } + + for _, tc := range tcs { + t.Run(tc.name, func(t *testing.T) { + err := checkLogLevel(tc.input) + if tc.wantErr { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + }) + } +} diff --git a/internal/desync/tls.go b/internal/desync/tls.go index b136cd75..cd037ac5 100644 --- a/internal/desync/tls.go +++ b/internal/desync/tls.go @@ -8,16 +8,17 @@ import ( "net" "github.com/rs/zerolog" + "github.com/samber/lo" "github.com/xvzc/SpoofDPI/internal/config" "github.com/xvzc/SpoofDPI/internal/logging" "github.com/xvzc/SpoofDPI/internal/netutil" "github.com/xvzc/SpoofDPI/internal/packet" "github.com/xvzc/SpoofDPI/internal/proto" - "github.com/xvzc/SpoofDPI/internal/ptr" ) -type TLSDesyncerAttrs struct { - DefaultTTL uint8 +type Segment struct { + Packet []byte + Lazy bool } // TLSDesyncer splits the data into chunks and optionally @@ -25,22 +26,19 @@ type TLSDesyncerAttrs struct { type TLSDesyncer struct { writer packet.Writer sniffer packet.Sniffer - attrs *TLSDesyncerAttrs } func NewTLSDesyncer( writer packet.Writer, sniffer packet.Sniffer, - attrs *TLSDesyncerAttrs, ) *TLSDesyncer { return &TLSDesyncer{ writer: writer, sniffer: sniffer, - attrs: attrs, } } -func (d *TLSDesyncer) Send( +func (d *TLSDesyncer) Desync( ctx context.Context, logger zerolog.Logger, conn net.Conn, @@ -49,13 +47,15 @@ func (d *TLSDesyncer) Send( ) (int, error) { logger = logging.WithLocalScope(ctx, logger, "tls_desync") - if ptr.FromPtr(httpsOpts.Skip) { + if lo.FromPtr(httpsOpts.Skip) { logger.Trace().Msg("skip desync for this request") - return d.sendSegments(conn, [][]byte{msg.Raw()}) + return d.sendSegments(conn, logger, []Segment{{Packet: msg.Raw()}}) } - if d.sniffer != nil && d.writer != nil && ptr.FromPtr(httpsOpts.FakeCount) > 0 { - oTTL := d.sniffer.GetOptimalTTL(conn.RemoteAddr().String()) + if d.sniffer != nil && d.writer != nil && lo.FromPtr(httpsOpts.FakeCount) > 0 { + oTTL := d.sniffer.GetOptimalTTL( + netutil.NewIPKey(conn.RemoteAddr().(*net.TCPAddr).IP), + ) n, err := d.sendFakePackets(ctx, logger, conn, oTTL, httpsOpts) if err != nil { logger.Warn().Err(err).Msg("failed to send fake packets") @@ -64,38 +64,16 @@ func (d *TLSDesyncer) Send( } } - segments := split(logger, httpsOpts, msg) + segments := split(logger, msg, httpsOpts) - if ptr.FromPtr(httpsOpts.Disorder) { - return d.sendSegmentsDisorder(conn, logger, segments, httpsOpts) - } - - return d.sendSegments(conn, segments) + return d.sendSegments(conn, logger, segments) } // sendSegments sends the segmented Client Hello sequentially. -func (d *TLSDesyncer) sendSegments(conn net.Conn, segments [][]byte) (int, error) { - total := 0 - for _, chunk := range segments { - n, err := conn.Write(chunk) - total += n - if err != nil { - return total, err - } - } - - return total, nil -} - -// sendSegmentsDisorder sends the segmented Client Hello out of order. -// Since performance is prioritized over strict randomness, -// a single 64-bit pattern is generated and reused cyclically -// for sequences exceeding 64 chunks. -func (d *TLSDesyncer) sendSegmentsDisorder( +func (d *TLSDesyncer) sendSegments( conn net.Conn, logger zerolog.Logger, - segments [][]byte, - opts *config.HTTPSOptions, + segments []Segment, ) (int, error) { var isIPv4 bool if tcpAddr, ok := conn.LocalAddr().(*net.TCPAddr); ok { @@ -110,30 +88,23 @@ func (d *TLSDesyncer) sendSegmentsDisorder( } } - defer setTTLWrap(d.attrs.DefaultTTL) // Restore the default TTL on return + defaultTTL := getDefaultTTL() - disorderBits := genPatternMask() - logger.Debug(). - Str("bits", fmt.Sprintf("%064b", disorderBits)). - Msgf("disorder ready") - curBit := uint64(1) total := 0 for _, chunk := range segments { - if !ttlErrored && disorderBits&curBit == curBit { + if !ttlErrored && chunk.Lazy { setTTLWrap(1) } - n, err := conn.Write(chunk) + n, err := conn.Write(chunk.Packet) + total += n if err != nil { return total, err } - total += n - if !ttlErrored && disorderBits&curBit == curBit { - setTTLWrap(d.attrs.DefaultTTL) + if !ttlErrored && chunk.Lazy { + setTTLWrap(defaultTTL) } - - curBit = bits.RotateLeft64(curBit, 1) } return total, nil @@ -141,12 +112,12 @@ func (d *TLSDesyncer) sendSegmentsDisorder( func split( logger zerolog.Logger, - attrs *config.HTTPSOptions, msg *proto.TLSMessage, -) [][]byte { - mode := *attrs.SplitMode + opts *config.HTTPSOptions, +) []Segment { + mode := *opts.SplitMode raw := msg.Raw() - var chunks [][]byte + var segments []Segment var err error switch mode { case config.HTTPSSplitModeSNI: @@ -155,38 +126,42 @@ func split( if err != nil { break } - chunks, err = splitSNI(raw, start, end) + segments, err = splitSNI(raw, start, end, *opts.Disorder) logger.Trace().Msgf("extracted SNI is '%s'", raw[start:end]) case config.HTTPSSplitModeRandom: mask := genPatternMask() - chunks, err = splitMask(raw, mask) + segments, err = splitMask(raw, mask, *opts.Disorder) case config.HTTPSSplitModeChunk: - chunks, err = splitChunks(raw, int(*attrs.ChunkSize)) + segments, err = splitChunks(raw, int(*opts.ChunkSize), *opts.Disorder) case config.HTTPSSplitModeFirstByte: - chunks, err = splitFirstByte(raw) + segments, err = splitFirstByte(raw, *opts.Disorder) + case config.HTTPSSplitModeCustom: + segments, err = applySegmentPlans(msg, opts.CustomSegmentPlans) case config.HTTPSSplitModeNone: - chunks = [][]byte{raw} + segments = []Segment{{Packet: raw}} default: logger.Debug().Msgf("unsupprted split mode '%s'. proceed without split", mode) - chunks = [][]byte{raw} + segments = []Segment{{Packet: raw}} } logger.Debug(). - Int("len", len(chunks)). + Int("len", len(segments)). Str("mode", mode.String()). Str("kind", msg.Kind()). + Bool("disorder", *opts.Disorder). Msg("segments ready") if err != nil { logger.Debug().Err(err). + Str("kind", msg.Kind()). Msgf("error processing split mode '%s', fallback to 'none'", mode) - chunks = [][]byte{raw} + segments = []Segment{{Packet: raw}} } - return chunks + return segments } -func splitChunks(raw []byte, size int) ([][]byte, error) { +func splitChunks(raw []byte, size int, disorder bool) ([]Segment, error) { lenRaw := len(raw) if lenRaw == 0 { @@ -198,26 +173,31 @@ func splitChunks(raw []byte, size int) ([][]byte, error) { } capacity := (lenRaw + size - 1) / size - chunks := make([][]byte, 0, capacity) + chunks := make([]Segment, 0, capacity) + curDisorder := true for len(raw) > 0 { n := min(len(raw), size) - chunks = append(chunks, raw[:n]) + chunks = append(chunks, Segment{Packet: raw[:n], Lazy: curDisorder && disorder}) raw = raw[n:] + curDisorder = !curDisorder } return chunks, nil } -func splitFirstByte(raw []byte) ([][]byte, error) { +func splitFirstByte(raw []byte, disorder bool) ([]Segment, error) { if len(raw) < 2 { return nil, fmt.Errorf("len(raw) is less than 2") } - return [][]byte{raw[:1], raw[1:]}, nil + return []Segment{ + {Packet: raw[:1], Lazy: disorder && true}, + {Packet: raw[1:], Lazy: false}, + }, nil } -func splitSNI(raw []byte, start, end int) ([][]byte, error) { +func splitSNI(raw []byte, start, end int, disorder bool) ([]Segment, error) { lenRaw := len(raw) if lenRaw == 0 { @@ -232,43 +212,66 @@ func splitSNI(raw []byte, start, end int) ([][]byte, error) { return nil, fmt.Errorf("invalid start, end pos (out of range)") } - segments := make([][]byte, 0, lenRaw) - segments = append(segments, raw[:start]) + curDisorder := true + segments := make([]Segment, 0, lenRaw) + segments = append(segments, Segment{Packet: raw[:start]}) for i := range end - start { - segments = append(segments, []byte{raw[start+i]}) + segments = append(segments, Segment{ + Packet: []byte{raw[start+i]}, + Lazy: curDisorder && disorder, + }) + curDisorder = !curDisorder } - segments = append(segments, raw[end:]) - return append([][]byte(nil), segments...), nil + segments = append(segments, Segment{ + Packet: raw[end:], + Lazy: curDisorder && disorder, + }) + + return segments, nil } -func splitMask(raw []byte, mask uint64) ([][]byte, error) { +func splitMask(raw []byte, mask uint64, disorder bool) ([]Segment, error) { lenRaw := len(raw) if lenRaw == 0 { return nil, fmt.Errorf("empty data") } - segments := make([][]byte, 0, lenRaw) + curDisorder := true + segments := make([]Segment, 0, lenRaw) start := 0 curBit := uint64(1) for i := range lenRaw { if mask&curBit == curBit { if i > start { - segments = append(segments, raw[start:i]) + segments = append(segments, Segment{ + Packet: raw[start:i], + Lazy: curDisorder && disorder, + }) + curDisorder = !curDisorder } - segments = append(segments, raw[i:i+1]) + + segments = append(segments, Segment{ + Packet: raw[i : i+1], + Lazy: curDisorder && disorder, + }) + start = i + 1 + curDisorder = !curDisorder } curBit = bits.RotateLeft64(curBit, 1) } if lenRaw > start { - segments = append(segments, raw[start:lenRaw]) + segments = append(segments, Segment{ + Packet: raw[start:lenRaw], + Lazy: curDisorder && disorder, + }) } - return append([][]byte(nil), segments...), nil + return segments, nil } func (d *TLSDesyncer) String() string { @@ -283,7 +286,7 @@ func (d *TLSDesyncer) sendFakePackets( opts *config.HTTPSOptions, ) (int, error) { var totalSent int - segments := split(logger, opts, opts.FakePacket) + segments := split(logger, opts.FakePacket, opts) for range *(opts.FakeCount) { for _, v := range segments { @@ -292,7 +295,7 @@ func (d *TLSDesyncer) sendFakePackets( conn.LocalAddr(), conn.RemoteAddr(), oTTL, - v, + v.Packet, ) if err != nil { return totalSent, err @@ -305,6 +308,67 @@ func (d *TLSDesyncer) sendFakePackets( return totalSent, nil } +func applySegmentPlans( + msg *proto.TLSMessage, + plans []config.SegmentPlan, +) ([]Segment, error) { + raw := msg.Raw() + sniStart, _, err := msg.ExtractSNIOffset() + if err != nil { + return nil, err + } + + var segments []Segment + prvAt := 0 + + for _, s := range plans { + base := 0 + switch s.From { + case config.SegmentFromSNI: + base = sniStart + case config.SegmentFromHead: + base = 0 + } + + curAt := base + s.At + + if s.Noise > 0 { + // Random integer in [-noise, noise] + noiseVal := rand.IntN(s.Noise*2+1) - s.Noise + curAt += noiseVal + } + + // Boundary checks + if curAt < 0 { + curAt = 0 + } + if curAt > len(raw) { + curAt = len(raw) + } + + // Handle overlap with previous split point + if curAt < prvAt { + curAt = prvAt + } + + segments = append(segments, Segment{ + Packet: raw[prvAt:curAt], + Lazy: s.Lazy, + }) + prvAt = curAt + } + + if prvAt < len(raw) { + segments = append(segments, Segment{Packet: raw[prvAt:]}) + } + + return segments, nil +} + +func getDefaultTTL() uint8 { + return 64 +} + // --- Helper Functions (Low-level Syscall) --- // genPatternMask generates a pseudo-random 64-bit mask used for determining diff --git a/internal/desync/tls_test.go b/internal/desync/tls_test.go index 0311e79f..83cbb1cb 100644 --- a/internal/desync/tls_test.go +++ b/internal/desync/tls_test.go @@ -1,359 +1,189 @@ package desync import ( - "fmt" - "slices" "testing" - "github.com/rs/zerolog" "github.com/stretchr/testify/assert" "github.com/xvzc/SpoofDPI/internal/config" "github.com/xvzc/SpoofDPI/internal/proto" - "github.com/xvzc/SpoofDPI/internal/ptr" ) -func TestSplit(t *testing.T) { - logger := zerolog.Nop() +func TestApplySegmentPlans(t *testing.T) { + // Using the FakeClientHello from config/types.go - tcs := []struct { - name string - opts *config.HTTPSOptions - msg *proto.TLSMessage - assert func(t *testing.T, chunks [][]byte) - }{ - { - name: "none", - opts: &config.HTTPSOptions{SplitMode: ptr.FromValue(config.HTTPSSplitModeNone)}, - msg: proto.NewFakeTLSMessage([]byte("12345")), - assert: func(t *testing.T, chunks [][]byte) { - assert.Len(t, chunks, 1) - assert.Equal(t, []byte("12345"), chunks[0]) - }, - }, - { - name: "chunk size 3", - msg: proto.NewFakeTLSMessage([]byte("1234567890")), - opts: &config.HTTPSOptions{ - SplitMode: ptr.FromValue(config.HTTPSSplitModeChunk), - ChunkSize: ptr.FromValue(uint8(3)), - }, - assert: func(t *testing.T, chunks [][]byte) { - assert.Len(t, chunks, 4) - assert.Equal(t, []byte("123"), chunks[0]) - assert.Equal(t, []byte("0"), chunks[3]) + fakeRaw := []byte(config.FakeClientHello) + + msg := proto.NewFakeTLSMessage(fakeRaw) + + // We know where SNI is in FakeClientHello. + // Let's verify SNI offset first to write correct assertions. + sniStart, _, _ := msg.ExtractSNIOffset() + t.Run("cut head", func(t *testing.T) { + plans := []config.SegmentPlan{ + { + From: config.SegmentFromHead, + At: 5, }, - }, - { - name: "chunk size 0 (fallback)", - msg: proto.NewFakeTLSMessage([]byte("1234567890")), - opts: &config.HTTPSOptions{ - SplitMode: ptr.FromValue(config.HTTPSSplitModeChunk), - ChunkSize: ptr.FromValue(uint8(0)), + } + + chunks, err := applySegmentPlans(msg, plans) + assert.NoError(t, err) + assert.Len(t, chunks, 2) + assert.Equal(t, fakeRaw[:5], chunks[0].Packet) + assert.Equal(t, fakeRaw[5:], chunks[1].Packet) + assert.False(t, chunks[0].Lazy) + assert.False(t, chunks[1].Lazy) + }) + + t.Run("cut head lazy", func(t *testing.T) { + plans := []config.SegmentPlan{ + { + From: config.SegmentFromHead, + At: 5, + Lazy: true, }, - assert: func(t *testing.T, chunks [][]byte) { - assert.Len(t, chunks, 1) - assert.Equal(t, []byte("1234567890"), chunks[0]) + } + + chunks, err := applySegmentPlans(msg, plans) + assert.NoError(t, err) + assert.Len(t, chunks, 2) + assert.Equal(t, fakeRaw[:5], chunks[0].Packet) + assert.True(t, chunks[0].Lazy) + assert.Equal(t, fakeRaw[5:], chunks[1].Packet) + assert.False(t, chunks[1].Lazy) // Remainder is not lazy by default + }) + + t.Run("cut head multiple", func(t *testing.T) { + plans := []config.SegmentPlan{ + { + From: config.SegmentFromHead, + At: 5, }, - }, - { - name: "first-byte", - msg: proto.NewFakeTLSMessage([]byte("1234567890")), - opts: &config.HTTPSOptions{ - SplitMode: ptr.FromValue(config.HTTPSSplitModeFirstByte), + { + From: config.SegmentFromHead, + At: 10, }, - assert: func(t *testing.T, chunks [][]byte) { - assert.Len(t, chunks, 2) - assert.Equal(t, []byte("1"), chunks[0]) - assert.Equal(t, []byte("234567890"), chunks[1]) + } + + chunks, err := applySegmentPlans(msg, plans) + assert.NoError(t, err) + assert.Len(t, chunks, 3) + assert.Equal(t, fakeRaw[:5], chunks[0].Packet) + assert.Equal(t, fakeRaw[5:10], chunks[1].Packet) + assert.Equal(t, fakeRaw[10:], chunks[2].Packet) + }) + + t.Run("cut sni", func(t *testing.T) { + if sniStart == 0 { + t.Skip("SNI not found in fake packet") + } + + // Split at SNI start + plans := []config.SegmentPlan{ + { + From: config.SegmentFromSNI, + At: 0, }, - }, - { - name: "first-byte (fallback)", - msg: proto.NewFakeTLSMessage([]byte("1")), - opts: &config.HTTPSOptions{ - SplitMode: ptr.FromValue(config.HTTPSSplitModeFirstByte), + } + + chunks, err := applySegmentPlans(msg, plans) + assert.NoError(t, err) + + // Should be [0...sniStart], [sniStart...] + assert.Len(t, chunks, 2) + assert.Equal(t, fakeRaw[:sniStart], chunks[0].Packet) + assert.Equal(t, fakeRaw[sniStart:], chunks[1].Packet) + }) + + t.Run("cut sni offset", func(t *testing.T) { + if sniStart == 0 { + t.Skip("SNI not found in fake packet") + } + + offset := 5 + target := sniStart + offset + plans := []config.SegmentPlan{ + { + From: config.SegmentFromSNI, + At: offset, }, - assert: func(t *testing.T, chunks [][]byte) { - assert.Len(t, chunks, 1) - assert.Equal(t, []byte("1"), chunks[0]) + } + + chunks, err := applySegmentPlans(msg, plans) + assert.NoError(t, err) + assert.Len(t, chunks, 2) + assert.Equal(t, fakeRaw[:target], chunks[0].Packet) + assert.Equal(t, fakeRaw[target:], chunks[1].Packet) + }) + + t.Run("cut mixed head and sni", func(t *testing.T) { + if sniStart == 0 { + t.Skip("SNI not found in fake packet") + } + + // Split at 5 (head) and then at SNI start + plans := []config.SegmentPlan{ + { + From: config.SegmentFromHead, + At: 5, }, - }, - { - name: "valid sni", - msg: proto.NewFakeTLSMessage([]byte(config.FakeClientHello)), - opts: &config.HTTPSOptions{SplitMode: ptr.FromValue(config.HTTPSSplitModeSNI)}, - assert: func(t *testing.T, chunks [][]byte) { - assert.GreaterOrEqual(t, len(chunks), 1) - assert.Equal(t, string("www.w3.org"), string(slices.Concat(chunks[1:11]...))) + { + From: config.SegmentFromSNI, + At: 0, }, - }, - { - name: "sni (fallback)", - msg: proto.NewFakeTLSMessage([]byte("1234567890")), - opts: &config.HTTPSOptions{SplitMode: ptr.FromValue(config.HTTPSSplitModeSNI)}, - assert: func(t *testing.T, chunks [][]byte) { - // Fallback to no split on error - assert.Len(t, chunks, 1) - assert.Equal(t, []byte("1234567890"), chunks[0]) + } + + chunks, err := applySegmentPlans(msg, plans) + assert.NoError(t, err) + assert.Len(t, chunks, 3) + assert.Equal(t, fakeRaw[:5], chunks[0].Packet) + assert.Equal(t, fakeRaw[5:sniStart], chunks[1].Packet) + assert.Equal(t, fakeRaw[sniStart:], chunks[2].Packet) + }) + + t.Run("overlap ignored", func(t *testing.T) { + // Split at 10, then try to split at 5 (should be ignored/empty for that segment) + plans := []config.SegmentPlan{ + { + From: config.SegmentFromHead, + At: 10, }, - }, - { - name: "random", - msg: proto.NewFakeTLSMessage([]byte(config.FakeClientHello)), - opts: &config.HTTPSOptions{ - SplitMode: ptr.FromValue(config.HTTPSSplitModeRandom), + { + From: config.SegmentFromHead, + At: 5, }, - assert: func(t *testing.T, chunks [][]byte) { - assert.GreaterOrEqual(t, len(chunks), 1) - var joined []byte - for _, c := range chunks { - joined = append(joined, c...) - } - assert.Equal(t, []byte(config.FakeClientHello), joined) + } + + chunks, err := applySegmentPlans(msg, plans) + assert.NoError(t, err) + // chunks[0]: 0-10 + // chunks[1]: 10-10 (empty) + // chunks[2]: 10-end + assert.Len(t, chunks, 3) + assert.Equal(t, fakeRaw[:10], chunks[0].Packet) + assert.Empty(t, chunks[1].Packet) + assert.Equal(t, fakeRaw[10:], chunks[2].Packet) + }) + + t.Run("noise", func(t *testing.T) { + // We can't strictly test random values, but we can check if it runs without panic + + plans := []config.SegmentPlan{ + { + From: config.SegmentFromHead, + At: 10, + Noise: 5, }, - }, - } - - for _, tc := range tcs { - t.Run(tc.name, func(t *testing.T) { - chunks := split(logger, tc.opts, tc.msg) - tc.assert(t, chunks) - }) - } -} - -func TestSplitChunks(t *testing.T) { - tcs := []struct { - name string - raw []byte - size int - wantErr bool - expect [][]byte - }{ - { - name: "size 2", - raw: []byte("12345"), - size: 2, - wantErr: false, - expect: [][]byte{[]byte("12"), []byte("34"), []byte("5")}, - }, - { - name: "size larger than len", - raw: []byte("123"), - size: 5, - wantErr: false, - expect: [][]byte{[]byte("123")}, - }, - { - name: "size 0", - raw: []byte("12345"), - size: 0, - wantErr: true, - }, - { - name: "len(raw) is 0", - raw: []byte(""), - size: 3, - wantErr: true, - }, - } - - for _, tc := range tcs { - t.Run(tc.name, func(t *testing.T) { - chunks, err := splitChunks(tc.raw, tc.size) - if tc.wantErr { - assert.Error(t, err) - return - } - - assert.NoError(t, err) - assert.Equal(t, tc.expect, chunks) - }) - } -} - -func TestSplitFirstByte(t *testing.T) { - tcs := []struct { - name string - raw []byte - wantErr bool - expect [][]byte - }{ - { - name: "size 2", - raw: []byte("12"), - wantErr: false, - expect: [][]byte{[]byte("1"), []byte("2")}, - }, - { - name: "size 3", - raw: []byte("123"), - wantErr: false, - expect: [][]byte{[]byte("1"), []byte("23")}, - }, - { - name: "size 10", - raw: []byte("1234567890"), - wantErr: false, - expect: [][]byte{[]byte("1"), []byte("234567890")}, - }, - { - name: "len(data) is 0", - raw: []byte(""), - wantErr: true, - }, - { - name: "len(data) is 1", - raw: []byte(""), - wantErr: true, - }, - } - - for _, tc := range tcs { - t.Run(tc.name, func(t *testing.T) { - chunks, err := splitFirstByte(tc.raw) - if tc.wantErr { - assert.Error(t, err) - return - } - - assert.NoError(t, err) - assert.Equal(t, chunks, tc.expect) - }) - } -} - -func TestSplitSNI(t *testing.T) { - tcs := []struct { - name string - raw []byte - start int - end int - wantErr bool - expect [][]byte - }{ - { - name: "size 0", - raw: []byte("PREFIX_SNI_SUFFIX"), - start: 7, - end: 10, - wantErr: false, - expect: [][]byte{ - []byte("PREFIX_"), - []byte("S"), - []byte("N"), - []byte("I"), - []byte("_SUFFIX"), - }, - }, - { - name: "start out of range (start > len)", - raw: []byte("1"), - start: 3, - end: 3, - wantErr: true, - expect: [][]byte{[]byte("1")}, - }, - { - name: "start out of range (start < 0)", - raw: []byte("1"), - start: -1, - end: 5, - wantErr: true, - expect: [][]byte{[]byte("1")}, - }, - { - name: "end out of range (end > len)", - raw: []byte("1"), - start: 0, - end: 5, - wantErr: true, - expect: [][]byte{[]byte("1")}, - }, - { - name: "end out of range (end < 0)", - raw: []byte("1"), - start: -1, - end: -1, - wantErr: true, - expect: [][]byte{[]byte("1")}, - }, - { - name: "invalid start, end pos (start > end)", - raw: []byte("12345"), - start: 4, - end: 3, - wantErr: true, - expect: [][]byte{[]byte("1")}, - }, - } - - for _, tc := range tcs { - t.Run(tc.name, func(t *testing.T) { - chunks, err := splitSNI(tc.raw, tc.start, tc.end) - if tc.wantErr { - assert.Error(t, err) - return - } - - assert.NoError(t, err) - assert.Equal(t, chunks, tc.expect) - // Debug print if fail - if !assert.Equal(t, tc.expect, chunks) { - for i, c := range chunks { - fmt.Printf("chunk[%d]: %s\n", i, c) - } - } - }) - } -} - -func TestSplitMask(t *testing.T) { - tcs := []struct { - name string - raw []byte - mask uint64 - wantErr bool - expect [][]byte - }{ - { - name: "mask with some bits set (137 = 10001001)", - raw: []byte("12345678"), - mask: 137, // - wantErr: false, - expect: [][]byte{ - []byte("1"), - []byte("23"), - []byte("4"), - []byte("567"), - []byte("8"), - }, - }, - { - name: "mask 0 (no split)", - raw: []byte("123"), - mask: 0, - wantErr: false, - expect: [][]byte{[]byte("123")}, - }, - { - name: "len(data) is 0", - raw: []byte(""), - mask: 123, - wantErr: true, - }, - } - - for _, tc := range tcs { - t.Run(tc.name, func(t *testing.T) { - chunks, err := splitMask(tc.raw, tc.mask) - if tc.wantErr { - assert.Error(t, err) - return - } + } + for i := 0; i < 50; i++ { + chunks, err := applySegmentPlans(msg, plans) assert.NoError(t, err) - assert.Equal(t, chunks, tc.expect) - }) - } + assert.Len(t, chunks, 2) + // Split point should be between 10-5=5 and 10+5=15 + splitLen := len(chunks[0].Packet) + assert.GreaterOrEqual(t, splitLen, 5) + assert.LessOrEqual(t, splitLen, 15) + } + }) } diff --git a/internal/desync/udp.go b/internal/desync/udp.go new file mode 100644 index 00000000..93891369 --- /dev/null +++ b/internal/desync/udp.go @@ -0,0 +1,73 @@ +package desync + +import ( + "context" + "net" + + "github.com/rs/zerolog" + "github.com/xvzc/SpoofDPI/internal/config" + "github.com/xvzc/SpoofDPI/internal/logging" + "github.com/xvzc/SpoofDPI/internal/netutil" + "github.com/xvzc/SpoofDPI/internal/packet" +) + +type UDPDesyncer struct { + logger zerolog.Logger + writer packet.Writer + sniffer packet.Sniffer +} + +func NewUDPDesyncer( + logger zerolog.Logger, + writer packet.Writer, + sniffer packet.Sniffer, +) *UDPDesyncer { + return &UDPDesyncer{ + logger: logger, + writer: writer, + sniffer: sniffer, + } +} + +func (d *UDPDesyncer) Desync( + ctx context.Context, + lConn net.Conn, + rConn net.Conn, + opts *config.UDPOptions, +) (int, error) { + logger := logging.WithLocalScope(ctx, d.logger, "udp_desync") + + if d.sniffer == nil || d.writer == nil || opts == nil || + opts.FakeCount == nil || *opts.FakeCount <= 0 { + return 0, nil + } + + dstIP := rConn.RemoteAddr().(*net.UDPAddr).IP + oTTL := d.sniffer.GetOptimalTTL(netutil.NewIPKey(dstIP)) + + var totalSent int + for range *opts.FakeCount { + n, err := d.writer.WriteCraftedPacket( + ctx, + lConn.LocalAddr(), // Spoofing source: original local address (TUN) + rConn.RemoteAddr(), + oTTL, + opts.FakePacket, + ) + if err != nil { + logger.Warn().Err(err).Msg("failed to send fake packet") + continue + } + totalSent += n + } + + if totalSent > 0 { + logger.Debug(). + Int("count", *opts.FakeCount). + Int("bytes", totalSent). + Uint8("ttl", oTTL). + Msg("sent fake packets") + } + + return totalSent, nil +} diff --git a/internal/dns/addrselect/addrselect.go b/internal/dns/addrselect/addrselect.go index f81efbc1..702ab13c 100644 --- a/internal/dns/addrselect/addrselect.go +++ b/internal/dns/addrselect/addrselect.go @@ -12,26 +12,26 @@ import ( // Minimal RFC 6724 address selection. -func SortByRFC6724(addrs []net.IPAddr) { - if len(addrs) < 2 { +func SortByRFC6724(ips []net.IP) { + if len(ips) < 2 { return } - sortByRFC6724withSrcs(addrs, srcAddrs(addrs)) + sortByRFC6724withSrcs(ips, srcAddrs(ips)) } -func sortByRFC6724withSrcs(addrs []net.IPAddr, srcs []netip.Addr) { - if len(addrs) != len(srcs) { +func sortByRFC6724withSrcs(ips []net.IP, srcs []netip.Addr) { + if len(ips) != len(srcs) { panic("internal error") } - addrAttr := make([]ipAttr, len(addrs)) + addrAttr := make([]ipAttr, len(ips)) srcAttr := make([]ipAttr, len(srcs)) - for i, v := range addrs { - addrAttrIP, _ := netip.AddrFromSlice(v.IP) + for i, v := range ips { + addrAttrIP, _ := netip.AddrFromSlice(v) addrAttr[i] = ipAttrOf(addrAttrIP) srcAttr[i] = ipAttrOf(srcs[i]) } sort.Stable(&byRFC6724{ - addrs: addrs, + addrs: ips, addrAttr: addrAttr, srcs: srcs, srcAttr: srcAttr, @@ -41,12 +41,12 @@ func sortByRFC6724withSrcs(addrs []net.IPAddr, srcs []netip.Addr) { // srcAddrs tries to UDP-connect to each address to see if it has a // route. This does not send any packets. The destination port number // is irrelevant. -func srcAddrs(addrs []net.IPAddr) []netip.Addr { - srcs := make([]netip.Addr, len(addrs)) +func srcAddrs(ips []net.IP) []netip.Addr { + srcs := make([]netip.Addr, len(ips)) dst := net.UDPAddr{Port: 9} - for i := range addrs { - dst.IP = addrs[i].IP - dst.Zone = addrs[i].Zone + for i := range ips { + dst.IP = ips[i] + // dst.Zone = ips[i].Zone // Zone is not easily available in net.IP, skipping for now c, err := net.DialUDP("udp", nil, &dst) if err == nil { if src, ok := c.LocalAddr().(*net.UDPAddr); ok { @@ -79,7 +79,7 @@ func ipAttrOf(ip netip.Addr) ipAttr { } type byRFC6724 struct { - addrs []net.IPAddr // Addresses to sort. + addrs []net.IP // Addresses to sort. addrAttr []ipAttr srcs []netip.Addr // Or not a valid addr if unreachable. srcAttr []ipAttr @@ -99,8 +99,8 @@ func (s *byRFC6724) Swap(i, j int) { // // The algorithm and variable names are from RFC 6724 section 6. func (s *byRFC6724) Less(i, j int) bool { - DA := s.addrs[i].IP - DB := s.addrs[j].IP + DA := s.addrs[i] + DB := s.addrs[j] SourceDA := s.srcs[i] SourceDB := s.srcs[j] attrDA := &s.addrAttr[i] diff --git a/internal/dns/cache.go b/internal/dns/cache.go index 7185330c..8122ab3b 100644 --- a/internal/dns/cache.go +++ b/internal/dns/cache.go @@ -15,13 +15,13 @@ import ( type CacheResolver struct { logger zerolog.Logger - ttlCache cache.Cache // Owns the cache + ttlCache cache.Cache[string] // Owns the cache } // NewCacheResolver wraps a "worker" resolver with a cache. func NewCacheResolver( logger zerolog.Logger, - cache cache.Cache, + cache cache.Cache[string], ) *CacheResolver { return &CacheResolver{ logger: logger, @@ -53,8 +53,8 @@ func (cr *CacheResolver) Resolve( // the cache might return the wrong one. // For now, assuming simplistic cache key = domain, but awareness of potential issue. // Ideally: key = domain + qtypes + spec-related-things - if item, ok := cr.ttlCache.Get(domain); ok { - logger.Trace().Msgf("hit") + if item, ok := cr.ttlCache.Fetch(domain); ok { + logger.Debug().Str("domain", domain).Msgf("hit") return item.(*RecordSet).Clone(), nil } @@ -64,7 +64,9 @@ func (cr *CacheResolver) Resolve( // 2. [Cache Miss] // Delegate the actual network request to 'r.next' (the worker). - logger.Trace().Str("fallback", fallback.Info()[0].Name).Msgf("miss") + logger.Debug().Str("domain", domain).Str("fallback", fallback.Info()[0].Name). + Msgf("miss") + rSet, err := fallback.Resolve(ctx, domain, nil, rule) if err != nil { return nil, err @@ -73,12 +75,13 @@ func (cr *CacheResolver) Resolve( // 3. [Cache Write] // (Assuming the actual TTL is parsed from the DNS response) // realTTL := 5 * time.Second - logger.Trace(). + logger.Debug(). + Str("domain", domain). Int("len", len(rSet.Addrs)). Uint32("ttl", rSet.TTL). Msg("set") - _ = cr.ttlCache.Set( + _ = cr.ttlCache.Store( domain, rSet, cache.Options().WithTTL(time.Duration(rSet.TTL)*time.Second), diff --git a/internal/dns/https.go b/internal/dns/https.go index 51fd4a0c..b3e3937f 100644 --- a/internal/dns/https.go +++ b/internal/dns/https.go @@ -3,8 +3,9 @@ package dns import ( "bytes" "context" - "encoding/base64" + "crypto/tls" "fmt" + "io" "net" "net/http" "strings" @@ -14,36 +15,52 @@ import ( "github.com/rs/zerolog" "github.com/xvzc/SpoofDPI/internal/config" "github.com/xvzc/SpoofDPI/internal/logging" + "golang.org/x/net/http2" ) var _ Resolver = (*HTTPSResolver)(nil) type HTTPSResolver struct { - logger zerolog.Logger - - client *http.Client - dnsOpts *config.DNSOptions + logger zerolog.Logger + client *http.Client + defaultDNSOpts *config.DNSOptions + defaultConnOpts *config.ConnOptions } func NewHTTPSResolver( logger zerolog.Logger, - dnsOpts *config.DNSOptions, + defaultDNSOpts *config.DNSOptions, + defaultConnOpts *config.ConnOptions, ) *HTTPSResolver { + tr := &http.Transport{ + TLSClientConfig: &tls.Config{ + NextProtos: []string{"h2", "http/1.1"}, + }, + DialContext: (&net.Dialer{ + Timeout: *defaultConnOpts.DNSTimeout, + KeepAlive: 30 * time.Second, + }).DialContext, + TLSHandshakeTimeout: 9 * time.Second, + MaxIdleConnsPerHost: 100, + MaxIdleConns: 100, + ForceAttemptHTTP2: true, + } + + // Configure HTTP/2 transport explicitly + if err := http2.ConfigureTransport(tr); err != nil { + logger.Warn(). + Err(err). + Msg("failed to configure http2 expressly, falling back to default / http/1.1") + } + return &HTTPSResolver{ logger: logger, client: &http.Client{ - Timeout: 5 * time.Second, - Transport: &http.Transport{ - DialContext: (&net.Dialer{ - Timeout: 3 * time.Second, - KeepAlive: 30 * time.Second, - }).DialContext, - TLSHandshakeTimeout: 5 * time.Second, - MaxIdleConnsPerHost: 100, - MaxIdleConns: 100, - }, + Transport: tr, + Timeout: *defaultConnOpts.DNSTimeout, }, - dnsOpts: dnsOpts, + defaultDNSOpts: defaultDNSOpts, + defaultConnOpts: defaultConnOpts, } } @@ -51,7 +68,7 @@ func (dr *HTTPSResolver) Info() []ResolverInfo { return []ResolverInfo{ { Name: "https", - Dst: *dr.dnsOpts.HTTPSURL, + Dst: *dr.defaultDNSOpts.HTTPSURL, }, } } @@ -62,7 +79,7 @@ func (dr *HTTPSResolver) Resolve( fallback Resolver, rule *config.Rule, ) (*RecordSet, error) { - opts := dr.dnsOpts.Clone() + opts := dr.defaultDNSOpts.Clone() if rule != nil { opts = opts.Merge(rule.DNS) } @@ -94,55 +111,75 @@ func (dr *HTTPSResolver) exchange( return nil, err } - url := fmt.Sprintf( - "%s?dns=%s", - upstream, - base64.RawURLEncoding.EncodeToString(pack), - // base64.RawStdEncoding.EncodeToString(pack), - ) - req, err := http.NewRequest("GET", url, nil) - if err != nil { - return nil, err + const maxRetries = 2 + var resp *http.Response + var reqErr error + + // Retry loop for transient network errors like unexpected EOF + for i := 0; i < maxRetries; i++ { + req, err := http.NewRequestWithContext( + ctx, + http.MethodPost, + upstream, + bytes.NewReader(pack), + ) + if err != nil { + return nil, err + } + + req.Header.Set("Content-Type", "application/dns-message") + req.Header.Set("Accept", "application/dns-message") + + resp, reqErr = dr.client.Do(req) + if reqErr == nil { + break + } + + // Check if error is retryable + if i < maxRetries-1 && isRetryableError(reqErr) { + continue + } } - req = req.WithContext(ctx) - req.Header.Set("Accept", "application/dns-message") - - resp, err := dr.client.Do(req) - if err != nil { - return nil, err + if reqErr != nil { + return nil, reqErr } defer func() { _ = resp.Body.Close() }() - buf := bytes.Buffer{} - bodyLen, err := buf.ReadFrom(resp.Body) + body, err := io.ReadAll(resp.Body) if err != nil { return nil, err } if resp.StatusCode != http.StatusOK { logger.Trace(). - Int64("len", bodyLen). + Int("len", len(body)). Int("status", resp.StatusCode). - Str("body", buf.String()). - Msg("status not ok") - return nil, fmt.Errorf("doh status code(%d)", resp.StatusCode) + Str("body", string(body)). + Msg("doh status not ok") + return nil, fmt.Errorf("status code(%d)", resp.StatusCode) } resultMsg := new(dns.Msg) - err = resultMsg.Unpack(buf.Bytes()) - if err != nil { + if err := resultMsg.Unpack(body); err != nil { return nil, err } - // Ignore Rcode 3 (NameNotFound) as it's not a critical error. if resultMsg.Rcode != dns.RcodeSuccess && resultMsg.Rcode != dns.RcodeNameError { logger.Trace(). Int("rcode", resultMsg.Rcode). Str("msg", resultMsg.String()). - Msg("rcode not ok") - return nil, fmt.Errorf("doh Rcode(%d)", resultMsg.Rcode) + Msg("doh rcode not ok") + return nil, fmt.Errorf("Rcode(%d)", resultMsg.Rcode) } return resultMsg, nil } + +// isRetryableError checks for common transient network errors +func isRetryableError(err error) bool { + msg := err.Error() + return strings.Contains(msg, "unexpected EOF") || + strings.Contains(msg, "connection reset") || + strings.Contains(msg, "broken pipe") +} diff --git a/internal/dns/resolver.go b/internal/dns/resolver.go index 2cc63b71..daab801b 100644 --- a/internal/dns/resolver.go +++ b/internal/dns/resolver.go @@ -2,6 +2,7 @@ package dns import ( "context" + "errors" "fmt" "math" "net" @@ -52,13 +53,13 @@ type MsgChan struct { } type RecordSet struct { - Addrs []net.IPAddr + Addrs []net.IP TTL uint32 } func (rs *RecordSet) Clone() *RecordSet { return &RecordSet{ - Addrs: append([]net.IPAddr(nil), rs.Addrs...), + Addrs: append([]net.IP(nil), rs.Addrs...), TTL: rs.TTL, } } @@ -149,8 +150,8 @@ func lookupAllTypes( return resCh } -func parseMsg(msg *dns.Msg) ([]net.IPAddr, uint32, bool) { - var addrs []net.IPAddr +func parseMsg(msg *dns.Msg) ([]net.IP, uint32, bool) { + var addrs []net.IP minTTL := uint32(math.MaxUint32) ok := false @@ -158,11 +159,11 @@ func parseMsg(msg *dns.Msg) ([]net.IPAddr, uint32, bool) { switch ipRecord := record.(type) { case *dns.A: ok = true - addrs = append(addrs, net.IPAddr{IP: ipRecord.A}) + addrs = append(addrs, ipRecord.A) minTTL = min(minTTL, record.Header().Ttl) case *dns.AAAA: ok = true - addrs = append(addrs, net.IPAddr{IP: ipRecord.AAAA}) + addrs = append(addrs, ipRecord.AAAA) minTTL = min(minTTL, record.Header().Ttl) } } @@ -175,7 +176,7 @@ func processMessages( resCh <-chan *MsgChan, ) (*RecordSet, error) { var errs []error - var addrs []net.IPAddr + var addrs []net.IP minTTL := uint32(math.MaxUint32) found := false @@ -224,7 +225,7 @@ loop: // Loop until the channel is closed or context is canceled // Only return errors if no addresses were found at all if len(errs) > 0 { - return nil, fmt.Errorf("failed to resolve with %d errors", len(errs)) + return nil, errors.Join(errs...) } return nil, fmt.Errorf("record not found") diff --git a/internal/dns/route.go b/internal/dns/route.go index 1166ed3d..5823c0a4 100644 --- a/internal/dns/route.go +++ b/internal/dns/route.go @@ -8,18 +8,18 @@ import ( "time" "github.com/rs/zerolog" + "github.com/samber/lo" "github.com/xvzc/SpoofDPI/internal/config" "github.com/xvzc/SpoofDPI/internal/logging" - "github.com/xvzc/SpoofDPI/internal/ptr" ) type RouteResolver struct { - logger zerolog.Logger - https Resolver - udp Resolver - system Resolver - cache Resolver - dnsOpts *config.DNSOptions + logger zerolog.Logger + https Resolver + udp Resolver + system Resolver + cache Resolver + defaultDNSOpts *config.DNSOptions } func NewRouteResolver( @@ -28,15 +28,15 @@ func NewRouteResolver( udp Resolver, sys Resolver, cache Resolver, - dnsOpts *config.DNSOptions, + defaultDNSOpts *config.DNSOptions, ) *RouteResolver { return &RouteResolver{ - logger: logger, - https: doh, - udp: udp, - system: sys, - cache: cache, - dnsOpts: dnsOpts, + logger: logger, + https: doh, + udp: udp, + system: sys, + cache: cache, + defaultDNSOpts: defaultDNSOpts, } } @@ -55,7 +55,7 @@ func (rr *RouteResolver) Resolve( fallback Resolver, rule *config.Rule, ) (*RecordSet, error) { - opts := rr.dnsOpts.Clone() + opts := rr.defaultDNSOpts.Clone() if rule != nil { opts = opts.Merge(rule.DNS) } @@ -64,7 +64,7 @@ func (rr *RouteResolver) Resolve( // 1. Check for IP address in domain if ip, err := parseIpAddr(domain); err == nil { - return &RecordSet{Addrs: []net.IPAddr{*ip}, TTL: 0}, nil + return &RecordSet{Addrs: []net.IP{ip}, TTL: 0}, nil } // 4. Handle ROUTE rule (or default) @@ -77,11 +77,10 @@ func (rr *RouteResolver) Resolve( } resolverInfo := resolver.Info()[0] - logger.Trace(). - Str("name", resolverInfo.Name). - Bool("cache", ptr.FromPtr(opts.Cache)). + logger.Debug().Str("mode", resolverInfo.Name).Bool("cache", lo.FromPtr(opts.Cache)). Msgf("ready to resolve") + t1 := time.Now() var rSet *RecordSet var err error if *opts.Mode != config.DNSModeSystem && *opts.Cache { @@ -90,7 +89,17 @@ func (rr *RouteResolver) Resolve( rSet, err = resolver.Resolve(ctx, domain, nil, rule) } - return rSet, err + if err != nil { + return nil, err + } + + logger.Debug(). + Str("domain", domain). + Int("len", len(rSet.Addrs)). + Str("took", fmt.Sprintf("%.3fms", float64(time.Since(t1).Microseconds())/1000.0)). + Msgf("dns lookup ok") + + return rSet, nil } func (rr *RouteResolver) route(attrs *config.DNSOptions) Resolver { @@ -106,15 +115,11 @@ func (rr *RouteResolver) route(attrs *config.DNSOptions) Resolver { } } -func parseIpAddr(addr string) (*net.IPAddr, error) { +func parseIpAddr(addr string) (net.IP, error) { ip := net.ParseIP(addr) if ip == nil { return nil, fmt.Errorf("%s is not an ip address", addr) } - ipAddr := &net.IPAddr{ - IP: ip, - } - - return ipAddr, nil + return ip, nil } diff --git a/internal/dns/system.go b/internal/dns/system.go index 63c2e72d..d6c318f0 100644 --- a/internal/dns/system.go +++ b/internal/dns/system.go @@ -9,27 +9,27 @@ import ( "github.com/xvzc/SpoofDPI/internal/config" ) -var _ Resolver = (*SysResolver)(nil) +var _ Resolver = (*SystemResolver)(nil) -type SysResolver struct { +type SystemResolver struct { logger zerolog.Logger *net.Resolver - dnsOpts *config.DNSOptions + defaultDNSOpts *config.DNSOptions } func NewSystemResolver( logger zerolog.Logger, - dnsOps *config.DNSOptions, -) *SysResolver { - return &SysResolver{ - logger: logger, - Resolver: &net.Resolver{PreferGo: true}, - dnsOpts: dnsOps, + defaultDNSOpts *config.DNSOptions, +) *SystemResolver { + return &SystemResolver{ + logger: logger, + Resolver: &net.Resolver{PreferGo: true}, + defaultDNSOpts: defaultDNSOpts, } } -func (sr *SysResolver) Info() []ResolverInfo { +func (sr *SystemResolver) Info() []ResolverInfo { return []ResolverInfo{ { Name: "system", @@ -38,29 +38,29 @@ func (sr *SysResolver) Info() []ResolverInfo { } } -func (sr *SysResolver) Resolve( +func (sr *SystemResolver) Resolve( ctx context.Context, domain string, fallback Resolver, rule *config.Rule, ) (*RecordSet, error) { - opts := sr.dnsOpts.Clone() + opts := sr.defaultDNSOpts.Clone() if rule != nil { opts = opts.Merge(rule.DNS) } - addrs, err := sr.LookupIPAddr(ctx, domain) + ips, err := sr.LookupIP(ctx, "ip", domain) if err != nil { return nil, err } return &RecordSet{ - Addrs: filtterAddrs(addrs, parseQueryTypes(*opts.QType)), + Addrs: filtterAddrs(ips, parseQueryTypes(*opts.QType)), TTL: 0, }, nil } -func filtterAddrs(addrs []net.IPAddr, qTypes []uint16) []net.IPAddr { +func filtterAddrs(ips []net.IP, qTypes []uint16) []net.IP { wantsA, wantsAAAA := false, false for _, qType := range qTypes { switch qType { @@ -76,31 +76,31 @@ func filtterAddrs(addrs []net.IPAddr, qTypes []uint16) []net.IPAddr { } if !wantsA && !wantsAAAA { - return []net.IPAddr{} + return []net.IP{} } - filteredMap := make(map[string]net.IPAddr) + filteredMap := make(map[string]net.IP) - for _, addr := range addrs { - addrStr := addr.IP.String() + for _, ip := range ips { + addrStr := ip.String() if _, exists := filteredMap[addrStr]; exists { continue } - isIPv4 := addr.IP.To4() != nil + isIPv4 := ip.To4() != nil if wantsA && isIPv4 { - filteredMap[addrStr] = addr + filteredMap[addrStr] = ip } if wantsAAAA && !isIPv4 { - filteredMap[addrStr] = addr + filteredMap[addrStr] = ip } } - filtered := make([]net.IPAddr, 0, len(filteredMap)) - for _, addr := range filteredMap { - filtered = append(filtered, addr) + filtered := make([]net.IP, 0, len(filteredMap)) + for _, ip := range filteredMap { + filtered = append(filtered, ip) } return filtered diff --git a/internal/dns/udp.go b/internal/dns/udp.go index 2fba04b1..3c7b1862 100644 --- a/internal/dns/udp.go +++ b/internal/dns/udp.go @@ -14,18 +14,23 @@ var _ Resolver = (*UDPResolver)(nil) type UDPResolver struct { logger zerolog.Logger - client *dns.Client - dnsOpts *config.DNSOptions + client *dns.Client + defaultDNSOpts *config.DNSOptions + defaultConnOpts *config.ConnOptions } func NewUDPResolver( logger zerolog.Logger, - dnsOpts *config.DNSOptions, + defaultDNSOpts *config.DNSOptions, + defaultConnOpts *config.ConnOptions, ) *UDPResolver { return &UDPResolver{ - client: &dns.Client{}, - dnsOpts: dnsOpts, - logger: logger, + client: &dns.Client{ + Timeout: *defaultConnOpts.DNSTimeout, + }, + defaultDNSOpts: defaultDNSOpts, + defaultConnOpts: defaultConnOpts, + logger: logger, } } @@ -33,7 +38,7 @@ func (ur *UDPResolver) Info() []ResolverInfo { return []ResolverInfo{ { Name: "udp", - Dst: ur.dnsOpts.Addr.String(), + Dst: ur.defaultDNSOpts.Addr.String(), }, } } @@ -44,7 +49,7 @@ func (ur *UDPResolver) Resolve( fallback Resolver, rule *config.Rule, ) (*RecordSet, error) { - opts := ur.dnsOpts.Clone() + opts := ur.defaultDNSOpts.Clone() if rule != nil { opts = opts.Merge(rule.DNS) } diff --git a/internal/logging/logging.go b/internal/logging/logging.go index a32a890d..58b9f245 100644 --- a/internal/logging/logging.go +++ b/internal/logging/logging.go @@ -127,9 +127,8 @@ func (h ctxHook) Run(e *zerolog.Event, level zerolog.Level, msg string) { // Request-scoped values like trace_id. // Scope is expected to be added at the component's creation time. - if traceID, ok := session.TraceIDFrom(ctx); ok { - e.Str(traceIDFieldName, traceID) - } + traceID, _ := session.TraceIDFrom(ctx) + e.Str(traceIDFieldName, traceID) if domain, ok := session.HostInfoFrom(ctx); ok { e.Str(hostInfoFieldName, domain) @@ -140,21 +139,21 @@ type joinableError interface { Unwrap() []error } -// ErrorUnwrapped tries to unwrap an error and prints each error separately. +// ErrorErrs tries to unwrap an error and prints each error separately. // If the error is not joined, it logs the single error normally. -func ErrorUnwrapped(logger *zerolog.Logger, msg string, err error) { - logUnwrapped(logger, zerolog.ErrorLevel, msg, err) +func ErrorErrs(logger *zerolog.Logger, err error, msg string) { + logErrs(logger, zerolog.ErrorLevel, err, msg) } -func WarnUnwrapped(logger *zerolog.Logger, msg string, err error) { - logUnwrapped(logger, zerolog.WarnLevel, msg, err) +func WarnErrs(logger *zerolog.Logger, err error, msg string) { + logErrs(logger, zerolog.WarnLevel, err, msg) } -func TraceUnwrapped(logger *zerolog.Logger, msg string, err error) { - logUnwrapped(logger, zerolog.TraceLevel, msg, err) +func TraceErrs(logger *zerolog.Logger, err error, msg string) { + logErrs(logger, zerolog.TraceLevel, err, msg) } -func logUnwrapped(logger *zerolog.Logger, level zerolog.Level, msg string, err error) { +func logErrs(logger *zerolog.Logger, level zerolog.Level, err error, msg string) { var joinedErrs joinableError if errors.As(err, &joinedErrs) { diff --git a/internal/matcher/addr.go b/internal/matcher/addr.go index 89b1fab2..6a358c48 100644 --- a/internal/matcher/addr.go +++ b/internal/matcher/addr.go @@ -6,8 +6,8 @@ import ( "sort" "sync" + "github.com/samber/lo" "github.com/xvzc/SpoofDPI/internal/config" - "github.com/xvzc/SpoofDPI/internal/ptr" ) // AddrMatcher implements Matcher for IP/CIDR rules. @@ -43,11 +43,11 @@ func (m *AddrMatcher) Add(r *config.Rule) error { } if r.Priority == nil { - r.Priority = ptr.FromValue(uint16(0)) + r.Priority = lo.ToPtr(uint16(0)) } if r.Block == nil { - r.Block = ptr.FromValue(false) + r.Block = lo.ToPtr(false) } m.mu.Lock() diff --git a/internal/matcher/addr_test.go b/internal/matcher/addr_test.go index 90810ac0..ae2e61e9 100644 --- a/internal/matcher/addr_test.go +++ b/internal/matcher/addr_test.go @@ -4,37 +4,37 @@ import ( "net" "testing" + "github.com/samber/lo" "github.com/stretchr/testify/assert" "github.com/xvzc/SpoofDPI/internal/config" - "github.com/xvzc/SpoofDPI/internal/ptr" ) func TestAddrMatcher(t *testing.T) { matcher := NewAddrMatcher() rule1 := &config.Rule{ - Name: ptr.FromValue("rule1"), - Priority: ptr.FromValue(uint16(10)), + Name: lo.ToPtr("rule1"), + Priority: lo.ToPtr(uint16(10)), Match: &config.MatchAttrs{ Addrs: []config.AddrMatch{ { - CIDR: ptr.FromValue(config.MustParseCIDR("192.168.1.0/24")), - PortFrom: ptr.FromValue(uint16(80)), - PortTo: ptr.FromValue(uint16(80)), + CIDR: lo.ToPtr(config.MustParseCIDR("192.168.1.0/24")), + PortFrom: lo.ToPtr(uint16(80)), + PortTo: lo.ToPtr(uint16(80)), }, }, }, } rule2 := &config.Rule{ - Name: ptr.FromValue("rule2"), - Priority: ptr.FromValue(uint16(20)), + Name: lo.ToPtr("rule2"), + Priority: lo.ToPtr(uint16(20)), Match: &config.MatchAttrs{ Addrs: []config.AddrMatch{ { - CIDR: ptr.FromValue(config.MustParseCIDR("10.0.0.0/8")), - PortFrom: ptr.FromValue(uint16(0)), - PortTo: ptr.FromValue(uint16(65535)), + CIDR: lo.ToPtr(config.MustParseCIDR("10.0.0.0/8")), + PortFrom: lo.ToPtr(uint16(0)), + PortTo: lo.ToPtr(uint16(65535)), }, }, }, @@ -42,14 +42,14 @@ func TestAddrMatcher(t *testing.T) { // Overlapping lower priority rule rule3 := &config.Rule{ - Name: ptr.FromValue("rule3"), - Priority: ptr.FromValue(uint16(5)), + Name: lo.ToPtr("rule3"), + Priority: lo.ToPtr(uint16(5)), Match: &config.MatchAttrs{ Addrs: []config.AddrMatch{ { - CIDR: ptr.FromValue(config.MustParseCIDR("172.16.0.0/16")), - PortFrom: ptr.FromValue(uint16(0)), - PortTo: ptr.FromValue(uint16(65535)), + CIDR: lo.ToPtr(config.MustParseCIDR("172.16.0.0/16")), + PortFrom: lo.ToPtr(uint16(0)), + PortTo: lo.ToPtr(uint16(65535)), }, }, }, @@ -57,14 +57,14 @@ func TestAddrMatcher(t *testing.T) { // Overlapping lower priority rule rule4 := &config.Rule{ - Name: ptr.FromValue("rule4"), - Priority: ptr.FromValue(uint16(4)), + Name: lo.ToPtr("rule4"), + Priority: lo.ToPtr(uint16(4)), Match: &config.MatchAttrs{ Addrs: []config.AddrMatch{ { - CIDR: ptr.FromValue(config.MustParseCIDR("172.16.0.0/16")), - PortFrom: ptr.FromValue(uint16(443)), - PortTo: ptr.FromValue(uint16(443)), + CIDR: lo.ToPtr(config.MustParseCIDR("172.16.0.0/16")), + PortFrom: lo.ToPtr(uint16(443)), + PortTo: lo.ToPtr(uint16(443)), }, }, }, @@ -142,7 +142,7 @@ func TestAddrMatcher(t *testing.T) { t.Run(tc.name, func(t *testing.T) { ip := net.ParseIP(tc.ip) port := tc.port - selector := &Selector{IP: &ip, Port: ptr.FromValue(uint16(port))} + selector := &Selector{IP: &ip, Port: lo.ToPtr(uint16(port))} output := matcher.Search(selector) tc.assert(t, output) }) diff --git a/internal/matcher/domain.go b/internal/matcher/domain.go index 7aa3a6e5..9d021054 100644 --- a/internal/matcher/domain.go +++ b/internal/matcher/domain.go @@ -5,8 +5,8 @@ import ( "strings" "sync" + "github.com/samber/lo" "github.com/xvzc/SpoofDPI/internal/config" - "github.com/xvzc/SpoofDPI/internal/ptr" ) // node represents a single node in the radix tree implementation. @@ -72,11 +72,11 @@ func (t *DomainMatcher) Add(r *config.Rule) error { } if r.Priority == nil { - r.Priority = ptr.FromValue(uint16(0)) + r.Priority = lo.ToPtr(uint16(0)) } if r.Block == nil { - r.Block = ptr.FromValue(false) + r.Block = lo.ToPtr(false) } t.mu.Lock() diff --git a/internal/matcher/domain_test.go b/internal/matcher/domain_test.go index 69ad5704..dab97f17 100644 --- a/internal/matcher/domain_test.go +++ b/internal/matcher/domain_test.go @@ -3,9 +3,9 @@ package matcher import ( "testing" + "github.com/samber/lo" "github.com/stretchr/testify/assert" "github.com/xvzc/SpoofDPI/internal/config" - "github.com/xvzc/SpoofDPI/internal/ptr" ) func TestDomainMatcher(t *testing.T) { @@ -13,24 +13,24 @@ func TestDomainMatcher(t *testing.T) { matcher := NewDomainMatcher() rule1 := &config.Rule{ - Name: ptr.FromValue("rule1"), - Priority: ptr.FromValue(uint16(10)), + Name: lo.ToPtr("rule1"), + Priority: lo.ToPtr(uint16(10)), Match: &config.MatchAttrs{ Domains: []string{"example.com"}, }, } rule2 := &config.Rule{ - Name: ptr.FromValue("rule2"), - Priority: ptr.FromValue(uint16(20)), + Name: lo.ToPtr("rule2"), + Priority: lo.ToPtr(uint16(20)), Match: &config.MatchAttrs{ Domains: []string{"*.google.com"}, }, } rule3 := &config.Rule{ - Name: ptr.FromValue("rule3"), - Priority: ptr.FromValue(uint16(5)), + Name: lo.ToPtr("rule3"), + Priority: lo.ToPtr(uint16(5)), Match: &config.MatchAttrs{ Domains: []string{"**.youtube.com"}, }, @@ -38,8 +38,8 @@ func TestDomainMatcher(t *testing.T) { // Additional rule for priority check rule4 := &config.Rule{ - Name: ptr.FromValue("rule4"), - Priority: ptr.FromValue(uint16(30)), + Name: lo.ToPtr("rule4"), + Priority: lo.ToPtr(uint16(30)), Match: &config.MatchAttrs{ Domains: []string{"mail.google.com"}, }, @@ -57,7 +57,7 @@ func TestDomainMatcher(t *testing.T) { }{ { name: "exact match", - selector: &Selector{Domain: ptr.FromValue("example.com")}, + selector: &Selector{Domain: lo.ToPtr("example.com")}, assert: func(t *testing.T, output *config.Rule) { assert.NotNil(t, output) assert.Equal(t, "rule1", *output.Name) @@ -65,7 +65,7 @@ func TestDomainMatcher(t *testing.T) { }, { name: "wildcard match", - selector: &Selector{Domain: ptr.FromValue("maps.google.com")}, + selector: &Selector{Domain: lo.ToPtr("maps.google.com")}, assert: func(t *testing.T, output *config.Rule) { assert.NotNil(t, output) assert.Equal(t, "rule2", *output.Name) @@ -73,7 +73,7 @@ func TestDomainMatcher(t *testing.T) { }, { name: "globstar match", - selector: &Selector{Domain: ptr.FromValue("foo.bar.youtube.com")}, + selector: &Selector{Domain: lo.ToPtr("foo.bar.youtube.com")}, assert: func(t *testing.T, output *config.Rule) { assert.NotNil(t, output) assert.Equal(t, "rule3", *output.Name) @@ -81,7 +81,7 @@ func TestDomainMatcher(t *testing.T) { }, { name: "wildcard higher priority check", - selector: &Selector{Domain: ptr.FromValue("mail.google.com")}, + selector: &Selector{Domain: lo.ToPtr("mail.google.com")}, assert: func(t *testing.T, output *config.Rule) { // Should pick rule4 (priority 30) over rule2 (priority 20) assert.NotNil(t, output) @@ -90,7 +90,7 @@ func TestDomainMatcher(t *testing.T) { }, { name: "no match", - selector: &Selector{Domain: ptr.FromValue("naver.com")}, + selector: &Selector{Domain: lo.ToPtr("naver.com")}, assert: func(t *testing.T, output *config.Rule) { assert.Nil(t, output) }, diff --git a/internal/matcher/matcher.go b/internal/matcher/matcher.go index 5d3718c1..3e269a25 100644 --- a/internal/matcher/matcher.go +++ b/internal/matcher/matcher.go @@ -4,8 +4,8 @@ import ( "fmt" "net" + "github.com/samber/lo" "github.com/xvzc/SpoofDPI/internal/config" - "github.com/xvzc/SpoofDPI/internal/ptr" ) // ----------------------------------------------------------------------------- @@ -142,7 +142,7 @@ func GetHigherPriorityRule(r1, r2 *config.Rule) *config.Rule { return r1 } - if ptr.FromPtr(r1.Priority) >= ptr.FromPtr(r2.Priority) { + if lo.FromPtr(r1.Priority) >= lo.FromPtr(r2.Priority) { return r1 } return r2 diff --git a/internal/matcher/matcher_test.go b/internal/matcher/matcher_test.go index 4e9156c2..5079963c 100644 --- a/internal/matcher/matcher_test.go +++ b/internal/matcher/matcher_test.go @@ -3,14 +3,14 @@ package matcher import ( "testing" + "github.com/samber/lo" "github.com/stretchr/testify/assert" "github.com/xvzc/SpoofDPI/internal/config" - "github.com/xvzc/SpoofDPI/internal/ptr" ) func TestGetHigherPriorityRule(t *testing.T) { - r1 := &config.Rule{Priority: ptr.FromValue(uint16(10))} - r2 := &config.Rule{Priority: ptr.FromValue(uint16(20))} + r1 := &config.Rule{Priority: lo.ToPtr(uint16(10))} + r2 := &config.Rule{Priority: lo.ToPtr(uint16(20))} tcs := []struct { name string diff --git a/internal/netutil/conn.go b/internal/netutil/conn.go index 2f7a4d09..465679bd 100644 --- a/internal/netutil/conn.go +++ b/internal/netutil/conn.go @@ -1,18 +1,35 @@ package netutil import ( + "bufio" "context" "errors" "fmt" "io" "net" + "os" "sync" + "sync/atomic" "syscall" + "time" "github.com/rs/zerolog" - "github.com/xvzc/SpoofDPI/internal/logging" ) +type TunnelDirType int + +const ( + TunnelDirOut TunnelDirType = iota + TunnelDirIn +) + +// TransferResult holds the result of a unidirectional tunnel transfer. +type TransferResult struct { + Written int64 + Dir TunnelDirType + Err error +} + // bufferPool is a package-level pool of 32KB buffers used by io.CopyBuffer // to reduce memory allocations and GC pressure in the tunnel hot path. var bufferPool = sync.Pool{ @@ -24,16 +41,16 @@ var bufferPool = sync.Pool{ }, } +// TunnelConns copies data from src to dst. +// It sends the result to resCh upon completion. +// It filters out benign errors like EOF, pipe closed, or read timeouts (for UDP). func TunnelConns( ctx context.Context, - logger zerolog.Logger, - errCh chan<- error, - dst net.Conn, // Destination connection (io.Writer) + resCh chan<- TransferResult, src net.Conn, // Source connection (io.Reader) + dst net.Conn, // Destination connection (io.Writer) + dir TunnelDirType, ) { - var n int64 - logger = logging.WithLocalScope(ctx, logger, "tunnel") - var once sync.Once closeOnce := func() { once.Do(func() { @@ -46,11 +63,6 @@ func TunnelConns( defer func() { stop() closeOnce() - - logger.Trace(). - Int64("len", n). - Str("route", fmt.Sprintf("%s -> %s", src.RemoteAddr(), dst.RemoteAddr())). - Msgf("done") }() bufPtr := bufferPool.Get().(*[]byte) @@ -58,13 +70,79 @@ func TunnelConns( // Copy data from src to dst using the borrowed buffer. n, err := io.CopyBuffer(dst, src, *bufPtr) + + // Filter benign errors. + // os.IsTimeout is checked to handle UDP idle timeouts gracefully. if err != nil && !errors.Is(err, net.ErrClosed) && !errors.Is(err, io.EOF) && - !errors.Is(err, syscall.EPIPE) { - errCh <- err + !errors.Is(err, syscall.EPIPE) && !os.IsTimeout(err) { + resCh <- TransferResult{Written: n, Dir: dir, Err: err} return } - errCh <- nil + resCh <- TransferResult{Written: n, Dir: dir, Err: nil} +} + +// WaitAndLogTunnel aggregates results and logs the summary. +// errHandler processes the list of errors to determine the final error. +func WaitAndLogTunnel( + ctx context.Context, + logger zerolog.Logger, + resCh <-chan TransferResult, + startedAt time.Time, + route string, + errHandler func(errs []error) error, // [Modified] Accepts slice handler +) error { + var ( + outBytes int64 + inBytes int64 + errs []error // Collect all errors + ) + + // Wait for exactly 2 results. + for range 2 { + res := <-resCh + + switch res.Dir { + case TunnelDirOut: + outBytes = res.Written + case TunnelDirIn: + inBytes = res.Written + default: + return fmt.Errorf("invalid tunnel dir") + } + + if res.Err != nil { + errs = append(errs, res.Err) + } + } + + duration := float64(time.Since(startedAt).Microseconds()) / 1000.0 + logger.Trace(). + Int64("out", outBytes). + Int64("in", inBytes). + Str("took", fmt.Sprintf("%.3fms", duration)). + Str("route", route). + Int("errs", len(errs)). + Msg("tunnel closed") + + if errHandler != nil { + return errHandler(errs) + } + + if len(errs) > 0 { + return errs[0] + } + + return nil +} + +func DescribeRoute(src, dst net.Conn) string { + return fmt.Sprintf("%s(%s) -> %s(%s)", + src.RemoteAddr(), + src.RemoteAddr().Network(), + dst.RemoteAddr(), + dst.RemoteAddr().Network(), + ) } // CloseConns safely closes one or more io.Closer (like net.Conn). @@ -74,13 +152,13 @@ func TunnelConns( func CloseConns(closers ...io.Closer) { for _, c := range closers { if c != nil { - // Intentionally ignore the error. _ = c.Close() } } } // SetTTL configures the TTL or Hop Limit depending on the IP version. +// The isIPv4 parameter is determined by examining the remote address of the connection. func SetTTL(conn net.Conn, isIPv4 bool, ttl uint8) error { tcpConn, ok := conn.(*net.TCPConn) if !ok { @@ -92,8 +170,19 @@ func SetTTL(conn net.Conn, isIPv4 bool, ttl uint8) error { return err } + // Re-check IP version using remote address to handle IPv4-mapped IPv6 addresses + // On Linux, when using dual-stack sockets, the local address might be IPv6 + // but the actual connection could be IPv4-mapped (::ffff:x.x.x.x) + actualIPv4 := isIPv4 + if tcpAddr, ok := conn.RemoteAddr().(*net.TCPAddr); ok { + // If the IP is IPv4 or IPv4-mapped IPv6, we should use IPv4 options + if ip4 := tcpAddr.IP.To4(); ip4 != nil { + actualIPv4 = true + } + } + var level, opt int - if isIPv4 { + if actualIPv4 { level = syscall.IPPROTO_IP opt = syscall.IP_TTL } else { @@ -103,7 +192,7 @@ func SetTTL(conn net.Conn, isIPv4 bool, ttl uint8) error { var sysErr error - // Invoke Control to manipulate file descriptor directly + // Invoke Control to manipulate file descriptor directly. err = rawConn.Control(func(fd uintptr) { sysErr = syscall.SetsockoptInt(int(fd), level, opt, int(ttl)) }) @@ -113,3 +202,117 @@ func SetTTL(conn net.Conn, isIPv4 bool, ttl uint8) error { return sysErr } + +// BufferedConn wraps a net.Conn with a bufio.Reader to support peeking. +type BufferedConn struct { + r *bufio.Reader + net.Conn +} + +func NewBufferedConn(c net.Conn) *BufferedConn { + return &BufferedConn{ + r: bufio.NewReader(c), + Conn: c, + } +} + +func (b *BufferedConn) Read(p []byte) (int, error) { + return b.r.Read(p) +} + +func (b *BufferedConn) Peek(n int) ([]byte, error) { + return b.r.Peek(n) +} + +// IdleTimeoutConn wraps a net.Conn to extend the deadline on every Read/Write call. +// This is useful for sessions which should stay alive as long as there is activity. +type IdleTimeoutConn struct { + net.Conn + Key any + Timeout time.Duration + + lastActivity int64 // UnixNano atomic + expiredAt int64 // UnixNano atomic + + onActivity func() + onClose func() +} + +// NewIdleTimeoutConn wraps a net.Conn and securely initializes its internal atomic deadlines. +func NewIdleTimeoutConn(conn net.Conn, timeout time.Duration) *IdleTimeoutConn { + c := &IdleTimeoutConn{ + Conn: conn, + Timeout: timeout, + } + + now := time.Now() + atomic.StoreInt64(&c.lastActivity, now.UnixNano()) + if timeout > 0 { + expTime := now.Add(timeout) + atomic.StoreInt64(&c.expiredAt, expTime.UnixNano()) + _ = c.SetReadDeadline(expTime) + _ = c.SetWriteDeadline(expTime) + } + + return c +} + +func (c *IdleTimeoutConn) Read(b []byte) (int, error) { + c.ExtendDeadline() + return c.Conn.Read(b) +} + +func (c *IdleTimeoutConn) Write(b []byte) (int, error) { + c.ExtendDeadline() + return c.Conn.Write(b) +} + +// ExtendDeadline attempts to extend the connection's deadline. +// Returns false if the connection was already expired. +func (c *IdleTimeoutConn) ExtendDeadline() bool { + now := time.Now() + nowUnix := now.UnixNano() + + // 1. Check if already expired (Thread-safe atomic read) + expUnix := atomic.LoadInt64(&c.expiredAt) + if expUnix != 0 && nowUnix > expUnix { + return false + } + + // 2. Throttle OnActivity to drastically reduce LRU Cache Lock Contention + lastActUnix := atomic.LoadInt64(&c.lastActivity) + if nowUnix-lastActUnix > int64(time.Second) { + atomic.StoreInt64(&c.lastActivity, nowUnix) + if c.onActivity != nil { + c.onActivity() + } + } + + // 3. Throttle SetDeadline overhead (System Call) + // Extends only if remaining time is under 70% of timeout + if c.Timeout > 0 { + if expUnix == 0 || (expUnix-nowUnix) < (c.Timeout.Nanoseconds()*7/10) { + newExpUnix := now.Add(c.Timeout).UnixNano() + atomic.StoreInt64(&c.expiredAt, newExpUnix) + + newExpTime := time.Unix(0, newExpUnix) + _ = c.SetReadDeadline(newExpTime) + _ = c.SetWriteDeadline(newExpTime) + } + } + + return true +} + +// IsExpired safely checks if the connection has surpassed its calculated expiration time. +func (c *IdleTimeoutConn) IsExpired(now time.Time) bool { + expUnix := atomic.LoadInt64(&c.expiredAt) + return expUnix != 0 && now.UnixNano() > expUnix +} + +func (c *IdleTimeoutConn) Close() error { + if c.onClose != nil { + c.onClose() + } + return c.Conn.Close() +} diff --git a/internal/netutil/conn_registry.go b/internal/netutil/conn_registry.go new file mode 100644 index 00000000..050a4a0c --- /dev/null +++ b/internal/netutil/conn_registry.go @@ -0,0 +1,130 @@ +package netutil + +import ( + "context" + "net" + "time" + + "github.com/xvzc/SpoofDPI/internal/cache" +) + +// ConnRegistry manages UDP connections with LRU eviction policy and idle timeout. +type ConnRegistry[K comparable] struct { + storage cache.Cache[K] + timeout time.Duration +} + +// NewConnRegistry creates a new pool with the specified capacity and timeout. +func NewConnRegistry[K comparable]( + capacity int, + timeout time.Duration, +) *ConnRegistry[K] { + p := &ConnRegistry[K]{ + timeout: timeout, + } + + onInvalidate := func(k K, v any) { + if conn, ok := v.(*IdleTimeoutConn); ok { + _ = conn.Conn.Close() + } + } + + p.storage = cache.NewLRUCache(capacity, onInvalidate) + + return p +} + +// RunCleanupLoop runs the background cleanup goroutine. +// It exits when appctx is cancelled, closing all remaining cached connections. +func (p *ConnRegistry[K]) RunCleanupLoop(appctx context.Context) { + // Cleanup interval: half of timeout, min 10s, max 60s + cleanupInterval := p.timeout / 2 + cleanupInterval = max(cleanupInterval, 10*time.Second) + cleanupInterval = min(cleanupInterval, 60*time.Second) + + go func() { + ticker := time.NewTicker(cleanupInterval) + defer ticker.Stop() + + for { + select { + case <-appctx.Done(): + p.CloseAll() + return + case <-ticker.C: + p.evictExpired() + } + } + }() +} + +// Store adds a connection to the cache and returns the wrapped connection. +// If the key already exists, the old connection is closed and evicted first. +// If capacity is full, evicts the least recently used connection. +func (p *ConnRegistry[K]) Store(key K, rawConn net.Conn) *IdleTimeoutConn { + wrapper := NewIdleTimeoutConn(rawConn, p.timeout) + wrapper.Key = key + + wrapper.onActivity = func() { + p.storage.Fetch(key) + } + + wrapper.onClose = func() { + p.Evict(key) + } + + p.storage.Store(key, wrapper, nil) + + return wrapper +} + +// Fetch retrieves a connection from the pool, refreshing its LRU status. +func (p *ConnRegistry[K]) Fetch(key K) (*IdleTimeoutConn, bool) { + if val, ok := p.storage.Fetch(key); ok { + return val.(*IdleTimeoutConn), true + } + return nil, false +} + +// Evict closes and removes the connection from the pool. +func (p *ConnRegistry[K]) Evict(key K) { + p.storage.Evict(key) +} + +// Has checks if the connection exists in the cache. +func (p *ConnRegistry[K]) Has(key K) bool { + return p.storage.Has(key) +} + +// Size returns the number of connections in the pool. +func (p *ConnRegistry[K]) Size() int { + return p.storage.Size() +} + +// CloseAll closes all connections in the pool. +func (p *ConnRegistry[K]) CloseAll() { + var toRemove []K + _ = p.storage.ForEach(func(key K, value any) error { + toRemove = append(toRemove, key) + return nil + }) + for _, k := range toRemove { + p.Evict(k) // safely removes without deadlocking + } +} + +func (p *ConnRegistry[K]) evictExpired() { + now := time.Now() + var toRemove []K + _ = p.storage.ForEach(func(key K, value any) error { + if conn, ok := value.(*IdleTimeoutConn); ok { + if conn.IsExpired(now) { + toRemove = append(toRemove, key) + } + } + return nil + }) + for _, k := range toRemove { + p.Evict(k) + } +} diff --git a/internal/netutil/dial.go b/internal/netutil/dial.go index 895e377d..81baa9a3 100644 --- a/internal/netutil/dial.go +++ b/internal/netutil/dial.go @@ -19,24 +19,22 @@ type dialResult struct { func DialFastest( ctx context.Context, network string, - addrs []net.IPAddr, - port int, - timeout time.Duration, + dst *Destination, ) (net.Conn, error) { - if len(addrs) == 0 { + if len(dst.Addrs) == 0 { return nil, fmt.Errorf("no addresses provided to dial") } ctx, cancel := context.WithCancel(ctx) defer cancel() - results := make(chan dialResult, len(addrs)) + results := make(chan dialResult, len(dst.Addrs)) const maxConcurrency = 10 sem := make(chan struct{}, maxConcurrency) // semaphore go func() { - for _, addr := range addrs { + for _, addr := range dst.Addrs { // Get semaphore select { case sem <- struct{}{}: @@ -47,10 +45,21 @@ func DialFastest( go func(ip net.IP) { defer func() { <-sem }() // Return semaphore - targetAddr := net.JoinHostPort(ip.String(), strconv.Itoa(port)) + targetAddr := net.JoinHostPort(ip.String(), strconv.Itoa(dst.Port)) dialer := &net.Dialer{} - if timeout > 0 { - dialer.Deadline = time.Now().Add(timeout) + if dst.Timeout > 0 { + dialer.Deadline = time.Now().Add(dst.Timeout) + } + + // If Iface is specified, bind to the interface + if dst.Iface != nil { + if err := bindToInterface(network, dialer, dst.Iface, ip); err != nil { + select { + case results <- dialResult{conn: nil, err: err}: + case <-ctx.Done(): + } + return + } } conn, err := dialer.DialContext(ctx, network, targetAddr) @@ -62,14 +71,14 @@ func DialFastest( _ = conn.Close() // Close on context cancel } } - }(addr.IP) + }(addr) } }() var firstError error failureCount := 0 - for range addrs { + for range dst.Addrs { select { case res := <-results: if res.err == nil { diff --git a/internal/netutil/addr.go b/internal/netutil/dst.go similarity index 67% rename from internal/netutil/addr.go rename to internal/netutil/dst.go index 2f856eca..89b42807 100644 --- a/internal/netutil/addr.go +++ b/internal/netutil/dst.go @@ -3,10 +3,25 @@ package netutil import ( "fmt" "net" + "strconv" + "time" ) +type Destination struct { + Domain string + Addrs []net.IP + Port int + Timeout time.Duration + Iface *net.Interface + Gateway string +} + +func (d *Destination) String() string { + return net.JoinHostPort(d.Domain, strconv.Itoa(d.Port)) +} + func ValidateDestination( - dstAddrs []net.IPAddr, + dstAddrs []net.IP, dstPort int, listenAddr *net.TCPAddr, ) (bool, error) { @@ -19,7 +34,7 @@ func ValidateDestination( ifAddrs, err = net.InterfaceAddrs() for _, dstAddr := range dstAddrs { - ip := dstAddr.IP + ip := dstAddr if ip.IsLoopback() { return false, fmt.Errorf("loopback addr detected %v", ip.String()) } diff --git a/internal/netutil/key.go b/internal/netutil/key.go new file mode 100644 index 00000000..615103ea --- /dev/null +++ b/internal/netutil/key.go @@ -0,0 +1,80 @@ +package netutil + +import ( + "fmt" + "net" +) + +// NATKey represents a 4-tuple (SrcIP, SrcPort, DstIP, DstPort) for zero-allocation NAT session mapping +type NATKey struct { + SrcIP [16]byte + SrcPort uint16 + DstIP [16]byte + DstPort uint16 +} + +// String returns the string representation of the session key. +// Only used for debugging / logging. +func (k NATKey) String() string { + var srcIP, dstIP net.IP + + // Check if IPv4-mapped IPv6 + if isIPv4Mapped(k.SrcIP) { + srcIP = net.IP(k.SrcIP[12:16]) + } else { + srcIP = net.IP(k.SrcIP[:]) + } + + if isIPv4Mapped(k.DstIP) { + dstIP = net.IP(k.DstIP[12:16]) + } else { + dstIP = net.IP(k.DstIP[:]) + } + + return fmt.Sprintf("%v:%d>%v:%d", srcIP, k.SrcPort, dstIP, k.DstPort) +} + +// IPKey represents an IP address for zero-allocation cache mapping +type IPKey [16]byte + +// String returns the string representation of the IPKey. +func (k IPKey) String() string { + var srcIP net.IP + if isIPv4Mapped(k) { + srcIP = net.IP(k[12:16]) + } else { + srcIP = net.IP(k[:]) + } + return srcIP.String() +} + +// NewIPKey zero-alloc constructs an IPKey from net.IP +func NewIPKey(ip net.IP) IPKey { + var k IPKey + ip16 := ip.To16() + if ip16 != nil { + copy(k[:], ip16) + } + return k +} + +// NewNATKey zero-alloc constructs a NATKey from two UDPAddr +func NewNATKey(srcIP net.IP, srcPort int, dstIP net.IP, dstPort int) NATKey { + var k NATKey + + // net.IP is a slice. Let's force it to 16 bytes for comparable struct key + srcIP16 := srcIP.To16() + if srcIP16 != nil { + copy(k.SrcIP[:], srcIP16) + } + + dstIP16 := dstIP.To16() + if dstIP16 != nil { + copy(k.DstIP[:], dstIP16) + } + + k.SrcPort = uint16(srcPort) + k.DstPort = uint16(dstPort) + + return k +} diff --git a/internal/netutil/netutil.go b/internal/netutil/netutil.go new file mode 100644 index 00000000..a79eacda --- /dev/null +++ b/internal/netutil/netutil.go @@ -0,0 +1,14 @@ +package netutil + +func isIPv4Mapped(ip [16]byte) bool { + // IPv4-mapped IPv6 address has the prefix 0:0:0:0:0:FFFF + for i := 0; i < 10; i++ { + if ip[i] != 0 { + return false + } + } + if ip[10] != 0xff || ip[11] != 0xff { + return false + } + return true +} diff --git a/internal/netutil/pac.go b/internal/netutil/pac.go new file mode 100644 index 00000000..03b936cc --- /dev/null +++ b/internal/netutil/pac.go @@ -0,0 +1,34 @@ +package netutil + +import ( + "fmt" + "net" + "net/http" +) + +func RunPACServer(content string) (string, *http.Server, error) { + listener, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + return "", nil, err + } + + mux := http.NewServeMux() + mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/x-ns-proxy-autoconfig") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(content)) + }) + + server := &http.Server{ + Handler: mux, + } + + go func() { + _ = server.Serve(listener) + }() + + addr := listener.Addr().(*net.TCPAddr) + url := fmt.Sprintf("http://127.0.0.1:%d/proxy.pac", addr.Port) + + return url, server, nil +} diff --git a/internal/netutil/route.go b/internal/netutil/route.go new file mode 100644 index 00000000..88ba1af0 --- /dev/null +++ b/internal/netutil/route.go @@ -0,0 +1,96 @@ +package netutil + +import ( + "fmt" + "net" +) + +// FindSafeSubnet scans the 10.0.0.0/8 range to find an unused /30 subnet +func FindSafeSubnet() (string, string, error) { + // Retrieve all active interface addresses to prevent CIDR overlapping. + // Checking against existing networks is faster than sending probe packets. + addrs, err := net.InterfaceAddrs() + if err != nil { + return "", "", err + } + + var existingNets []*net.IPNet + for _, addr := range addrs { + if ipnet, ok := addr.(*net.IPNet); ok { + existingNets = append(existingNets, ipnet) + } + } + + // Iterate through the 10.0.0.0/8 private range with a /30 step. + // A /30 subnet provides exactly two usable end-point IP addresses. + for i := 0; i < 256; i++ { + for j := 0; j < 256; j++ { + // Construct candidate IP pair: 10.i.j.1 and 10.i.j.2 + local := net.IPv4(10, byte(i), byte(j), 1) + remote := net.IPv4(10, byte(i), byte(j), 2) + + conflict := false + for _, ipnet := range existingNets { + if ipnet.Contains(local) || ipnet.Contains(remote) { + conflict = true + break + } + } + + if !conflict { + return local.String(), remote.String(), nil + } + } + } + + return "", "", fmt.Errorf("failed to find an available address in 10.0.0.0/8") +} + +// GetDefaultInterfaceAndGateway returns the name of the default network interface and the gateway IP +func GetDefaultInterfaceAndGateway() (string, string, error) { + // Dial a public DNS server to determine the default interface + conn, err := net.Dial("udp", "8.8.8.8:53") + if err != nil { + return "", "", err + } + defer func() { _ = conn.Close() }() + + localAddr := conn.LocalAddr().(*net.UDPAddr) + + ifaces, err := net.Interfaces() + if err != nil { + return "", "", err + } + + var ifaceName string + for _, iface := range ifaces { + addrs, err := iface.Addrs() + if err != nil { + continue + } + + for _, addr := range addrs { + if ipnet, ok := addr.(*net.IPNet); ok { + if ipnet.IP.Equal(localAddr.IP) { + ifaceName = iface.Name + break + } + } + } + if ifaceName != "" { + break + } + } + + if ifaceName == "" { + return "", "", fmt.Errorf("default interface not found") + } + + // Get gateway by parsing route table + gateway, err := getDefaultGateway() + if err != nil { + return "", "", fmt.Errorf("failed to get default gateway: %w", err) + } + + return ifaceName, gateway, nil +} diff --git a/internal/netutil/route_bsd.go b/internal/netutil/route_bsd.go new file mode 100644 index 00000000..c84b6329 --- /dev/null +++ b/internal/netutil/route_bsd.go @@ -0,0 +1,63 @@ +//go:build darwin || freebsd + +package netutil + +import ( + "fmt" + "net" + "os/exec" + "regexp" + "syscall" + + "golang.org/x/sys/unix" +) + +// bindToInterface sets the dialer's Control function to bind the socket +// to a specific network interface using IP_BOUND_IF on BSD systems. +func bindToInterface( + network string, + dialer *net.Dialer, + iface *net.Interface, + targetIP net.IP, +) error { + if iface == nil { + return nil + } + + ifaceIndex := iface.Index + dialer.Control = func(network, address string, c syscall.RawConn) error { + var setsockoptErr error + err := c.Control(func(fd uintptr) { + setsockoptErr = unix.SetsockoptInt( + int(fd), + unix.IPPROTO_IP, + unix.IP_BOUND_IF, + ifaceIndex, + ) + }) + if err != nil { + return err + } + return setsockoptErr + } + return nil +} + +// getDefaultGateway parses the system route table to find the default gateway on macOS/BSD +func getDefaultGateway() (string, error) { + // Use route to get the default route on macOS + cmd := exec.Command("route", "-n", "get", "default") + out, err := cmd.Output() + if err != nil { + return "", err + } + + // Parse output to find gateway line + re := regexp.MustCompile(`gateway:\s+(\d+\.\d+\.\d+\.\d+)`) + matches := re.FindSubmatch(out) + if len(matches) < 2 { + return "", fmt.Errorf("could not parse gateway from route output") + } + + return string(matches[1]), nil +} diff --git a/internal/netutil/route_linux.go b/internal/netutil/route_linux.go new file mode 100644 index 00000000..3273e124 --- /dev/null +++ b/internal/netutil/route_linux.go @@ -0,0 +1,81 @@ +//go:build linux + +package netutil + +import ( + "fmt" + "net" + "os/exec" + "strings" +) + +// bindToInterface sets the dialer's LocalAddr to use the interface's IP as the source address. +// On Linux, we only set LocalAddr because SO_BINDTODEVICE can cause issues with +// socket lookup for incoming packets. +func bindToInterface( + network string, + dialer *net.Dialer, + iface *net.Interface, + targetIP net.IP, +) error { + if iface == nil { + return nil + } + + // Find the interface's IP address to use as source + addrs, err := iface.Addrs() + if err != nil { + return fmt.Errorf("failed to get interface addresses: %w", err) + } + + for _, addr := range addrs { + if ipnet, ok := addr.(*net.IPNet); ok { + // Match IP version: use IPv4 source for IPv4 target, IPv6 for IPv6 + if targetIP.To4() != nil && ipnet.IP.To4() != nil && !ipnet.IP.IsLoopback() { + if strings.HasPrefix(network, "tcp") { + dialer.LocalAddr = &net.TCPAddr{IP: ipnet.IP} + } else if strings.HasPrefix(network, "udp") { + dialer.LocalAddr = &net.UDPAddr{IP: ipnet.IP} + } else { + dialer.LocalAddr = &net.IPAddr{IP: ipnet.IP} + } + return nil + } else if targetIP.To4() == nil && ipnet.IP.To4() == nil && !ipnet.IP.IsLoopback() { + if strings.HasPrefix(network, "tcp") { + dialer.LocalAddr = &net.TCPAddr{IP: ipnet.IP} + } else if strings.HasPrefix(network, "udp") { + dialer.LocalAddr = &net.UDPAddr{IP: ipnet.IP} + } else { + dialer.LocalAddr = &net.IPAddr{IP: ipnet.IP} + } + return nil + } + } + } + + return fmt.Errorf( + "no suitable IP address found on interface %s for target %s", + iface.Name, + targetIP, + ) +} + +// getDefaultGateway parses the system route table to find the default gateway on Linux +func getDefaultGateway() (string, error) { + // Use ip route to get the default route on Linux + cmd := exec.Command("ip", "route", "show", "default") + out, err := cmd.Output() + if err != nil { + return "", err + } + + // Parse output: "default via 192.168.0.1 dev enp12s0 ..." + fields := strings.Fields(string(out)) + for i, field := range fields { + if field == "via" && i+1 < len(fields) { + return fields[i+1], nil + } + } + + return "", fmt.Errorf("could not parse gateway from ip route output: %s", string(out)) +} diff --git a/internal/netutil/route_unsupported.go b/internal/netutil/route_unsupported.go new file mode 100644 index 00000000..7e13f7f7 --- /dev/null +++ b/internal/netutil/route_unsupported.go @@ -0,0 +1,23 @@ +//go:build !linux && !darwin && !freebsd + +package netutil + +import ( + "fmt" + "net" +) + +// bindToInterface is a no-op on unsupported platforms. +func bindToInterface( + network string, + dialer *net.Dialer, + iface *net.Interface, + targetIP net.IP, +) error { + return nil +} + +// getDefaultGateway is not supported on this platform. +func getDefaultGateway() (string, error) { + return "", fmt.Errorf("getDefaultGateway not supported on this platform") +} diff --git a/internal/packet/LICENSE b/internal/packet/LICENSE deleted file mode 100644 index 8dada3ed..00000000 --- a/internal/packet/LICENSE +++ /dev/null @@ -1,201 +0,0 @@ - Apache License - Version 2.0, January 2004 - http://www.apache.org/licenses/ - - TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION - - 1. Definitions. - - "License" shall mean the terms and conditions for use, reproduction, - and distribution as defined by Sections 1 through 9 of this document. - - "Licensor" shall mean the copyright owner or entity authorized by - the copyright owner that is granting the License. - - "Legal Entity" shall mean the union of the acting entity and all - other entities that control, are controlled by, or are under common - control with that entity. For the purposes of this definition, - "control" means (i) the power, direct or indirect, to cause the - direction or management of such entity, whether by contract or - otherwise, or (ii) ownership of fifty percent (50%) or more of the - outstanding shares, or (iii) beneficial ownership of such entity. - - "You" (or "Your") shall mean an individual or Legal Entity - exercising permissions granted by this License. - - "Source" form shall mean the preferred form for making modifications, - including but not limited to software source code, documentation - source, and configuration files. - - "Object" form shall mean any form resulting from mechanical - transformation or translation of a Source form, including but - not limited to compiled object code, generated documentation, - and conversions to other media types. - - "Work" shall mean the work of authorship, whether in Source or - Object form, made available under the License, as indicated by a - copyright notice that is included in or attached to the work - (an example is provided in the Appendix below). - - "Derivative Works" shall mean any work, whether in Source or Object - form, that is based on (or derived from) the Work and for which the - editorial revisions, annotations, elaborations, or other modifications - represent, as a whole, an original work of authorship. For the purposes - of this License, Derivative Works shall not include works that remain - separable from, or merely link (or bind by name) to the interfaces of, - the Work and Derivative Works thereof. - - "Contribution" shall mean any work of authorship, including - the original version of the Work and any modifications or additions - to that Work or Derivative Works thereof, that is intentionally - submitted to Licensor for inclusion in the Work by the copyright owner - or by an individual or Legal Entity authorized to submit on behalf of - the copyright owner. For the purposes of this definition, "submitted" - means any form of electronic, verbal, or written communication sent - to the Licensor or its representatives, including but not limited to - communication on electronic mailing lists, source code control systems, - and issue tracking systems that are managed by, or on behalf of, the - Licensor for the purpose of discussing and improving the Work, but - excluding communication that is conspicuously marked or otherwise - designated in writing by the copyright owner as "Not a Contribution." - - "Contributor" shall mean Licensor and any individual or Legal Entity - on behalf of whom a Contribution has been received by Licensor and - subsequently incorporated within the Work. - - 2. Grant of Copyright License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - copyright license to reproduce, prepare Derivative Works of, - publicly display, publicly perform, sublicense, and distribute the - Work and such Derivative Works in Source or Object form. - - 3. Grant of Patent License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - (except as stated in this section) patent license to make, have made, - use, offer to sell, sell, import, and otherwise transfer the Work, - where such license applies only to those patent claims licensable - by such Contributor that are necessarily infringed by their - Contribution(s) alone or by combination of their Contribution(s) - with the Work to which such Contribution(s) was submitted. If You - institute patent litigation against any entity (including a - cross-claim or counterclaim in a lawsuit) alleging that the Work - or a Contribution incorporated within the Work constitutes direct - or contributory patent infringement, then any patent licenses - granted to You under this License for that Work shall terminate - as of the date such litigation is filed. - - 4. Redistribution. You may reproduce and distribute copies of the - Work or Derivative Works thereof in any medium, with or without - modifications, and in Source or Object form, provided that You - meet the following conditions: - - (a) You must give any other recipients of the Work or - Derivative Works a copy of this License; and - - (b) You must cause any modified files to carry prominent notices - stating that You changed the files; and - - (c) You must retain, in the Source form of any Derivative Works - that You distribute, all copyright, patent, trademark, and - attribution notices from the Source form of the Work, - excluding those notices that do not pertain to any part of - the Derivative Works; and - - (d) If the Work includes a "NOTICE" text file as part of its - distribution, then any Derivative Works that You distribute must - include a readable copy of the attribution notices contained - within such NOTICE file, excluding those notices that do not - pertain to any part of the Derivative Works, in at least one - of the following places: within a NOTICE text file distributed - as part of the Derivative Works; within the Source form or - documentation, if provided along with the Derivative Works; or, - within a display generated by the Derivative Works, if and - wherever such third-party notices normally appear. The contents - of the NOTICE file are for informational purposes only and - do not modify the License. You may add Your own attribution - notices within Derivative Works that You distribute, alongside - or as an addendum to the NOTICE text from the Work, provided - that such additional attribution notices cannot be construed - as modifying the License. - - You may add Your own copyright statement to Your modifications and - may provide additional or different license terms and conditions - for use, reproduction, or distribution of Your modifications, or - for any such Derivative Works as a whole, provided Your use, - reproduction, and distribution of the Work otherwise complies with - the conditions stated in this License. - - 5. Submission of Contributions. Unless You explicitly state otherwise, - any Contribution intentionally submitted for inclusion in the Work - by You to the Licensor shall be under the terms and conditions of - this License, without any additional terms or conditions. - Notwithstanding the above, nothing herein shall supersede or modify - the terms of any separate license agreement you may have executed - with Licensor regarding such Contributions. - - 6. Trademarks. This License does not grant permission to use the trade - names, trademarks, service marks, or product names of the Licensor, - except as required for reasonable and customary use in describing the - origin of the Work and reproducing the content of the NOTICE file. - - 7. Disclaimer of Warranty. Unless required by applicable law or - agreed to in writing, Licensor provides the Work (and each - Contributor provides its Contributions) on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or - implied, including, without limitation, any warranties or conditions - of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A - PARTICULAR PURPOSE. You are solely responsible for determining the - appropriateness of using or redistributing the Work and assume any - risks associated with Your exercise of permissions under this License. - - 8. Limitation of Liability. In no event and under no legal theory, - whether in tort (including negligence), contract, or otherwise, - unless required by applicable law (such as deliberate and grossly - negligent acts) or agreed to in writing, shall any Contributor be - liable to You for damages, including any direct, indirect, special, - incidental, or consequential damages of any character arising as a - result of this License or out of the use or inability to use the - Work (including but not limited to damages for loss of goodwill, - work stoppage, computer failure or malfunction, or any and all - other commercial damages or losses), even if such Contributor - has been advised of the possibility of such damages. - - 9. Accepting Warranty or Additional Liability. While redistributing - the Work or Derivative Works thereof, You may choose to offer, - and charge a fee for, acceptance of support, warranty, indemnity, - or other liability obligations and/or rights consistent with this - License. However, in accepting such obligations, You may act only - on Your own behalf and on Your sole responsibility, not on behalf - of any other Contributor, and only if You agree to indemnify, - defend, and hold each Contributor harmless for any liability - incurred by, or claims asserted against, such Contributor by reason - of your accepting any such warranty or additional liability. - - END OF TERMS AND CONDITIONS - - APPENDIX: How to apply the Apache License to your work. - - To apply the Apache License to your work, attach the following - boilerplate notice, with the fields enclosed by brackets "{}" - replaced with your own identifying information. (Don't include - the brackets!) The text should be enclosed in the appropriate - comment syntax for the file format. We also recommend that a - file or class name and description of purpose be included on the - same "printed page" as the copyright notice for easier - identification within third-party archives. - - Copyright {yyyy} {name of copyright owner} - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. diff --git a/internal/packet/handle_linux.go b/internal/packet/handle_linux.go index 03247a5d..434a34b2 100644 --- a/internal/packet/handle_linux.go +++ b/internal/packet/handle_linux.go @@ -87,12 +87,29 @@ func (h *LinuxHandle) WritePacketData(data []byte) error { return unix.Sendto(h.fd, data, 0, addr) } -// LinkType logic for handling "any" interface (Linux SLL) vs Ethernet func (h *LinuxHandle) LinkType() layers.LinkType { + // 1. "Any" device (tcpdump -i any) uses Linux SLL (Cooked Mode) if h.ifIndex == 0 { - // "any" interface uses Linux SLL (Cooked Mode) return layers.LinkTypeLinuxSLL } + + // 2. Get interface information by Index + iface, err := net.InterfaceByIndex(h.ifIndex) + if err != nil { + // If fails to find interface, fallback to Ethernet + return layers.LinkTypeEthernet + } + + // 3. Check for VPN / TUN characteristics + // - Case A: No Hardware Address (MAC) -> Typical for TUN devices + // - Case B: Point-to-Point Flag -> Typical for VPN tunnels (WireGuard, OpenVPN) + if len(iface.HardwareAddr) == 0 || iface.Flags&net.FlagPointToPoint != 0 { + // Linux TUN devices provide Raw IP packets (No Link Header) + // This corresponds to DLT_RAW (12) or DLT_IPV4/IPV6 (101) + return layers.LinkTypeRaw + } + + // 4. Default to Ethernet return layers.LinkTypeEthernet } diff --git a/internal/packet/network_detector.go b/internal/packet/network_detector.go index 797dd6e8..c4fd4601 100644 --- a/internal/packet/network_detector.go +++ b/internal/packet/network_detector.go @@ -13,12 +13,12 @@ import ( "github.com/xvzc/SpoofDPI/internal/netutil" ) -var dnsServers = []net.IPAddr{ - {IP: net.ParseIP("8.8.8.8")}, - {IP: net.ParseIP("8.8.4.4")}, - {IP: net.ParseIP("1.1.1.1")}, - {IP: net.ParseIP("1.0.0.1")}, - {IP: net.ParseIP("9.9.9.9")}, +var dnsServers = []net.IP{ + net.ParseIP("8.8.8.8"), + net.ParseIP("8.8.4.4"), + net.ParseIP("1.1.1.1"), + net.ParseIP("1.0.0.1"), + net.ParseIP("9.9.9.9"), } type NetworkDetector struct { @@ -69,13 +69,46 @@ func (nd *NetworkDetector) Start(ctx context.Context) error { } }() - conn, err := netutil.DialFastest(ctx, "udp", dnsServers, 53, time.Duration(0)) + go func() { + // Wait for the packet capture to start + select { + case <-ctx.Done(): + return + case <-time.After(300 * time.Millisecond): + } + + ticker := time.NewTicker(500 * time.Millisecond) + defer ticker.Stop() + + nd.probe(ctx) + + for { + select { + case <-ctx.Done(): + return + case <-nd.found: + return + case <-ticker.C: + nd.probe(ctx) + } + } + }() + + return nil +} + +func (nd *NetworkDetector) probe(ctx context.Context) { + conn, err := netutil.DialFastest(ctx, "udp", &netutil.Destination{ + Addrs: dnsServers, + Port: 53, + Timeout: 2 * time.Second, + }) if err != nil { - return err + return } defer func() { _ = conn.Close() }() - return nil + _, _ = conn.Write([]byte(".")) } func (nd *NetworkDetector) processPacket(p gopacket.Packet) { @@ -202,6 +235,10 @@ func (nd *NetworkDetector) GetGatewayMAC() net.HardwareAddr { func (nd *NetworkDetector) WaitForGatewayMAC( ctx context.Context, ) (net.HardwareAddr, error) { + if nd.iface.HardwareAddr == nil { + return nil, nil + } + if nd.IsFound() { return nd.GetGatewayMAC(), nil } @@ -222,9 +259,11 @@ func findDefaultInterface(ctx context.Context) (*net.Interface, error) { conn, err := netutil.DialFastest( ctx, "udp", - dnsServers, - 53, - time.Duration(10)*time.Second, + &netutil.Destination{ + Addrs: dnsServers, + Port: 53, + Timeout: time.Duration(20) * time.Second, + }, ) if err != nil { return nil, fmt.Errorf( @@ -251,7 +290,10 @@ func findDefaultInterface(ctx context.Context) (*net.Interface, error) { ) } - for _, iface := range ifaces { + var defaultIface *net.Interface + + for i := range ifaces { + iface := ifaces[i] addrs, err := iface.Addrs() if err != nil { continue // Skip interfaces whose addresses cannot be retrieved @@ -261,11 +303,39 @@ func findDefaultInterface(ctx context.Context) (*net.Interface, error) { if ipnet, ok := addr.(*net.IPNet); ok { // Check if the IP used for Dial matches the interface's IP if ipnet.IP.Equal(localAddr.IP) { - // Return &iface (*net.Interface) instead of iface.Name (string) - return &iface, nil // Found the default interface + if len(iface.HardwareAddr) > 0 { + return &iface, nil // Found the default interface with MAC + } + defaultIface = &iface + } + } + } + } + + // fmt.Println("hello") + + // If we found a default interface but it has no MAC (e.g. VPN/TUN), + // try to find a physical interface as fallback. + if defaultIface != nil { + for i := range ifaces { + iface := ifaces[i] + if iface.Flags&net.FlagUp == 0 || iface.Flags&net.FlagLoopback != 0 { + continue + } + if len(iface.HardwareAddr) == 0 { + continue + } + + addrs, _ := iface.Addrs() + for _, addr := range addrs { + if ipnet, ok := addr.(*net.IPNet); ok { + if ipnet.IP.To4() != nil && !ipnet.IP.IsLoopback() { + return &iface, nil + } } } } + return defaultIface, nil } return nil, fmt.Errorf( diff --git a/internal/packet/sniffer.go b/internal/packet/sniffer.go index d69ed45a..07fe4c08 100644 --- a/internal/packet/sniffer.go +++ b/internal/packet/sniffer.go @@ -4,11 +4,66 @@ import ( "net" "github.com/xvzc/SpoofDPI/internal/cache" + "github.com/xvzc/SpoofDPI/internal/netutil" ) type Sniffer interface { StartCapturing() - RegisterUntracked(addrs []net.IPAddr, port int) - GetOptimalTTL(key string) uint8 - Cache() cache.Cache + RegisterUntracked(addrs []net.IP) + GetOptimalTTL(key netutil.IPKey) uint8 + Cache() cache.Cache[netutil.IPKey] +} + +// estimateHops estimates the number of hops based on TTL. +// This logic is based on the hop counting mechanism from GoodbyeDPI. +// It returns 0 if the TTL is not recognizable. +func estimateHops(ttlLeft uint8) uint8 { + // Unrecognizable initial TTL + estimatedInitialHops := uint8(255) + switch { + case ttlLeft <= 64: + estimatedInitialHops = 64 + case ttlLeft <= 128: + estimatedInitialHops = 128 + default: + estimatedInitialHops = 255 + } + + return estimatedInitialHops - ttlLeft +} + +// isLocalIP checks if an IP address is in a local/private range. +// This is used to filter out outgoing packets from local machine. +// Private ranges: 10.0.0.0/8, 172.16.0.0/12, 192.168.0.0/16, 127.0.0.0/8 +func isLocalIP(ip []byte) bool { + if len(ip) < 4 { + return false + } + + // 10.0.0.0/8 + if ip[0] == 10 { + return true + } + + // 127.0.0.0/8 (loopback) + if ip[0] == 127 { + return true + } + + // 172.16.0.0/12 (172.16.x.x - 172.31.x.x) + if ip[0] == 172 && ip[1] >= 16 && ip[1] <= 31 { + return true + } + + // 192.168.0.0/16 + if ip[0] == 192 && ip[1] == 168 { + return true + } + + // 169.254.0.0/16 (link-local) + if ip[0] == 169 && ip[1] == 254 { + return true + } + + return false } diff --git a/internal/packet/tcp_sniffer.go b/internal/packet/tcp_sniffer.go index 9871afe2..1e4bddb3 100644 --- a/internal/packet/tcp_sniffer.go +++ b/internal/packet/tcp_sniffer.go @@ -2,16 +2,14 @@ package packet import ( "context" - "fmt" "net" - "strconv" "github.com/google/gopacket" "github.com/google/gopacket/layers" "github.com/rs/zerolog" "github.com/xvzc/SpoofDPI/internal/cache" "github.com/xvzc/SpoofDPI/internal/logging" - "github.com/xvzc/SpoofDPI/internal/session" + "github.com/xvzc/SpoofDPI/internal/netutil" ) var _ Sniffer = (*TCPSniffer)(nil) @@ -19,7 +17,7 @@ var _ Sniffer = (*TCPSniffer)(nil) type TCPSniffer struct { logger zerolog.Logger - nhopCache cache.Cache + nhopCache cache.Cache[netutil.IPKey] defaultTTL uint8 handle Handle @@ -27,7 +25,7 @@ type TCPSniffer struct { func NewTCPSniffer( logger zerolog.Logger, - cache cache.Cache, + cache cache.Cache[netutil.IPKey], handle Handle, defaultTTL uint8, ) *TCPSniffer { @@ -41,7 +39,7 @@ func NewTCPSniffer( // --- HopTracker Methods --- -func (ts *TCPSniffer) Cache() cache.Cache { +func (ts *TCPSniffer) Cache() cache.Cache[netutil.IPKey] { return ts.nhopCache } @@ -52,25 +50,23 @@ func (ts *TCPSniffer) StartCapturing() { packets := packetSource.Packets() // _ = ht.handle.SetBPFRawInstructionFilter(generateSynAckFilter()) _ = ts.handle.ClearBPF() - _ = ts.handle.SetBPFRawInstructionFilter(generateSynAckFilter()) + _ = ts.handle.SetBPFRawInstructionFilter(generateSynAckFilter(ts.handle.LinkType())) // Start a dedicated goroutine to process incoming packets. go func() { // Create a base context for this goroutine. - ctx := session.WithNewTraceID(context.Background()) for packet := range packets { - ts.processPacket(ctx, packet) + ts.processPacket(context.Background(), packet) } }() } // RegisterUntracked registers new IP addresses for tracking. // Addresses that are already being tracked are ignored. -func (ts *TCPSniffer) RegisterUntracked(addrs []net.IPAddr, port int) { - portStr := strconv.Itoa(port) +func (ts *TCPSniffer) RegisterUntracked(addrs []net.IP) { for _, v := range addrs { - ts.nhopCache.Set( - v.String()+":"+portStr, + ts.nhopCache.Store( + netutil.NewIPKey(v), ts.defaultTTL, cache.Options().WithSkipExisting(true), ) @@ -79,9 +75,9 @@ func (ts *TCPSniffer) RegisterUntracked(addrs []net.IPAddr, port int) { // GetOptimalTTL retrieves the estimated hop count for a given key from the cache. // It returns the hop count and true if found, or 0 and false if not found. -func (ts *TCPSniffer) GetOptimalTTL(key string) uint8 { +func (ts *TCPSniffer) GetOptimalTTL(key netutil.IPKey) uint8 { hopCount := uint8(255) - if oTTL, ok := ts.nhopCache.Get(key); ok { + if oTTL, ok := ts.nhopCache.Fetch(key); ok { hopCount = oTTL.(uint8) } @@ -105,18 +101,23 @@ func (ts *TCPSniffer) processPacket(ctx context.Context, p gopacket.Packet) { } // Check for a TCP layer - var srcIP string + var srcIP []byte var ttlLeft uint8 // Handle IPv4 if ipLayer := p.Layer(layers.LayerTypeIPv4); ipLayer != nil { ip, _ := ipLayer.(*layers.IPv4) - srcIP = ip.SrcIP.String() + // Skip packets from local/private IPs (outgoing packets) + if isLocalIP(ip.SrcIP) { + return + } + + srcIP = ip.SrcIP ttlLeft = ip.TTL } else if ipLayer := p.Layer(layers.LayerTypeIPv6); ipLayer != nil { // Handle IPv6 ip, _ := ipLayer.(*layers.IPv6) - srcIP = ip.SrcIP.String() + srcIP = ip.SrcIP ttlLeft = ip.HopLimit } else { return // No IP layer found @@ -124,75 +125,120 @@ func (ts *TCPSniffer) processPacket(ctx context.Context, p gopacket.Packet) { // Create the cache key: ServerIP:ServerPort // (The source of the SYN/ACK is the server) - key := fmt.Sprintf("%s:%d", srcIP, tcp.SrcPort) + key := netutil.NewIPKey(srcIP) // Calculate hop count from the TTL nhops := estimateHops(ttlLeft) - ok := ts.nhopCache.Set(key, nhops, nil) - if ok { - logger.Trace(). - Str("host_info", key). - Uint8("nhops", nhops). - Uint8("ttlLeft", ttlLeft). - Msgf("received syn+ack") - } -} -// estimateHops estimates the number of hops based on TTL. -// This logic is based on the hop counting mechanism from GoodbyeDPI. -// It returns 0 if the TTL is not recognizable. -func estimateHops(ttlLeft uint8) uint8 { - // Unrecognizable initial TTL - estimatedInitialHops := uint8(255) - switch { - case ttlLeft <= 64: - estimatedInitialHops = 64 - case ttlLeft <= 128: - estimatedInitialHops = 128 - default: - estimatedInitialHops = 255 - } + stored, exists := ts.nhopCache.Fetch(key) - return estimatedInitialHops - ttlLeft + if ts.nhopCache.Store(key, nhops, cache.Options().WithUpdateExistingOnly(true)) { + if !exists || stored != nhops { + logger.Trace(). + Str("from", key.String()). + Uint8("nhops", nhops). + Uint8("ttlLeft", ttlLeft). + Msgf("ttl(tcp) update") + } + } } // GenerateSynAckFilter creates a BPF program for "ip and tcp and (tcp[13] & 18 == 18)". // This captures only TCP SYN-ACK packets (IPv4). -func generateSynAckFilter() []BPFInstruction { - instructions := []BPFInstruction{ - // 1. Check EtherType == IPv4 (0x0800) - {Op: 0x28, Jt: 0, Jf: 0, K: 12}, - {Op: 0x15, Jt: 0, Jf: 8, K: 0x0800}, - - // 2. Check Protocol == TCP (6) - {Op: 0x30, Jt: 0, Jf: 0, K: 23}, - {Op: 0x15, Jt: 0, Jf: 6, K: 6}, - - // 3. Check Fragmentation - {Op: 0x28, Jt: 0, Jf: 0, K: 20}, - {Op: 0x45, Jt: 4, Jf: 0, K: 0x1fff}, - - // 4. Find TCP Header Start (IP Header Length to X) - // Loads byte at offset 14 (IP Header Start), gets IHL, multiplies by 4, stores in X. - {Op: 0xb1, Jt: 0, Jf: 0, K: 14}, - - // 5. Check TCP Flags (SYN+ACK) - // We want to load: Ethernet(14) + IP_Len(X) + TCP_Flags(13) - // Instruction is: Load [X + K] - // So K must be 14 + 13 = 27. - - // [FIX] K was 13, changed to 27 - {Op: 0x50, Jt: 0, Jf: 0, K: 27}, - - // Bitwise AND with 18 (SYN=2 | ACK=16) - {Op: 0x54, Jt: 0, Jf: 0, K: 18}, - - // Compare Result == 18 - {Op: 0x15, Jt: 0, Jf: 1, K: 18}, +// GenerateSynAckFilter creates a BPF program adapted to the LinkType. +// It supports Ethernet, Null (Loopback/VPN), and Raw IP link types. +func generateSynAckFilter(linkType layers.LinkType) []BPFInstruction { + var baseOffset uint32 + + // Determine the offset where the IP header begins + switch linkType { + case layers.LinkTypeEthernet: + baseOffset = 14 + case layers.LinkTypeNull, layers.LinkTypeLoop: // BSD Loopback / macOS utun + baseOffset = 4 + case layers.LinkTypeRaw: // Linux TUN + baseOffset = 0 + default: + // Fallback to Ethernet or handle error if necessary + baseOffset = 14 + } - // 6. Capture - {Op: 0x6, Jt: 0, Jf: 0, K: 0x00040000}, - {Op: 0x6, Jt: 0, Jf: 0, K: 0x00000000}, + instructions := []BPFInstruction{} + + // 1. Protocol Verification (IPv4) + if linkType == layers.LinkTypeEthernet { + // Check EtherType == IPv4 (0x0800) at offset 12 + instructions = append( + instructions, + BPFInstruction{Op: 0x28, Jt: 0, Jf: 0, K: 12}, // Ldh [12] + BPFInstruction{ + Op: 0x15, + Jt: 0, + Jf: 8, + K: 0x0800, + }, // Jeq 0x800, True, False(Skip to End) + ) + } else { + // Check IP Version == 4 at the base offset + // Load byte at baseOffset, mask 0xF0, check if 0x40 + instructions = append( + instructions, + // BPFInstruction{Op: 0x30, Jt: 0, Jf: 0, K: baseOffset}, // Ldb [baseOffset] + BPFInstruction{Op: 0x54, Jt: 0, Jf: 0, K: 0xf0}, // And 0xf0 + BPFInstruction{ + Op: 0x15, + Jt: 0, + Jf: 8, + K: 0x40, + }, // Jeq 0x40, True, False(Skip to End) + ) } + // 2. Check Protocol == TCP (6) + // Protocol field is at IP header + 9 bytes + instructions = append(instructions, + BPFInstruction{Op: 0x30, Jt: 0, Jf: 0, K: baseOffset + 9}, // Ldb [baseOffset + 9] + BPFInstruction{Op: 0x15, Jt: 0, Jf: 6, K: 6}, // Jeq 6, True, False + ) + + // 3. Check Fragmentation (Flags & Fragment Offset) + // At IP header + 6 bytes + instructions = append( + instructions, + BPFInstruction{Op: 0x28, Jt: 0, Jf: 0, K: baseOffset + 6}, // Ldh [baseOffset + 6] + BPFInstruction{ + Op: 0x45, + Jt: 4, + Jf: 0, + K: 0x1fff, + }, // Jset 0x1fff, True(Skip), False + ) + + // 4. Find TCP Header Start + // Load IP IHL from (baseOffset), multiply by 4 to get length, store in X + instructions = append(instructions, + BPFInstruction{Op: 0xb1, Jt: 0, Jf: 0, K: baseOffset}, // Ldxb 4*([baseOffset]&0xf) + ) + + // 5. Check TCP Flags (SYN+ACK) + // We need to load: baseOffset + IP_Len(X) + TCP_Flags(13) + // Instruction: Load [X + K] -> K = baseOffset + 13 + instructions = append( + instructions, + BPFInstruction{ + Op: 0x50, + Jt: 0, + Jf: 0, + K: baseOffset + 13, + }, // Ldb [X + baseOffset + 13] + BPFInstruction{Op: 0x54, Jt: 0, Jf: 0, K: 18}, // And 18 (SYN|ACK) + BPFInstruction{Op: 0x15, Jt: 0, Jf: 1, K: 18}, // Jeq 18, True, False + ) + + // 6. Capture + instructions = append(instructions, + BPFInstruction{Op: 0x6, Jt: 0, Jf: 0, K: 0x00040000}, // Ret capture_len + BPFInstruction{Op: 0x6, Jt: 0, Jf: 0, K: 0x00000000}, // Ret 0 + ) + return instructions } diff --git a/internal/packet/tcp_writer.go b/internal/packet/tcp_writer.go index 5aa4bc9e..146557e9 100644 --- a/internal/packet/tcp_writer.go +++ b/internal/packet/tcp_writer.go @@ -124,12 +124,14 @@ func (tw *TCPWriter) createIPv4Layers( ) ([]gopacket.SerializableLayer, error) { var packetLayers []gopacket.SerializableLayer - eth := &layers.Ethernet{ - SrcMAC: srcMAC, - DstMAC: dstMAC, - EthernetType: layers.EthernetTypeIPv4, + if srcMAC != nil && dstMAC != nil { + eth := &layers.Ethernet{ + SrcMAC: srcMAC, + DstMAC: dstMAC, + EthernetType: layers.EthernetTypeIPv4, + } + packetLayers = append(packetLayers, eth) } - packetLayers = append(packetLayers, eth) // define ip layer ipLayer := &layers.IPv4{ @@ -171,12 +173,14 @@ func (tw *TCPWriter) createIPv6Layers( ) ([]gopacket.SerializableLayer, error) { var packetLayers []gopacket.SerializableLayer - eth := &layers.Ethernet{ - SrcMAC: srcMAC, - DstMAC: dstMAC, - EthernetType: layers.EthernetTypeIPv6, + if srcMAC != nil && dstMAC != nil { + eth := &layers.Ethernet{ + SrcMAC: srcMAC, + DstMAC: dstMAC, + EthernetType: layers.EthernetTypeIPv6, + } + packetLayers = append(packetLayers, eth) } - packetLayers = append(packetLayers, eth) ipLayer := &layers.IPv6{ Version: 6, diff --git a/internal/packet/udp_sniffer.go b/internal/packet/udp_sniffer.go new file mode 100644 index 00000000..acd46d28 --- /dev/null +++ b/internal/packet/udp_sniffer.go @@ -0,0 +1,200 @@ +package packet + +import ( + "context" + "net" + + "github.com/google/gopacket" + "github.com/google/gopacket/layers" + "github.com/rs/zerolog" + "github.com/xvzc/SpoofDPI/internal/cache" + "github.com/xvzc/SpoofDPI/internal/logging" + "github.com/xvzc/SpoofDPI/internal/netutil" +) + +var _ Sniffer = (*UDPSniffer)(nil) + +type UDPSniffer struct { + logger zerolog.Logger + + nhopCache cache.Cache[netutil.IPKey] + defaultTTL uint8 + + handle Handle +} + +func NewUDPSniffer( + logger zerolog.Logger, + cache cache.Cache[netutil.IPKey], + handle Handle, + defaultTTL uint8, +) *UDPSniffer { + return &UDPSniffer{ + logger: logger, + nhopCache: cache, + handle: handle, + defaultTTL: defaultTTL, + } +} + +// --- HopTracker Methods --- + +func (us *UDPSniffer) Cache() cache.Cache[netutil.IPKey] { + return us.nhopCache +} + +// StartCapturing begins monitoring for UDP packets in a background goroutine. +func (us *UDPSniffer) StartCapturing() { + // Create a new packet source from the handle. + packetSource := gopacket.NewPacketSource(us.handle, us.handle.LinkType()) + packets := packetSource.Packets() + + _ = us.handle.ClearBPF() + _ = us.handle.SetBPFRawInstructionFilter(generateUdpFilter(us.handle.LinkType())) + + // Start a dedicated goroutine to process incoming packets. + go func() { + for packet := range packets { + us.processPacket(context.Background(), packet) + } + }() +} + +// RegisterUntracked registers new IP addresses for tracking. +// Addresses that are already being tracked are ignored. +func (us *UDPSniffer) RegisterUntracked(addrs []net.IP) { + for _, v := range addrs { + us.nhopCache.Store( + netutil.NewIPKey(v), + us.defaultTTL, + cache.Options().WithSkipExisting(true), + ) + } +} + +// GetOptimalTTL retrieves the estimated hop count for a given key from the cache. +// It returns the hop count and true if found, or 0 and false if not found. +func (us *UDPSniffer) GetOptimalTTL(key netutil.IPKey) uint8 { + hopCount := uint8(255) + if oTTL, ok := us.nhopCache.Fetch(key); ok { + hopCount = oTTL.(uint8) + } + + return max(hopCount, 2) - 1 +} + +// processPacket analyzes a single packet to store hop counts. +func (us *UDPSniffer) processPacket(ctx context.Context, p gopacket.Packet) { + logger := logging.WithLocalScope(ctx, us.logger, "sniff") + + udpLayer := p.Layer(layers.LayerTypeUDP) + if udpLayer == nil { + return + } + + var srcIP []byte + var ttlLeft uint8 + + // Handle IPv4 + if ipLayer := p.Layer(layers.LayerTypeIPv4); ipLayer != nil { + ip, _ := ipLayer.(*layers.IPv4) + + // Skip packets from local/private IPs (outgoing packets) + if isLocalIP(ip.SrcIP) { + return + } + // Skip packets where dst is not local (outgoing packets including our fake packets) + if !isLocalIP(ip.DstIP) { + return + } + + srcIP = ip.SrcIP + ttlLeft = ip.TTL + } else if ipLayer := p.Layer(layers.LayerTypeIPv6); ipLayer != nil { + // Handle IPv6 + ip, _ := ipLayer.(*layers.IPv6) + srcIP = ip.SrcIP + ttlLeft = ip.HopLimit + } else { + return // No IP layer found + } + + key := netutil.NewIPKey(srcIP) + // Calculate hop count from the TTL + nhops := estimateHops(ttlLeft) + + stored, exists := us.nhopCache.Fetch(key) + + if us.nhopCache.Store(key, nhops, nil) { + if !exists || stored != nhops { + logger.Trace(). + Str("from", key.String()). + Uint8("nhops", nhops). + Uint8("ttlLeft", ttlLeft). + Msgf("ttl(udp) update") + } + } +} + +// GenerateUdpFilter creates a BPF program for "ip and udp". +// It supports Ethernet, Null (Loopback/VPN), and Raw IP link types. +func generateUdpFilter(linkType layers.LinkType) []BPFInstruction { + var baseOffset uint32 + + // Determine the offset where the IP header begins + switch linkType { + case layers.LinkTypeEthernet: + baseOffset = 14 + case layers.LinkTypeNull, layers.LinkTypeLoop: // BSD Loopback / macOS utun + baseOffset = 4 + case layers.LinkTypeRaw: // Linux TUN + baseOffset = 0 + default: + // Fallback to Ethernet or handle error if necessary + baseOffset = 14 + } + + instructions := []BPFInstruction{} + + // 1. Protocol Verification (IPv4) + if linkType == layers.LinkTypeEthernet { + // Check EtherType == IPv4 (0x0800) at offset 12 + instructions = append( + instructions, + BPFInstruction{Op: 0x28, Jt: 0, Jf: 0, K: 12}, // Ldh [12] + BPFInstruction{ + Op: 0x15, + Jt: 0, + Jf: 3, + K: 0x0800, + }, // Jeq 0x800, True, False(Skip to End) + ) + } else { + // Check IP Version == 4 at the base offset + // Load byte at baseOffset, mask 0xF0, check if 0x40 + // Ldb [baseOffset] + // And 0xf0 + // Jeq 0x40, True, False(Skip to End) + instructions = append( + instructions, + BPFInstruction{Op: 0x30, Jt: 0, Jf: 0, K: baseOffset}, + BPFInstruction{Op: 0x54, Jt: 0, Jf: 0, K: 0xf0}, + BPFInstruction{Op: 0x15, Jt: 0, Jf: 3, K: 0x40}, + ) + } + + // 2. Check Protocol == UDP (17) + // Protocol field is at IP header + 9 bytes + instructions = append(instructions, + BPFInstruction{Op: 0x30, Jt: 0, Jf: 0, K: baseOffset + 9}, // Ldb [baseOffset + 9] + BPFInstruction{Op: 0x15, Jt: 0, Jf: 1, K: 17}, // Jeq 17, True, False + ) + + // 3. Capture + instructions = append(instructions, + BPFInstruction{Op: 0x6, Jt: 0, Jf: 0, K: 0x00040000}, // Ret capture_len + BPFInstruction{Op: 0x6, Jt: 0, Jf: 0, K: 0x00000000}, // Ret 0 + ) + + return instructions +} diff --git a/internal/packet/udp_writer.go b/internal/packet/udp_writer.go new file mode 100644 index 00000000..873562f3 --- /dev/null +++ b/internal/packet/udp_writer.go @@ -0,0 +1,199 @@ +package packet + +import ( + "context" + "errors" + "net" + + "github.com/google/gopacket" + "github.com/google/gopacket/layers" + "github.com/rs/zerolog" +) + +var _ Writer = (*UDPWriter)(nil) + +type UDPWriter struct { + logger zerolog.Logger + + handle Handle + iface *net.Interface + gatewayMAC net.HardwareAddr +} + +func NewUDPWriter( + logger zerolog.Logger, + handle Handle, + iface *net.Interface, + gatewayMAC net.HardwareAddr, +) *UDPWriter { + return &UDPWriter{ + logger: logger, + handle: handle, + iface: iface, + gatewayMAC: gatewayMAC, + } +} + +// --- Injector Methods --- + +// WriteCraftedPacket crafts and injects a full UDP packet from a payload. +// It uses the pre-configured gateway MAC address. +func (uw *UDPWriter) WriteCraftedPacket( + ctx context.Context, + src net.Addr, + dst net.Addr, + ttl uint8, + payload []byte, +) (int, error) { + // set variables for src/dst + srcMAC := uw.iface.HardwareAddr + dstMAC := uw.gatewayMAC + + srcUDP, ok := src.(*net.UDPAddr) + if !ok { + return 0, errors.New("src is not *net.UDPAddr") + } + + dstUDP, ok := dst.(*net.UDPAddr) + if !ok { + return 0, errors.New("dst is not *net.UDPAddr") + } + + srcPort := srcUDP.Port + dstPort := dstUDP.Port + + var err error + var packetLayers []gopacket.SerializableLayer + if dstUDP.IP.To4() != nil { + packetLayers, err = uw.createIPv4Layers( + srcMAC, + dstMAC, + srcUDP.IP, + dstUDP.IP, + srcPort, + dstPort, + ttl, + ) + } else { + packetLayers, err = uw.createIPv6Layers( + srcMAC, + dstMAC, + srcUDP.IP, + dstUDP.IP, + srcPort, + dstPort, + ttl, + ) + } + + if err != nil { + return 0, err + } + + // serialize the packet L2(optional) + L3 + L4 + payload + buf := gopacket.NewSerializeBuffer() + opts := gopacket.SerializeOptions{ + ComputeChecksums: true, // Recalculate checksums + FixLengths: true, // Fix lengths + } + + packetLayers = append(packetLayers, gopacket.Payload(payload)) + + err = gopacket.SerializeLayers(buf, opts, packetLayers...) + if err != nil { + return 0, err + } + + // inject the raw L2 packet + if err := uw.handle.WritePacketData(buf.Bytes()); err != nil { + return 0, err + } + + return len(payload), nil +} + +func (uw *UDPWriter) createIPv4Layers( + srcMAC net.HardwareAddr, + dstMAC net.HardwareAddr, + srcIP net.IP, + dstIP net.IP, + srcPort int, + dstPort int, + ttl uint8, +) ([]gopacket.SerializableLayer, error) { + var packetLayers []gopacket.SerializableLayer + + if srcMAC != nil && dstMAC != nil { + eth := &layers.Ethernet{ + SrcMAC: srcMAC, + DstMAC: dstMAC, + EthernetType: layers.EthernetTypeIPv4, + } + packetLayers = append(packetLayers, eth) + } + + // define ip layer + ipLayer := &layers.IPv4{ + Version: 4, + TTL: ttl, + Protocol: layers.IPProtocolUDP, + SrcIP: srcIP, + DstIP: dstIP, + } + packetLayers = append(packetLayers, ipLayer) + + // define udp layer + udpLayer := &layers.UDP{ + SrcPort: layers.UDPPort(srcPort), + DstPort: layers.UDPPort(dstPort), + } + packetLayers = append(packetLayers, udpLayer) + + if err := udpLayer.SetNetworkLayerForChecksum(ipLayer); err != nil { + return nil, err + } + + return packetLayers, nil +} + +func (uw *UDPWriter) createIPv6Layers( + srcMAC net.HardwareAddr, + dstMAC net.HardwareAddr, + srcIP net.IP, + dstIP net.IP, + srcPort int, + dstPort int, + ttl uint8, +) ([]gopacket.SerializableLayer, error) { + var packetLayers []gopacket.SerializableLayer + + if srcMAC != nil && dstMAC != nil { + eth := &layers.Ethernet{ + SrcMAC: srcMAC, + DstMAC: dstMAC, + EthernetType: layers.EthernetTypeIPv6, + } + packetLayers = append(packetLayers, eth) + } + + ipLayer := &layers.IPv6{ + Version: 6, + HopLimit: ttl, + NextHeader: layers.IPProtocolUDP, + SrcIP: srcIP, + DstIP: dstIP, + } + packetLayers = append(packetLayers, ipLayer) + + udpLayer := &layers.UDP{ + SrcPort: layers.UDPPort(srcPort), + DstPort: layers.UDPPort(dstPort), + } + packetLayers = append(packetLayers, udpLayer) + + if err := udpLayer.SetNetworkLayerForChecksum(ipLayer); err != nil { + return nil, err + } + + return packetLayers, nil +} diff --git a/internal/proto/http.go b/internal/proto/http.go index 7f718a8b..1ca656f0 100644 --- a/internal/proto/http.go +++ b/internal/proto/http.go @@ -68,8 +68,8 @@ func ReadHttpRequest(rdr io.Reader) (*HTTPRequest, error) { return NewHttpRequest(req), nil } -// ExtractDomain returns the host without port information -func (r *HTTPRequest) ExtractDomain() string { +// ExtractHost returns the host without port information +func (r *HTTPRequest) ExtractHost() string { host, _, err := net.SplitHostPort(r.Host) if err != nil { return r.Host diff --git a/internal/proto/socks5.go b/internal/proto/socks5.go index 92054979..f99a46e5 100644 --- a/internal/proto/socks5.go +++ b/internal/proto/socks5.go @@ -13,42 +13,34 @@ const ( SOCKSVersion = 0x05 // Auth - AuthNone = 0x00 - AuthGSSAPI = 0x01 - AuthUserPass = 0x02 - AuthNoAccept = 0xFF + SOCKS5AuthNone = 0x00 + SOCKS5AuthGSSAPI = 0x01 + SOCKS5AuthUserPass = 0x02 + SOCKS5AuthNoAccept = 0xFF // Command - CmdConnect = 0x01 - CmdBind = 0x02 - CmdUDPAssociate = 0x03 + SOCKS5CmdConnect = 0x01 + SOCKS5CmdBind = 0x02 + SOCKS5CmdUDPAssociate = 0x03 // ATYP - ATYPIPv4 = 0x01 - ATYPFQDN = 0x03 - ATYPIPv6 = 0x04 + SOCKS5AddrTypeIPv4 = 0x01 + SOCKS5AddrTypeFQDN = 0x03 + SOCKS5AddrTypeIPv6 = 0x04 // Reply codes - ReplyCodeSuccess = 0x00 - ReplyCodeGenFailure = 0x01 - ReplyCodeCmdNotSupport = 0x07 - ReplyCodeAddrNotSupport = 0x08 + SOCKS5RCodeSuccess = 0x00 + SOCKS5RCodeGenFailure = 0x01 + SOCKS5RCodeCmdNotSupported = 0x07 + SOCKS5RCodeAddrNotSupported = 0x08 ) type SOCKS5Request struct { - Cmd byte - Domain string - IP net.IP - Port int -} - -func (r *SOCKS5Request) Host() string { - ret := r.Domain - if ret == "" { - ret = r.IP.String() - } - - return ret + Cmd byte + ATYP byte + FQDN string + IP net.IP + Port int } // ReadSocks5Request parses the SOCKS5 request details. @@ -73,18 +65,20 @@ func ReadSocks5Request(conn net.Conn) (*SOCKS5Request, error) { var ip net.IP switch atyp { - case ATYPIPv4: + case SOCKS5AddrTypeIPv4: buf := make([]byte, 4) if _, err := io.ReadFull(conn, buf); err != nil { return nil, err } + ip = net.IP(buf) - case ATYPFQDN: + case SOCKS5AddrTypeFQDN: lenBuf := make([]byte, 1) if _, err := io.ReadFull(conn, lenBuf); err != nil { return nil, err } + domainLen := int(lenBuf[0]) domainBuf := make([]byte, domainLen) if _, err := io.ReadFull(conn, domainBuf); err != nil { @@ -92,7 +86,7 @@ func ReadSocks5Request(conn net.Conn) (*SOCKS5Request, error) { } domain = string(domainBuf) - case ATYPIPv6: + case SOCKS5AddrTypeIPv6: buf := make([]byte, 16) if _, err := io.ReadFull(conn, buf); err != nil { return nil, err @@ -110,10 +104,11 @@ func ReadSocks5Request(conn net.Conn) (*SOCKS5Request, error) { port := int(binary.BigEndian.Uint16(portBuf)) return &SOCKS5Request{ - Cmd: cmd, - Domain: domain, - IP: ip, - Port: port, + Cmd: cmd, + ATYP: atyp, + FQDN: domain, + IP: ip, + Port: port, }, nil } @@ -132,15 +127,15 @@ func NewSOCKS5Reply(rep byte) *SOCKS5Reply { } func SOCKS5SuccessResponse() *SOCKS5Reply { - return NewSOCKS5Reply(ReplyCodeSuccess) + return NewSOCKS5Reply(SOCKS5RCodeSuccess) } func SOCKS5FailureResponse() *SOCKS5Reply { - return NewSOCKS5Reply(ReplyCodeGenFailure) + return NewSOCKS5Reply(SOCKS5RCodeGenFailure) } func SOCKS5CommandNotSupportedResponse() *SOCKS5Reply { - return NewSOCKS5Reply(ReplyCodeCmdNotSupport) + return NewSOCKS5Reply(SOCKS5RCodeCmdNotSupported) } func (r *SOCKS5Reply) Bind(ip net.IP) *SOCKS5Reply { @@ -157,7 +152,7 @@ func (r *SOCKS5Reply) Port(port int) *SOCKS5Reply { func (r *SOCKS5Reply) Write(w io.Writer) error { buf := make([]byte, 0, 10) - buf = append(buf, SOCKSVersion, r.Rep, 0x00, ATYPIPv4) + buf = append(buf, SOCKSVersion, r.Rep, 0x00, SOCKS5AddrTypeIPv4) // Use To4() to ensure 4 bytes if it's an IPv4 address stored in IPv6 format if ip4 := r.BindIP.To4(); ip4 != nil { diff --git a/internal/proxy/proxy.go b/internal/proxy/proxy.go deleted file mode 100644 index b7a47163..00000000 --- a/internal/proxy/proxy.go +++ /dev/null @@ -1,9 +0,0 @@ -package proxy - -import ( - "context" -) - -type ProxyServer interface { - ListenAndServe(ctx context.Context, wait chan struct{}) -} diff --git a/internal/proxy/socks5/socks5_proxy.go b/internal/proxy/socks5/socks5_proxy.go deleted file mode 100644 index f9097cae..00000000 --- a/internal/proxy/socks5/socks5_proxy.go +++ /dev/null @@ -1,327 +0,0 @@ -package socks5 - -import ( - "context" - "encoding/json" - "errors" - "fmt" - "io" - "net" - "time" - - "github.com/rs/zerolog" - "github.com/xvzc/SpoofDPI/internal/config" - "github.com/xvzc/SpoofDPI/internal/dns" - "github.com/xvzc/SpoofDPI/internal/logging" - "github.com/xvzc/SpoofDPI/internal/matcher" - "github.com/xvzc/SpoofDPI/internal/netutil" - "github.com/xvzc/SpoofDPI/internal/proto" - "github.com/xvzc/SpoofDPI/internal/proxy" - "github.com/xvzc/SpoofDPI/internal/proxy/http" - "github.com/xvzc/SpoofDPI/internal/ptr" - "github.com/xvzc/SpoofDPI/internal/session" -) - -type SOCKS5Proxy struct { - logger zerolog.Logger - - resolver dns.Resolver - httpsHandler *http.HTTPSHandler // SOCKS5 primarily uses CONNECT, so we leverage HTTPSHandler - ruleMatcher matcher.RuleMatcher - serverOpts *config.ServerOptions - policyOpts *config.PolicyOptions -} - -func NewSOCKS5Proxy( - logger zerolog.Logger, - resolver dns.Resolver, - httpsHandler *http.HTTPSHandler, - ruleMatcher matcher.RuleMatcher, - serverOpts *config.ServerOptions, - policyOpts *config.PolicyOptions, -) proxy.ProxyServer { - return &SOCKS5Proxy{ - logger: logger, - resolver: resolver, - httpsHandler: httpsHandler, - ruleMatcher: ruleMatcher, - serverOpts: serverOpts, - policyOpts: policyOpts, - } -} - -func (p *SOCKS5Proxy) ListenAndServe(ctx context.Context, wait chan struct{}) { - <-wait - - logger := p.logger.With().Ctx(ctx).Logger() - - // Using ListenTCP to match HTTPProxy style, though net.Listen is also fine - listener, err := net.ListenTCP("tcp", p.serverOpts.ListenAddr) - if err != nil { - p.logger.Fatal(). - Err(err). - Msgf("error creating socks5 listener on %s", p.serverOpts.ListenAddr.String()) - } - - logger.Info(). - Msgf("created a socks5 listener on %s", p.serverOpts.ListenAddr) - - for { - conn, err := listener.Accept() - if err != nil { - p.logger.Error(). - Err(err). - Msg("failed to accept new connection") - continue - } - - go p.handleConnection(session.WithNewTraceID(context.Background()), conn) - } -} - -func (p *SOCKS5Proxy) handleConnection(ctx context.Context, conn net.Conn) { - logger := logging.WithLocalScope(ctx, p.logger, "socks5") - - ctx, cancel := context.WithCancel(ctx) - defer cancel() - defer netutil.CloseConns(conn) - - // 1. Negotiation Phase - if err := p.negotiate(conn); err != nil { - logger.Debug().Err(err).Msg("socks5 negotiation failed") - return - } - - // 2. Request Phase - req, err := proto.ReadSocks5Request(conn) - if err != nil { - if err != io.EOF { - logger.Warn().Err(err).Msg("failed to read socks5 request") - } - return - } - - // Only support CONNECT for now - if req.Cmd != proto.CmdConnect { - _ = proto.SOCKS5CommandNotSupportedResponse().Write(conn) - logger.Warn().Uint8("cmd", req.Cmd).Msg("unsupported socks5 command") - return - } - - // Setup Logging Context - remoteInfo := req.Domain - if remoteInfo == "" { - remoteInfo = req.IP.String() - } - ctx = session.WithHostInfo(ctx, remoteInfo) - logger = logger.With().Ctx(ctx).Logger() - - logger.Debug(). - Str("from", conn.RemoteAddr().String()). - Msg("new socks5 request") - - // 3. Match Domain Rules (if domain provided) - var nameMatch *config.Rule - if req.Domain != "" { - nameMatch = p.ruleMatcher.Search( - &matcher.Selector{ - Kind: matcher.MatchKindDomain, - Domain: ptr.FromValue(req.Domain), - }, - ) - if nameMatch != nil && logger.GetLevel() == zerolog.TraceLevel { - jsonAttrs, _ := json.Marshal(nameMatch) - logger.Trace().RawJSON("values", jsonAttrs).Msg("name match") - } - } - - // 4. DNS Resolution - // SOCKS5 allows IP or Domain. If Domain, we resolve. If IP, we use it directly. - t1 := time.Now() - var addrs []net.IPAddr - - if req.Domain != "" { - // Resolve Domain - rSet, err := p.resolver.Resolve(ctx, req.Domain, nil, nameMatch) - if err != nil { - _ = proto.SOCKS5FailureResponse().Write(conn) - logging.ErrorUnwrapped(&logger, "dns lookup failed", err) - return - } - addrs = rSet.Addrs - } else { - // IP Request - Just wrap the IP - addrs = []net.IPAddr{{IP: req.IP}} - } - - dt := time.Since(t1).Milliseconds() - logger.Debug(). - Int("cnt", len(addrs)). - Str("took", fmt.Sprintf("%dms", dt)). - Msg("dns lookup ok") - - // Avoid recursively querying self. - ok, err := validateDestination(addrs, req.Port, p.serverOpts.ListenAddr) - if err != nil { - logger.Debug().Err(err).Msg("error determining if valid destination") - if !ok { - _ = proto.SOCKS5FailureResponse().Write(conn) - return - } - } - - // 6. Match IP Rules - var selectors []*matcher.Selector - for _, v := range addrs { - selectors = append(selectors, &matcher.Selector{ - Kind: matcher.MatchKindAddr, - IP: ptr.FromValue(v.IP), - Port: ptr.FromValue(uint16(req.Port)), - }) - } - - addrMatch := p.ruleMatcher.SearchAll(selectors) - if addrMatch != nil && logger.GetLevel() == zerolog.TraceLevel { - jsonAttrs, _ := json.Marshal(addrMatch) - logger.Trace().RawJSON("values", jsonAttrs).Msg("addr match") - } - - bestMatch := matcher.GetHigherPriorityRule(addrMatch, nameMatch) - if bestMatch != nil && *bestMatch.Block { - logger.Debug().Msg("request is blocked by policy") - _ = proto.SOCKS5FailureResponse().Write(conn) - // Or specific code for blocked - return - } - - // 7. Handover to Handler - // Important: We must send a success reply to the client BEFORE handing over to the handler, - // because the handler (SpoofDPI) typically expects a raw stream where it can start TLS handshake immediately. - // However, standard SOCKS5 expects the proxy to connect to the target FIRST, then send success. - // Since SpoofDPI handler does the connection, we might need to send success here optimistically. - - // [Optimistic Success] - // We tell the client "OK, we are connected" so it starts sending data (e.g. ClientHello). - // The real connection happens inside p.tcpHandler.HandleRequest. - // Note: BIND addr/port is usually 0.0.0.0:0 if we don't care. - if err := proto.SOCKS5SuccessResponse().Bind(net.IPv4zero).Port(0).Write(conn); err != nil { - logger.Error().Err(err).Msg("failed to write socks5 success reply") - return - } - - dst := &http.Destination{ - Domain: req.Domain, - Addrs: addrs, - Port: req.Port, - Timeout: *p.serverOpts.Timeout, - } - - // Note: 'req' is nil here because it's not an HTTP request yet. - // The handler must be able to handle nil req or we wrap a dummy one. - // Assuming Handler is adapted to deal with raw streams or nil reqs for SOCKS mode. - handleErr := p.httpsHandler.HandleRequest(ctx, conn, dst, bestMatch) - - if handleErr == nil { - return - } - - logger.Warn().Err(handleErr).Msg("error handling request") - if !errors.Is(handleErr, netutil.ErrBlocked) { - return - } - - // 8. Auto Config (Duplicate logic from HTTPProxy) - if nameMatch != nil { - logger.Info(). - Interface("match", nameMatch.Match.Domains). - Str("name", *nameMatch.Name). - Msg("skipping auto-config (duplicate policy)") - return - } - - if addrMatch != nil { - logger.Info(). - Interface("match", addrMatch.Match.Addrs). - Str("name", *addrMatch.Name). - Msg("skipping auto-config (duplicate policy)") - return - } - - if *p.policyOpts.Auto && p.policyOpts.Template != nil { - newRule := p.policyOpts.Template.Clone() - targetDomain := req.Domain - if targetDomain == "" && len(addrs) > 0 { - // If request was by IP, we can't really add a domain rule, - // maybe add IP rule or skip. Use domain if available. - targetDomain = addrs[0].IP.String() - } - - newRule.Match = &config.MatchAttrs{Domains: []string{targetDomain}} - - if err := p.ruleMatcher.Add(newRule); err != nil { - logger.Info().Err(err).Msg("failed to add config automatically") - } else { - logger.Info().Msg("automatically added to config") - } - } -} - -// validateDestination checks if we are recursively querying ourselves. -// This function needs to be duplicated or moved to a shared utility if common logic is identical. -// For now, I'll use a local helper or assume the logic is specific enough. -// Since `isRecursiveDst` was a method on HTTPProxy, I'll assume similar logic is needed here. -// I'll implement a local version for now to keep it self-contained in this package or duplicated from http/proxy.go. -// Or I can just omit it if I don't want to duplicate, but safety is important. -// I'll add a simple check against listening port/loopback. -func validateDestination( - dstAddrs []net.IPAddr, - dstPort int, - listenAddr *net.TCPAddr, -) (bool, error) { - if dstPort != int(listenAddr.Port) { - return true, nil - } - - for _, dstAddr := range dstAddrs { - ip := dstAddr.IP - if ip.IsLoopback() { - return false, nil - } - - ifAddrs, err := net.InterfaceAddrs() - if err != nil { - return false, err - } - - for _, addr := range ifAddrs { - if ipnet, ok := addr.(*net.IPNet); ok { - if ipnet.IP.Equal(ip) { - return false, nil - } - } - } - } - return true, nil -} - -// negotiate performs SOCKS5 auth negotiation (NoAuth only). -func (p *SOCKS5Proxy) negotiate(conn net.Conn) error { - header := make([]byte, 2) - if _, err := io.ReadFull(conn, header); err != nil { - return err - } - - if header[0] != proto.SOCKSVersion { - return fmt.Errorf("unsupported version: %d", header[0]) - } - - nMethods := int(header[1]) - methods := make([]byte, nMethods) - if _, err := io.ReadFull(conn, methods); err != nil { - return err - } - - // Respond: Version 5, Method NoAuth(0) - _, err := conn.Write([]byte{proto.SOCKSVersion, proto.AuthNone}) - return err -} diff --git a/internal/proxy/http/http_handler.go b/internal/server/http/http.go similarity index 68% rename from internal/proxy/http/http_handler.go rename to internal/server/http/http.go index 3cc415f6..80ee9b73 100644 --- a/internal/proxy/http/http_handler.go +++ b/internal/server/http/http.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "net" + "time" "github.com/rs/zerolog" "github.com/xvzc/SpoofDPI/internal/config" @@ -16,9 +17,7 @@ type HTTPHandler struct { logger zerolog.Logger } -func NewHTTPHandler( - logger zerolog.Logger, -) *HTTPHandler { +func NewHTTPHandler(logger zerolog.Logger) *HTTPHandler { return &HTTPHandler{ logger: logger, } @@ -28,13 +27,14 @@ func (h *HTTPHandler) HandleRequest( ctx context.Context, lConn net.Conn, // Use the net.Conn interface, not a concrete *net.TCPConn. req *proto.HTTPRequest, // Assumes HttpRequest is a custom type for request parsing. - dst *Destination, + dst *netutil.Destination, rule *config.Rule, ) error { logger := logging.WithLocalScope(ctx, h.logger, "http") - rConn, err := netutil.DialFastest(ctx, "tcp", dst.Addrs, dst.Port, dst.Timeout) + rConn, err := netutil.DialFastest(ctx, "tcp", dst) if err != nil { + _ = proto.HTTPBadGatewayResponse().Write(lConn) return err } @@ -50,27 +50,22 @@ func (h *HTTPHandler) HandleRequest( return fmt.Errorf("failed to send request: %w", err) } - errCh := make(chan error, 2) + // Start bi-directional tunneling + resCh := make(chan netutil.TransferResult, 2) ctx, cancel := context.WithCancel(ctx) defer cancel() - go netutil.TunnelConns(ctx, logger, errCh, rConn, lConn) - go netutil.TunnelConns(ctx, logger, errCh, lConn, rConn) + startedAt := time.Now() + go netutil.TunnelConns(ctx, resCh, lConn, rConn, netutil.TunnelDirOut) + go netutil.TunnelConns(ctx, resCh, rConn, lConn, netutil.TunnelDirIn) - for range 2 { - e := <-errCh - if e == nil { - continue - } - - return fmt.Errorf( - "unsuccessful tunnel %s -> %s: %w", - lConn.RemoteAddr(), - rConn.RemoteAddr(), - e, - ) - } - - return nil + return netutil.WaitAndLogTunnel( + ctx, + logger, + resCh, + startedAt, + netutil.DescribeRoute(lConn, rConn), + nil, + ) } diff --git a/internal/proxy/http/https_handler.go b/internal/server/http/https.go similarity index 52% rename from internal/proxy/http/https_handler.go rename to internal/server/http/https.go index 2a9f8a1e..502307be 100644 --- a/internal/proxy/http/https_handler.go +++ b/internal/server/http/https.go @@ -6,59 +6,87 @@ import ( "fmt" "io" "net" + "slices" + "time" "github.com/rs/zerolog" + "github.com/samber/lo" "github.com/xvzc/SpoofDPI/internal/config" "github.com/xvzc/SpoofDPI/internal/desync" "github.com/xvzc/SpoofDPI/internal/logging" "github.com/xvzc/SpoofDPI/internal/netutil" "github.com/xvzc/SpoofDPI/internal/packet" "github.com/xvzc/SpoofDPI/internal/proto" - "github.com/xvzc/SpoofDPI/internal/ptr" ) type HTTPSHandler struct { - logger zerolog.Logger - desyncer *desync.TLSDesyncer - sniffer packet.Sniffer - httpsOpts *config.HTTPSOptions + logger zerolog.Logger + desyncer *desync.TLSDesyncer + sniffer packet.Sniffer + defaultHTTPSOpts *config.HTTPSOptions + defaultConnOpts *config.ConnOptions } func NewHTTPSHandler( logger zerolog.Logger, desyncer *desync.TLSDesyncer, sniffer packet.Sniffer, - httpsOpts *config.HTTPSOptions, + defaultHTTPSOpts *config.HTTPSOptions, + defaultConnOpts *config.ConnOptions, ) *HTTPSHandler { return &HTTPSHandler{ - logger: logger, - desyncer: desyncer, - sniffer: sniffer, - httpsOpts: httpsOpts, + logger: logger, + desyncer: desyncer, + sniffer: sniffer, + defaultHTTPSOpts: defaultHTTPSOpts, + defaultConnOpts: defaultConnOpts, } } -// func (h *HTTPSHandler) DefaultRule() *policy.Rule { -// return h.defaultAttrs.Clone() -// } func (h *HTTPSHandler) HandleRequest( ctx context.Context, lConn net.Conn, - dst *Destination, + dst *netutil.Destination, rule *config.Rule, ) error { - httpsOpts := h.httpsOpts + httpsOpts := h.defaultHTTPSOpts.Clone() + connOpts := h.defaultConnOpts.Clone() if rule != nil { httpsOpts = httpsOpts.Merge(rule.HTTPS) + connOpts = connOpts.Merge(rule.Conn) + } + + logger := logging.WithLocalScope(ctx, h.logger, "handshake") + + // 1. Send 200 Connection Established + if err := proto.HTTPConnectionEstablishedResponse().Write(lConn); err != nil { + if !netutil.IsConnectionResetByPeer(err) && !errors.Is(err, io.EOF) { + logger.Trace().Err(err).Msgf("proxy handshake error") + return fmt.Errorf("failed to handle proxy handshake: %w", err) + } + return nil } + logger.Trace().Msgf("sent 200 connection established -> %s", lConn.RemoteAddr()) - if h.sniffer != nil && ptr.FromPtr(httpsOpts.FakeCount) > 0 { - h.sniffer.RegisterUntracked(dst.Addrs, dst.Port) + // 2. Tunnel + return h.tunnel(ctx, lConn, dst, httpsOpts, connOpts) +} + +func (h *HTTPSHandler) tunnel( + ctx context.Context, + lConn net.Conn, + dst *netutil.Destination, + httpsOpts *config.HTTPSOptions, + connOpts *config.ConnOptions, +) error { + if h.sniffer != nil && lo.FromPtr(httpsOpts.FakeCount) > 0 { + h.sniffer.RegisterUntracked(dst.Addrs) } logger := logging.WithLocalScope(ctx, h.logger, "https") - rConn, err := netutil.DialFastest(ctx, "tcp", dst.Addrs, dst.Port, dst.Timeout) + dst.Timeout = *connOpts.TCPTimeout + rConn, err := netutil.DialFastest(ctx, "tcp", dst) if err != nil { return err } @@ -66,21 +94,26 @@ func (h *HTTPSHandler) HandleRequest( logger.Debug().Msgf("new remote conn -> %s", rConn.RemoteAddr()) - tlsMsg, err := h.handleProxyHandshake(ctx, lConn) + // Read the first message from the client (expected to be ClientHello) + tlsMsg, err := proto.ReadTLSMessage(lConn) if err != nil { - if !netutil.IsConnectionResetByPeer(err) && !errors.Is(err, io.EOF) { - logger.Trace().Err(err).Msgf("proxy handshake error") - return fmt.Errorf("failed to handle proxy handshake: %w", err) + if err == io.EOF || err.Error() == "unexpected EOF" { + return nil } - - return nil + logger.Trace().Err(err).Msgf("failed to read first message from client") + return err } + logger.Debug(). + Int("len", tlsMsg.Len()). + Msgf("client hello received <- %s", lConn.RemoteAddr()) + if !tlsMsg.IsClientHello() { logger.Trace().Int("len", tlsMsg.Len()).Msg("not a client hello. aborting") return nil } + // Send ClientHello to the remote server (with desync if configured) n, err := h.sendClientHello(ctx, rConn, tlsMsg, httpsOpts) if err != nil { return fmt.Errorf("failed to send client hello: %w", err) @@ -90,67 +123,43 @@ func (h *HTTPSHandler) HandleRequest( Int("len", n). Msgf("sent client hello -> %s", rConn.RemoteAddr()) - errCh := make(chan error, 2) + // Start bi-directional tunneling + resCh := make(chan netutil.TransferResult, 2) ctx, cancel := context.WithCancel(ctx) defer cancel() - go netutil.TunnelConns(ctx, logger, errCh, rConn, lConn) - go netutil.TunnelConns(ctx, logger, errCh, lConn, rConn) + startedAt := time.Now() + go netutil.TunnelConns(ctx, resCh, lConn, rConn, netutil.TunnelDirOut) + go netutil.TunnelConns(ctx, resCh, rConn, lConn, netutil.TunnelDirIn) - for range 2 { - e := <-errCh - if e == nil { - continue + handleErrs := func(errs []error) error { + if len(errs) == 0 { + return nil } - if netutil.IsConnectionResetByPeer(e) { + if slices.ContainsFunc(errs, netutil.IsConnectionResetByPeer) { return netutil.ErrBlocked } - return fmt.Errorf( - "unsuccessful tunnel %s -> %s: %w", - lConn.RemoteAddr(), - rConn.RemoteAddr(), - e, - ) - } - - return nil -} - -// handleProxyHandshake sends "200 Connection Established" -// and reads the subsequent Client Hello. -func (h *HTTPSHandler) handleProxyHandshake( - ctx context.Context, - lConn net.Conn, -) (*proto.TLSMessage, error) { - logger := logging.WithLocalScope(ctx, h.logger, "handshake") - - if err := proto.HTTPConnectionEstablishedResponse().Write(lConn); err != nil { - return nil, err - } - logger.Trace().Msgf("sent 200 connection established -> %s", lConn.RemoteAddr()) - - tlsMsg, err := proto.ReadTLSMessage(lConn) - if err != nil { - return nil, err + return errs[0] } - logger.Debug(). - Int("len", tlsMsg.Len()). - Msgf("client hello received <- %s", lConn.RemoteAddr()) - - return tlsMsg, nil + return netutil.WaitAndLogTunnel( + ctx, + logger, + resCh, + startedAt, + netutil.DescribeRoute(lConn, rConn), + handleErrs, + ) } -// sendClientHello decides whether to spoof and sends the Client Hello accordingly. func (h *HTTPSHandler) sendClientHello( ctx context.Context, - conn net.Conn, + rConn net.Conn, msg *proto.TLSMessage, - httpsOpts *config.HTTPSOptions, + opts *config.HTTPSOptions, ) (int, error) { - logger := logging.WithLocalScope(ctx, h.logger, "client_hello") - return h.desyncer.Send(ctx, logger, conn, msg, httpsOpts) + return h.desyncer.Desync(ctx, h.logger, rConn, msg, opts) } diff --git a/internal/server/http/network.go b/internal/server/http/network.go new file mode 100644 index 00000000..bf8f070b --- /dev/null +++ b/internal/server/http/network.go @@ -0,0 +1,11 @@ +//go:build !darwin + +package http + +import ( + "github.com/rs/zerolog" +) + +func setSystemProxy(logger zerolog.Logger, port uint16) (func() error, error) { + return func() error { return nil }, nil +} diff --git a/internal/server/http/network_darwin.go b/internal/server/http/network_darwin.go new file mode 100644 index 00000000..b2a3d6a8 --- /dev/null +++ b/internal/server/http/network_darwin.go @@ -0,0 +1,106 @@ +//go:build darwin + +package http + +import ( + "errors" + "fmt" + "os/exec" + "strconv" + "strings" + + "github.com/rs/zerolog" + "github.com/xvzc/SpoofDPI/internal/netutil" +) + +const ( + permissionErrorHelpText = "By default SpoofDPI tries to set itself up as a system-wide proxy server.\n" + + "Doing so may require root access on machines with\n" + + "'Settings > Privacy & Security > Advanced > Require" + + " an administrator password to access system-wide settings' enabled.\n" + + "If you do not want SpoofDPI to act as a system-wide proxy, provide" + + " -system-proxy=false." +) + +func setSystemProxy(logger zerolog.Logger, port uint16) (func() error, error) { + network, err := getDefaultNetwork() + if err != nil { + return nil, err + } + + portStr := strconv.Itoa(int(port)) + pacContent := fmt.Sprintf(`function FindProxyForURL(url, host) { + return "PROXY 127.0.0.1:%s; DIRECT"; +}`, portStr) + + pacURL, pacServer, err := netutil.RunPACServer(pacContent) + if err != nil { + return nil, fmt.Errorf("error creating pac server: %w", err) + } + + // Enable Auto Proxy Configuration + // networksetup -setautoproxyurl + if err := networkSetup("-setautoproxyurl", network, pacURL); err != nil { + _ = pacServer.Close() + return nil, fmt.Errorf("setting autoproxyurl: %w", err) + } + + // networksetup -setproxyautodiscovery + if err := networkSetup("-setproxyautodiscovery", network, "on"); err != nil { + _ = pacServer.Close() + return nil, fmt.Errorf("setting proxyautodiscovery: %w", err) + } + + unset := func() error { + _ = pacServer.Close() + + // Disable Auto Proxy Configuration + if err := networkSetup("-setautoproxystate", network, "off"); err != nil { + return fmt.Errorf("unsetting autoproxystate: %w", err) + } + + if err := networkSetup("-setproxyautodiscovery", network, "off"); err != nil { + return fmt.Errorf("unsetting proxyautodiscovery: %w", err) + } + + return nil + } + + return unset, nil +} + +func getDefaultNetwork() (string, error) { + const cmd = "networksetup -listnetworkserviceorder | grep" + + " `(route -n get default | grep 'interface' || route -n get -inet6 default | grep 'interface') | cut -d ':' -f2`" + + " -B 1 | head -n 1 | cut -d ' ' -f 2-" + + out, err := exec.Command("sh", "-c", cmd).Output() + if err != nil { + return "", err + } + + network := strings.TrimSpace(string(out)) + if network == "" { + return "", errors.New("no available networks") + } + return network, nil +} + +func networkSetup(args ...string) error { + cmd := exec.Command("networksetup", args...) + out, err := cmd.CombinedOutput() + if err != nil { + msg := string(out) + if isPermissionError(err) { + msg += permissionErrorHelpText + } + return fmt.Errorf("%s", msg) + } + return nil +} + +func isPermissionError(err error) bool { + var exitErr *exec.ExitError + ok := errors.As(err, &exitErr) + return ok && exitErr.ExitCode() == 14 +} diff --git a/internal/proxy/http/http_proxy.go b/internal/server/http/server.go similarity index 51% rename from internal/proxy/http/http_proxy.go rename to internal/server/http/server.go index 16c8a7ff..4c7d4cb2 100644 --- a/internal/proxy/http/http_proxy.go +++ b/internal/server/http/server.go @@ -2,32 +2,23 @@ package http import ( "context" - "encoding/json" "errors" "fmt" "io" "net" - "time" "github.com/rs/zerolog" + "github.com/samber/lo" "github.com/xvzc/SpoofDPI/internal/config" "github.com/xvzc/SpoofDPI/internal/dns" "github.com/xvzc/SpoofDPI/internal/logging" "github.com/xvzc/SpoofDPI/internal/matcher" "github.com/xvzc/SpoofDPI/internal/netutil" "github.com/xvzc/SpoofDPI/internal/proto" - "github.com/xvzc/SpoofDPI/internal/proxy" - "github.com/xvzc/SpoofDPI/internal/ptr" + "github.com/xvzc/SpoofDPI/internal/server" "github.com/xvzc/SpoofDPI/internal/session" ) -type Destination struct { - Domain string - Addrs []net.IPAddr - Port int - Timeout time.Duration -} - type HTTPProxy struct { logger zerolog.Logger @@ -35,8 +26,11 @@ type HTTPProxy struct { httpHandler *HTTPHandler httpsHandler *HTTPSHandler ruleMatcher matcher.RuleMatcher - serverOpts *config.ServerOptions + appOpts *config.AppOptions + connOpts *config.ConnOptions policyOpts *config.PolicyOptions + + listener net.Listener } func NewHTTPProxy( @@ -45,38 +39,51 @@ func NewHTTPProxy( httpHandler *HTTPHandler, httpsHandler *HTTPSHandler, ruleMatcher matcher.RuleMatcher, - serverOpts *config.ServerOptions, + appOpts *config.AppOptions, + connOpts *config.ConnOptions, policyOpts *config.PolicyOptions, -) proxy.ProxyServer { +) server.Server { return &HTTPProxy{ logger: logger, resolver: resolver, httpHandler: httpHandler, httpsHandler: httpsHandler, ruleMatcher: ruleMatcher, - serverOpts: serverOpts, + appOpts: appOpts, + connOpts: connOpts, policyOpts: policyOpts, } } -func (p *HTTPProxy) ListenAndServe(ctx context.Context, wait chan struct{}) { - <-wait - - logger := p.logger.With().Ctx(ctx).Logger() - - listener, err := net.ListenTCP("tcp", p.serverOpts.ListenAddr) +func (p *HTTPProxy) ListenAndServe( + appctx context.Context, + ready chan<- struct{}, +) error { + listener, err := net.ListenTCP("tcp", p.appOpts.ListenAddr) if err != nil { - p.logger.Fatal(). - Err(err). - Msgf("error creating listener on %s", p.serverOpts.ListenAddr.String()) + return fmt.Errorf( + "error creating listener on %s: %w", + p.appOpts.ListenAddr.String(), + err, + ) } + p.listener = listener - logger.Info(). - Msgf("created a listener on %s", p.serverOpts.ListenAddr) + go func() { + <-appctx.Done() + _ = listener.Close() + }() + + if ready != nil { + close(ready) + } for { conn, err := listener.Accept() if err != nil { + if errors.Is(err, net.ErrClosed) { + return nil // Normal shutdown + } p.logger.Error(). Err(err). Msgf("failed to accept new connection") @@ -88,8 +95,16 @@ func (p *HTTPProxy) ListenAndServe(ctx context.Context, wait chan struct{}) { } } +func (p *HTTPProxy) SetNetworkConfig() (func() error, error) { + return setSystemProxy(p.logger, uint16(p.appOpts.ListenAddr.Port)) +} + +func (p *HTTPProxy) Addr() string { + return p.appOpts.ListenAddr.String() +} + func (p *HTTPProxy) handleNewConnection(ctx context.Context, conn net.Conn) { - logger := logging.WithLocalScope(ctx, p.logger, "conn") + logger := logging.WithLocalScope(ctx, p.logger, "conn-init") ctx, cancel := context.WithCancel(ctx) defer cancel() @@ -104,6 +119,9 @@ func (p *HTTPProxy) handleNewConnection(ctx context.Context, conn net.Conn) { return } + logger.Debug().Str("from", conn.RemoteAddr().String()).Str("host", req.Host). + Msg("new request") + if !req.IsValidMethod() { logger.Warn().Str("method", req.Method).Msg("unsupported method. abort") _ = proto.HTTPNotImplementedResponse().Write(conn) @@ -111,7 +129,7 @@ func (p *HTTPProxy) handleNewConnection(ctx context.Context, conn net.Conn) { return } - domain := req.ExtractDomain() + host := req.ExtractHost() dstPort, err := req.ExtractPort() if err != nil { logger.Warn().Str("host", req.Host).Msg("failed to extract port") @@ -120,39 +138,35 @@ func (p *HTTPProxy) handleNewConnection(ctx context.Context, conn net.Conn) { return } - ctx = session.WithHostInfo(ctx, domain) - logger = logger.With().Ctx(ctx).Logger() - logger.Debug(). Str("method", req.Method). Str("from", conn.RemoteAddr().String()). Msg("new request") - nameMatch := p.ruleMatcher.Search( - &matcher.Selector{Kind: matcher.MatchKindDomain, Domain: ptr.FromValue(domain)}, - ) - if nameMatch != nil && logger.GetLevel() == zerolog.TraceLevel { - jsonAttrs, _ := json.Marshal(nameMatch) - logger.Trace().RawJSON("values", jsonAttrs).Msg("name match") - } + var addrs []net.IP + var nameMatch *config.Rule + if net.ParseIP(host) != nil { + addrs = []net.IP{net.ParseIP(host)} + logger.Trace().Msgf("skipping dns lookup for non-domain host %q", host) + } else { + nameMatch = p.ruleMatcher.Search( + &matcher.Selector{Kind: matcher.MatchKindDomain, Domain: lo.ToPtr(host)}, + ) - t1 := time.Now() - rSet, err := p.resolver.Resolve(ctx, domain, nil, nameMatch) - dt := time.Since(t1).Milliseconds() - if err != nil { - _ = proto.HTTPBadGatewayResponse().Write(conn) - logging.ErrorUnwrapped(&logger, "dns lookup failed", err) + rSet, err := p.resolver.Resolve(ctx, host, nil, nameMatch) + if err != nil { + _ = proto.HTTPBadGatewayResponse().Write(conn) + // logging.ErrorUnwrapped is not available, using standard error logging + logger.Error().Err(err).Msgf("dns lookup failed for %s", host) - return - } + return + } - logger.Debug(). - Int("cnt", len(rSet.Addrs)). - Str("took", fmt.Sprintf("%dms", dt)). - Msgf("dns lookup ok") + addrs = rSet.Addrs + } // Avoid recursively querying self. - ok, err := netutil.ValidateDestination(rSet.Addrs, dstPort, p.serverOpts.ListenAddr) + ok, err := netutil.ValidateDestination(addrs, dstPort, p.appOpts.ListenAddr) if err != nil { logger.Debug().Err(err).Msg("error validating dst addrs") if !ok { @@ -161,24 +175,19 @@ func (p *HTTPProxy) handleNewConnection(ctx context.Context, conn net.Conn) { } var selectors []*matcher.Selector - for _, v := range rSet.Addrs { + for _, v := range addrs { selectors = append(selectors, &matcher.Selector{ Kind: matcher.MatchKindAddr, - IP: ptr.FromValue(v.IP), - Port: ptr.FromValue(uint16(dstPort)), + IP: lo.ToPtr(v), + Port: lo.ToPtr(uint16(dstPort)), }) } addrMatch := p.ruleMatcher.SearchAll(selectors) - if addrMatch != nil && logger.GetLevel() == zerolog.TraceLevel { - jsonAttrs, _ := json.Marshal(addrMatch) - logger.Trace().RawJSON("values", jsonAttrs).Msg("addr match") - } bestMatch := matcher.GetHigherPriorityRule(addrMatch, nameMatch) if bestMatch != nil && logger.GetLevel() == zerolog.TraceLevel { - jsonAttrs, _ := json.Marshal(bestMatch) - logger.Trace().RawJSON("values", jsonAttrs).Msg("best match") + logger.Trace().RawJSON("summary", bestMatch.JSON()).Msg("match") } if bestMatch != nil && *bestMatch.Block { @@ -186,11 +195,11 @@ func (p *HTTPProxy) handleNewConnection(ctx context.Context, conn net.Conn) { return } - dst := &Destination{ - Domain: domain, - Addrs: rSet.Addrs, + dst := &netutil.Destination{ + Domain: host, // Updated from Domain to Host + Addrs: addrs, Port: dstPort, - Timeout: *p.serverOpts.Timeout, + Timeout: *p.connOpts.TCPTimeout, } var handleErr error @@ -205,27 +214,4 @@ func (p *HTTPProxy) handleNewConnection(ctx context.Context, conn net.Conn) { } logger.Warn().Err(handleErr).Msg("error handling request") - if !errors.Is(handleErr, netutil.ErrBlocked) { // Early exit if not blocked - return - } - - // ┌─────────────┐ - // │ AUTO config │ - // └─────────────┘ - if bestMatch != nil && logger.GetLevel() == zerolog.TraceLevel { - logger.Info().Msg("skipping auto-config (duplicate policy)") - return - } - - // Perform auto config if enabled and RuleTemplate is not nil - if *p.policyOpts.Auto && p.policyOpts.Template != nil { - newRule := p.policyOpts.Template.Clone() - newRule.Match = &config.MatchAttrs{Domains: []string{domain}} - - if err := p.ruleMatcher.Add(newRule); err != nil { - logger.Info().Err(err).Msg("failed to add config automatically") - } else { - logger.Info().Msg("automatically added to config") - } - } } diff --git a/internal/server/server.go b/internal/server/server.go new file mode 100644 index 00000000..46d2225a --- /dev/null +++ b/internal/server/server.go @@ -0,0 +1,13 @@ +package server + +import "context" + +// Server represents a core component that processes network traffic. +// ListenAndServe blocks until ctx is cancelled, then releases all resources. +type Server interface { + ListenAndServe(ctx context.Context, ready chan<- struct{}) error + SetNetworkConfig() (func() error, error) + + // Addr returns the network address or interface name the server is bound to + Addr() string +} diff --git a/internal/server/socks5/bind.go b/internal/server/socks5/bind.go new file mode 100644 index 00000000..b1dce109 --- /dev/null +++ b/internal/server/socks5/bind.go @@ -0,0 +1,98 @@ +package socks5 + +import ( + "context" + "net" + "time" + + "github.com/rs/zerolog" + "github.com/xvzc/SpoofDPI/internal/logging" + "github.com/xvzc/SpoofDPI/internal/netutil" + "github.com/xvzc/SpoofDPI/internal/proto" +) + +type BindHandler struct { + logger zerolog.Logger +} + +func NewBindHandler(logger zerolog.Logger) *BindHandler { + return &BindHandler{ + logger: logger, + } +} + +func (h *BindHandler) Handle( + ctx context.Context, + conn net.Conn, + req *proto.SOCKS5Request, +) error { + logger := logging.WithLocalScope(ctx, h.logger, "bind") + + // 1. Listen on a random TCP port + listener, err := net.ListenTCP("tcp", &net.TCPAddr{IP: net.IPv4zero, Port: 0}) + if err != nil { + logger.Error().Err(err).Msg("failed to create bind listener") + _ = proto.SOCKS5FailureResponse().Write(conn) + return err + } + defer func() { _ = listener.Close() }() + + logger.Debug(). + Str("addr", listener.Addr().String()). + Str("network", listener.Addr().Network()). + Msg("new listener") + + lAddr := listener.Addr().(*net.TCPAddr) + + // 2. First Reply: Send the address/port we are listening on + err = proto.SOCKS5SuccessResponse().Bind(lAddr.IP).Port(lAddr.Port).Write(conn) + if err != nil { + logger.Error().Err(err).Msg("failed to write first bind reply") + return err + } + + logger.Debug(). + Str("bind_addr", lAddr.String()). + Msg("waiting for incoming connection") + + // 3. Accept Incoming Connection + // The client should now tell the application server to connect to lAddr. + remoteConn, err := listener.Accept() + if err != nil { + logger.Error().Err(err).Msg("failed to accept incoming connection") + _ = proto.SOCKS5FailureResponse().Write(conn) + return err + } + defer netutil.CloseConns(remoteConn) + + rAddr := remoteConn.RemoteAddr().(*net.TCPAddr) + + logger.Debug(). + Str("remote_addr", rAddr.String()). + Msg("accepted incoming connection") + + // 4. Second Reply: Send the address/port of the connecting host + err = proto.SOCKS5SuccessResponse().Bind(rAddr.IP).Port(rAddr.Port).Write(conn) + if err != nil { + logger.Error().Err(err).Msg("failed to write second bind reply") + return err + } + + // 5. Start bi-directional tunneling + resCh := make(chan netutil.TransferResult, 2) + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + startedAt := time.Now() + go netutil.TunnelConns(ctx, resCh, remoteConn, conn, netutil.TunnelDirOut) + go netutil.TunnelConns(ctx, resCh, conn, remoteConn, netutil.TunnelDirIn) + + return netutil.WaitAndLogTunnel( + ctx, + logger, + resCh, + startedAt, + netutil.DescribeRoute(conn, remoteConn), + nil, + ) +} diff --git a/internal/server/socks5/connect.go b/internal/server/socks5/connect.go new file mode 100644 index 00000000..b83117ba --- /dev/null +++ b/internal/server/socks5/connect.go @@ -0,0 +1,207 @@ +package socks5 + +import ( + "context" + "fmt" + "io" + "net" + "time" + + "github.com/rs/zerolog" + "github.com/samber/lo" + "github.com/xvzc/SpoofDPI/internal/config" + "github.com/xvzc/SpoofDPI/internal/desync" + "github.com/xvzc/SpoofDPI/internal/logging" + "github.com/xvzc/SpoofDPI/internal/netutil" + "github.com/xvzc/SpoofDPI/internal/packet" + "github.com/xvzc/SpoofDPI/internal/proto" +) + +type ConnectHandler struct { + logger zerolog.Logger + desyncer *desync.TLSDesyncer + sniffer packet.Sniffer + appOpts *config.AppOptions + defaultConnOpts *config.ConnOptions + defaultHTTPSOpts *config.HTTPSOptions +} + +func NewConnectHandler( + logger zerolog.Logger, + desyncer *desync.TLSDesyncer, + sniffer packet.Sniffer, + appOpts *config.AppOptions, + defaultConnOpts *config.ConnOptions, + defaultHTTPSOpts *config.HTTPSOptions, +) *ConnectHandler { + return &ConnectHandler{ + logger: logger, + desyncer: desyncer, + sniffer: sniffer, + appOpts: appOpts, + defaultConnOpts: defaultConnOpts, + defaultHTTPSOpts: defaultHTTPSOpts, + } +} + +func (h *ConnectHandler) Handle( + ctx context.Context, + lConn net.Conn, + req *proto.SOCKS5Request, + dst *netutil.Destination, + rule *config.Rule, +) error { + httpsOpts := h.defaultHTTPSOpts.Clone() + connOpts := h.defaultConnOpts.Clone() + if rule != nil { + httpsOpts = httpsOpts.Merge(rule.HTTPS) + connOpts = connOpts.Merge(rule.Conn) + } + + logger := logging.WithLocalScope(ctx, h.logger, "connect") + + // 1. Validate Destination + ok, err := netutil.ValidateDestination(dst.Addrs, dst.Port, h.appOpts.ListenAddr) + if err != nil { + logger.Debug().Err(err).Msg("error determining if valid destination") + if !ok { + _ = proto.SOCKS5FailureResponse().Write(lConn) + return err + } + } + + // 2. Check if blocked + if rule != nil && *rule.Block { + logger.Debug().Msg("request is blocked by policy") + _ = proto.SOCKS5FailureResponse().Write(lConn) + return netutil.ErrBlocked + } + + dst.Timeout = *connOpts.TCPTimeout + + rConn, err := netutil.DialFastest(ctx, "tcp", dst) + if err != nil { + _ = proto.SOCKS5FailureResponse().Write(lConn) + return err + } + defer netutil.CloseConns(rConn) + + // 3. Send Success Response + err = proto.SOCKS5SuccessResponse().Bind(net.IPv4zero).Port(0).Write(lConn) + if err != nil { + logger.Error().Err(err).Msg("failed to write socks5 success reply") + return err + } + + logger.Debug().Msgf("new remote conn -> %s", rConn.RemoteAddr()) + + // Wrap lConn with a buffered reader to peek for TLS + bufConn := netutil.NewBufferedConn(lConn) + + // Peek first byte to check for TLS Handshake (0x16) + // We try to peek 1 byte. + b, err := bufConn.Peek(1) + if err == nil && b[0] == byte(proto.TLSHandshake) { // 0x16 + + if h.sniffer != nil && lo.FromPtr(httpsOpts.FakeCount) > 0 { + h.sniffer.RegisterUntracked(dst.Addrs) + } + + return h.handleHTTPS(ctx, bufConn, rConn, httpsOpts) + } + + // If not TLS, fall back to pure TCP tunnel + logger.Debug().Msg("not a tls handshake. fallback to pure tcp") + + resCh := make(chan netutil.TransferResult, 2) + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + startedAt := time.Now() + go netutil.TunnelConns(ctx, resCh, rConn, bufConn, netutil.TunnelDirOut) + go netutil.TunnelConns(ctx, resCh, bufConn, rConn, netutil.TunnelDirIn) + + return netutil.WaitAndLogTunnel( + ctx, + logger, + resCh, + startedAt, + netutil.DescribeRoute(bufConn, rConn), + nil, + ) +} + +func (h *ConnectHandler) handleHTTPS( + ctx context.Context, + lConn net.Conn, // This is expected to be the BufferedConn + rConn net.Conn, + opts *config.HTTPSOptions, +) error { + logger := logging.WithLocalScope(ctx, h.logger, "connect(tls)") + + // Read the first message from the client (expected to be ClientHello) + tlsMsg, err := proto.ReadTLSMessage(lConn) + if err != nil { + if err == io.EOF || err.Error() == "unexpected EOF" { + return nil + } + logger.Trace().Err(err).Msgf("failed to read first message from client") + return err + } + + // It starts with 0x16, but is it a ClientHello? + if !tlsMsg.IsClientHello() { + logger.Debug(). + Int("len", tlsMsg.Len()). + Msg("not a client hello. fallback to pure tcp") + + // Forward the initial bytes we read + if _, err := rConn.Write(tlsMsg.Raw()); err != nil { + return fmt.Errorf("failed to write initial bytes to remote: %w", err) + } + } else { + logger.Debug(). + Int("len", tlsMsg.Len()). + Msgf("client hello received <- %s", lConn.RemoteAddr()) + + // Send ClientHello to the remote server (with desync if configured) + n, err := h.sendClientHello(ctx, rConn, tlsMsg, opts) + if err != nil { + return fmt.Errorf("failed to send client hello: %w", err) + } + + logger.Debug(). + Int("len", n). + Msgf("sent client hello -> %s", rConn.RemoteAddr()) + } + + resCh := make(chan netutil.TransferResult, 2) + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + startedAt := time.Now() + go netutil.TunnelConns(ctx, resCh, rConn, lConn, netutil.TunnelDirOut) + go netutil.TunnelConns(ctx, resCh, lConn, rConn, netutil.TunnelDirIn) + + return netutil.WaitAndLogTunnel( + ctx, + logger, + resCh, + startedAt, + netutil.DescribeRoute(lConn, rConn), + nil, + ) +} + +func (h *ConnectHandler) sendClientHello( + ctx context.Context, + rConn net.Conn, + msg *proto.TLSMessage, + opts *config.HTTPSOptions, +) (int, error) { + if lo.FromPtr(opts.Skip) { + return rConn.Write(msg.Raw()) + } + + return h.desyncer.Desync(ctx, h.logger, rConn, msg, opts) +} diff --git a/internal/server/socks5/network.go b/internal/server/socks5/network.go new file mode 100644 index 00000000..d53f5133 --- /dev/null +++ b/internal/server/socks5/network.go @@ -0,0 +1,11 @@ +//go:build !darwin + +package socks5 + +import ( + "github.com/rs/zerolog" +) + +func setSystemProxy(logger zerolog.Logger, port uint16) (func() error, error) { + return func() error { return nil }, nil +} diff --git a/internal/server/socks5/network_darwin.go b/internal/server/socks5/network_darwin.go new file mode 100644 index 00000000..bed35848 --- /dev/null +++ b/internal/server/socks5/network_darwin.go @@ -0,0 +1,105 @@ +//go:build darwin + +package socks5 + +import ( + "errors" + "fmt" + "os/exec" + "strconv" + "strings" + + "github.com/rs/zerolog" + "github.com/xvzc/SpoofDPI/internal/netutil" +) + +const ( + permissionErrorHelpText = "By default SpoofDPI tries to set itself up as a system-wide proxy server.\n" + + "Doing so may require root access on machines with\n" + + "'Settings > Privacy & Security > Advanced > Require" + + " an administrator password to access system-wide settings' enabled.\n" + + "If you do not want SpoofDPI to act as a system-wide proxy, provide" + + " -system-proxy=false." +) + +func setSystemProxy(logger zerolog.Logger, port uint16) (func() error, error) { + network, err := getDefaultNetwork() + if err != nil { + return nil, err + } + + portStr := strconv.Itoa(int(port)) + pacContent := fmt.Sprintf(`function FindProxyForURL(url, host) { + return "SOCKS5 127.0.0.1:%s; DIRECT"; +}`, portStr) + + pacURL, pacServer, err := netutil.RunPACServer(pacContent) + if err != nil { + return nil, fmt.Errorf("error creating pac server: %w", err) + } + + // Enable Auto Proxy Configuration + // networksetup -setautoproxyurl + if err := networkSetup("-setautoproxyurl", network, pacURL); err != nil { + _ = pacServer.Close() + return nil, fmt.Errorf("setting autoproxyurl: %w", err) + } + + // networksetup -setproxyautodiscovery + if err := networkSetup("-setproxyautodiscovery", network, "on"); err != nil { + _ = pacServer.Close() + return nil, fmt.Errorf("setting proxyautodiscovery: %w", err) + } + + unset := func() error { + _ = pacServer.Close() + + if err := networkSetup("-setautoproxystate", network, "off"); err != nil { + return fmt.Errorf("unsetting autoproxystate: %w", err) + } + + if err := networkSetup("-setproxyautodiscovery", network, "off"); err != nil { + return fmt.Errorf("unsetting proxyautodiscovery: %w", err) + } + + return nil + } + + return unset, nil +} + +func getDefaultNetwork() (string, error) { + const cmd = "networksetup -listnetworkserviceorder | grep" + + " `(route -n get default | grep 'interface' || route -n get -inet6 default | grep 'interface') | cut -d ':' -f2`" + + " -B 1 | head -n 1 | cut -d ' ' -f 2-" + + out, err := exec.Command("sh", "-c", cmd).Output() + if err != nil { + return "", err + } + + network := strings.TrimSpace(string(out)) + if network == "" { + return "", errors.New("no available networks") + } + return network, nil +} + +func networkSetup(args ...string) error { + cmd := exec.Command("networksetup", args...) + out, err := cmd.CombinedOutput() + if err != nil { + msg := string(out) + if isPermissionError(err) { + msg += permissionErrorHelpText + } + return fmt.Errorf("%s", msg) + } + return nil +} + +func isPermissionError(err error) bool { + var exitErr *exec.ExitError + ok := errors.As(err, &exitErr) + return ok && exitErr.ExitCode() == 14 +} diff --git a/internal/server/socks5/server.go b/internal/server/socks5/server.go new file mode 100644 index 00000000..bded9e9b --- /dev/null +++ b/internal/server/socks5/server.go @@ -0,0 +1,260 @@ +package socks5 + +import ( + "bufio" + "bytes" + "context" + "errors" + "fmt" + "io" + "net" + "net/http" + + "github.com/rs/zerolog" + "github.com/samber/lo" + "github.com/xvzc/SpoofDPI/internal/config" + "github.com/xvzc/SpoofDPI/internal/dns" + "github.com/xvzc/SpoofDPI/internal/logging" + "github.com/xvzc/SpoofDPI/internal/matcher" + "github.com/xvzc/SpoofDPI/internal/netutil" + "github.com/xvzc/SpoofDPI/internal/proto" + "github.com/xvzc/SpoofDPI/internal/server" + "github.com/xvzc/SpoofDPI/internal/session" +) + +type SOCKS5Proxy struct { + logger zerolog.Logger + + resolver dns.Resolver + ruleMatcher matcher.RuleMatcher + connectHandler *ConnectHandler + bindHandler *BindHandler + udpAssociateHandler *UdpAssociateHandler + + appOpts *config.AppOptions + connOpts *config.ConnOptions + policyOpts *config.PolicyOptions +} + +func NewSOCKS5Proxy( + logger zerolog.Logger, + resolver dns.Resolver, + ruleMatcher matcher.RuleMatcher, + connectHandler *ConnectHandler, + bindHandler *BindHandler, + udpAssociateHandler *UdpAssociateHandler, + appOpts *config.AppOptions, + connOpts *config.ConnOptions, + policyOpts *config.PolicyOptions, +) server.Server { + return &SOCKS5Proxy{ + logger: logger, + resolver: resolver, + ruleMatcher: ruleMatcher, + connectHandler: connectHandler, + bindHandler: bindHandler, + udpAssociateHandler: udpAssociateHandler, + appOpts: appOpts, + connOpts: connOpts, + policyOpts: policyOpts, + } +} + +func (p *SOCKS5Proxy) ListenAndServe( + appctx context.Context, + ready chan<- struct{}, +) error { + listener, err := net.ListenTCP("tcp", p.appOpts.ListenAddr) + if err != nil { + return fmt.Errorf( + "error creating listener on %s: %w", + p.appOpts.ListenAddr.String(), + err, + ) + } + + go func() { + <-appctx.Done() + _ = listener.Close() + }() + + if ready != nil { + close(ready) + } + + for { + conn, err := listener.Accept() + if err != nil { + if errors.Is(err, net.ErrClosed) { + return nil + } + p.logger.Error(). + Err(err). + Msg("failed to accept new connection") + continue + } + + go p.handleConnection(session.WithNewTraceID(appctx), conn) + } +} + +func (p *SOCKS5Proxy) SetNetworkConfig() (func() error, error) { + return setSystemProxy(p.logger, uint16(p.appOpts.ListenAddr.Port)) +} + +func (p *SOCKS5Proxy) Addr() string { + return p.appOpts.ListenAddr.String() +} + +func (p *SOCKS5Proxy) handleConnection(ctx context.Context, conn net.Conn) { + logger := logging.WithLocalScope(ctx, p.logger, "socks5") + + ctx, cancel := context.WithCancel(ctx) + defer cancel() + defer netutil.CloseConns(conn) + + // 1. Negotiation Phase + if err := p.negotiate(logger, conn); err != nil { + logger.Debug().Err(err).Msg("negotiation failed") + return + } + + // 2. Request Phase + req, err := proto.ReadSocks5Request(conn) + if err != nil { + if err != io.EOF { + logger.Warn().Err(err).Msg("failed to read request") + } + return + } + + // ctx = session.WithHostInfo(ctx, req.Host()) + // logger = logger.With().Ctx(ctx).Logger() + + logger.Trace(). + Uint8("cmd", req.Cmd). + Int("port", req.Port). + Str("fqdn", req.FQDN). + Str("ip", req.IP.String()). + Msg("new request") + + var addrs []net.IP + var nameMatch *config.Rule + + if req.IP != nil { + addrs = []net.IP{req.IP} + } else if req.ATYP == proto.SOCKS5AddrTypeFQDN && len(req.FQDN) > 1 { + nameMatch = p.ruleMatcher.Search( + &matcher.Selector{ + Kind: matcher.MatchKindDomain, + Domain: lo.ToPtr(req.FQDN), // req.Domain -> req.FQDN + }, + ) + + // Resolve Domain + rSet, err := p.resolver.Resolve(ctx, req.FQDN, nil, nameMatch) + if err != nil { + logger.Error().Str("domain", req.FQDN).Err(err).Msgf("dns lookup failed") + return + } + + addrs = rSet.Addrs + } else { + logger.Trace().Msg("no addrs specified for this request. skipping") + } + + var selectors []*matcher.Selector + for _, v := range addrs { + selectors = append(selectors, &matcher.Selector{ + Kind: matcher.MatchKindAddr, + IP: lo.ToPtr(v), + Port: lo.ToPtr(uint16(req.Port)), + }) + } + + addrMatch := p.ruleMatcher.SearchAll(selectors) + + bestMatch := matcher.GetHigherPriorityRule(addrMatch, nameMatch) + if bestMatch != nil && logger.GetLevel() == zerolog.TraceLevel { + logger.Trace().RawJSON("summary", bestMatch.JSON()).Msg("match") + } + + switch req.Cmd { + case proto.SOCKS5CmdConnect: + dst := &netutil.Destination{ + Domain: req.FQDN, + Addrs: addrs, + Port: req.Port, + Timeout: *p.connOpts.TCPTimeout, + } + if err = p.connectHandler.Handle(ctx, conn, req, dst, bestMatch); err != nil { + return // Handler logs error + } + + case proto.SOCKS5CmdBind: + // Bind command usually implies user wants the server to listen. + // Destination address in request is usually zero or the IP of the client, + // but SOCKS5 spec says "DST.ADDR and DST.PORT fields of the BIND request contains + // the address and port of the party the client expects to connect to the application server." + // For our basic BindHandler, we might not strictly validate this yet. + if err = p.bindHandler.Handle(ctx, conn, req); err != nil { + return + } + + case proto.SOCKS5CmdUDPAssociate: + // UDP Associate usually doesn't have destination info in the request + if err = p.udpAssociateHandler.Handle(ctx, conn, req, nil, nil); err != nil { + logger.Error().Err(err).Msg("failed to handle udp_associate") + return + } + default: + err = proto.SOCKS5CommandNotSupportedResponse().Write(conn) + logger.Warn().Uint8("cmd", req.Cmd).Msg("unsupported command") + } + + if err == nil { + return + } + + logger.Error().Err(err).Msg("failed to handle") +} + +func (p *SOCKS5Proxy) negotiate(logger zerolog.Logger, conn net.Conn) error { + header := make([]byte, 2) + if _, err := io.ReadFull(conn, header); err != nil { + return err + } + + if header[0] != proto.SOCKSVersion { + // Check if the first byte is 'C'(67), and the second byte is 'O'(79) + // indicating a potential HTTP CONNECT method + if len(header) > 1 && header[0] == 67 && header[1] == 79 { + // Reconstruct the stream using the already read header and the remaining connection + // This allows http.ReadRequest to parse the full request line including the method + mr := io.MultiReader(bytes.NewReader(header), conn) + bufReader := bufio.NewReader(mr) + + // Parse the HTTP request headers without waiting for EOF + // ReadRequest reads only the header section and stops + req, err := http.ReadRequest(bufReader) + if err != nil { + return fmt.Errorf("invalid request(unknown): %w", err) + } + + // req.Host contains the target domain (e.g., "google.com:443") + return fmt.Errorf("invalid request: http connect to %s", req.Host) + } + + return fmt.Errorf("invalid version: %d", header[0]) + } + + nMethods := int(header[1]) + methods := make([]byte, nMethods) + if _, err := io.ReadFull(conn, methods); err != nil { + return err + } + + // Respond: Version 5, Method NoAuth(0) + _, err := conn.Write([]byte{proto.SOCKSVersion, proto.SOCKS5AuthNone}) + return err +} diff --git a/internal/server/socks5/udp_associate.go b/internal/server/socks5/udp_associate.go new file mode 100644 index 00000000..2061c875 --- /dev/null +++ b/internal/server/socks5/udp_associate.go @@ -0,0 +1,295 @@ +package socks5 + +import ( + "context" + "encoding/binary" + "fmt" + "io" + "net" + + "github.com/rs/zerolog" + "github.com/xvzc/SpoofDPI/internal/config" + "github.com/xvzc/SpoofDPI/internal/desync" + "github.com/xvzc/SpoofDPI/internal/logging" + "github.com/xvzc/SpoofDPI/internal/netutil" + "github.com/xvzc/SpoofDPI/internal/proto" +) + +type UdpAssociateHandler struct { + logger zerolog.Logger + pool *netutil.ConnRegistry[netutil.NATKey] + desyncer *desync.UDPDesyncer + defaultUDPOpts *config.UDPOptions +} + +func NewUdpAssociateHandler( + logger zerolog.Logger, + pool *netutil.ConnRegistry[netutil.NATKey], + desyncer *desync.UDPDesyncer, + defaultUDPOpts *config.UDPOptions, +) *UdpAssociateHandler { + return &UdpAssociateHandler{ + logger: logger, + pool: pool, + desyncer: desyncer, + defaultUDPOpts: defaultUDPOpts, + } +} + +func (h *UdpAssociateHandler) Handle( + ctx context.Context, + lConn net.Conn, + req *proto.SOCKS5Request, + dst *netutil.Destination, + rule *config.Rule, +) error { + logger := logging.WithLocalScope(ctx, h.logger, "udp_associate") + + // 1. Listen on a random UDP port + lTCPAddr := lConn.LocalAddr().(*net.TCPAddr) // SOCKS5 listens on TCP + lUDPConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: lTCPAddr.IP, Port: 0}) + if err != nil { + logger.Error().Err(err).Msg("failed to create udp listener") + _ = proto.SOCKS5FailureResponse().Write(lConn) + return err + } + defer netutil.CloseConns(lUDPConn) + + logger.Debug(). + Str("addr", lUDPConn.LocalAddr().String()). + Str("network", lUDPConn.LocalAddr().Network()). + Msg("new conn") + + lUDPAddr := lUDPConn.LocalAddr().(*net.UDPAddr) + + logger.Debug(). + Str("bind_addr", lUDPAddr.String()). + Msg("socks5 udp associate established") + + // 2. Reply with the bound address + err = proto.SOCKS5SuccessResponse().Bind(lUDPAddr.IP).Port(lUDPAddr.Port).Write(lConn) + if err != nil { + logger.Error().Err(err).Msg("failed to write socks5 success reply") + return err + } + + // 3. Keep TCP Alive & Relay + // According to [RFC1928](https://datatracker.ietf.org/doc/html/rfc1928#section-6), + // > A UDP association terminates when the TCP connection that the UDP + // > ASSOCIATE request arrived on terminates. + // Therefore, we need to monitor TCP for closure. + done := make(chan struct{}) + go func() { + _, _ = io.Copy(io.Discard, lConn) // Block until TCP closes + close(done) // Close the channel to signal UDP handler to exit + _ = lUDPConn.Close() // Force ReadFromUDP to unblock and avoid goroutine leak + }() + + buf := make([]byte, 65535) + rTCPAddr := lConn.RemoteAddr().(*net.TCPAddr).IP + + for { + // Wait for data + n, srcAddr, err := lUDPConn.ReadFromUDP(buf) + if err != nil { + // Normal closure check + select { + case <-done: + return nil + default: + if err != io.EOF { + logger.Debug().Err(err).Msg("error reading from udp") + } + return err + } + } + + // Security: Only accept UDP packets from the same IP that established the TCP connection + if !srcAddr.IP.Equal(rTCPAddr) { + logger.Debug(). + Str("expected", rTCPAddr.String()). + Str("actual", srcAddr.IP.String()). + Msg("dropped udp packet from unexpected ip") + continue + } + + // Outbound: Client -> Proxy -> Target + dstAddrStr, payload, err := parseUDPHeader(buf[:n]) + if err != nil { + logger.Warn().Err(err).Msg("failed to parse socks5 udp header") + continue + } + + // Resolve address to construct Destination + dstAddr, err := net.ResolveUDPAddr("udp", dstAddrStr) + if err != nil { + logger.Warn(). + Err(err). + Str("addr", dstAddrStr). + Msg("failed to resolve udp target") + continue + } + + // Key: Client Addr -> Target Addr (Zero Allocation Struct) + key := netutil.NewNATKey(srcAddr.IP, srcAddr.Port, dstAddr.IP, dstAddr.Port) + + // Check if connection already exists in the pool + if cachedConn, ok := h.pool.Fetch(key); ok { + logger.Debug(). + Str("key", fmt.Sprintf("%s > %s", srcAddr.String(), dstAddr.String())). + Msg("session cache hit") + + // Write payload to target + if _, err := cachedConn.Write(payload); err != nil { + logger.Warn().Err(err).Msg("failed to write udp to target") + } + continue + } else { + logger.Debug(). + Str("key", fmt.Sprintf("%s > %s", srcAddr.String(), dstAddr.String())). + Msg("session cache miss") + } + + dst := &netutil.Destination{ + Addrs: []net.IP{dstAddr.IP}, + Port: dstAddr.Port, + } + + rRawConn, err := netutil.DialFastest(ctx, "udp", dst) + if err != nil { + logger.Warn().Err(err).Str("addr", dstAddrStr).Msg("failed to dial udp target") + continue + } + + // Add to pool (pool handles LRU eviction and deadline) + // returns IdleTimeoutConn with the actual net.Conn inside + rConn := h.pool.Store(key, rRawConn) + + // Apply UDP options from rule if matched + udpOpts := h.defaultUDPOpts.Clone() + if rule != nil && rule.UDP != nil { + udpOpts = udpOpts.Merge(rule.UDP) + } + + // Send fake packets before real payload (UDP desync) + if h.desyncer != nil { + _, _ = h.desyncer.Desync(ctx, lUDPConn, rConn.Conn, udpOpts) + } + + // Start a goroutine to read from the target and forward to the client. + // rConn is a connected UDP socket, so all responses come from the single remote. + // Using rConn.Read() (via IdleTimeoutConn) properly extends the idle deadline + // on each inbound packet, preventing premature timeout on asymmetric flows. + // dstAddr is already *net.UDPAddr (resolved above), same as rConn.RemoteAddr(). + go h.relayInboundUDP(logger, lUDPConn, rConn, srcAddr, dstAddr, key) + + // Write payload to target + if _, err := rConn.Write(payload); err != nil { + logger.Warn().Err(err).Msg("failed to write udp to target") + } + } +} + +func (h *UdpAssociateHandler) relayInboundUDP( + logger zerolog.Logger, + lUDPConn *net.UDPConn, + rConn *netutil.IdleTimeoutConn, + clientAddr *net.UDPAddr, + targetAddr *net.UDPAddr, + key netutil.NATKey, +) { + respBuf := make([]byte, 65535) + for { + // Read via IdleTimeoutConn so each inbound packet extends the deadline. + n, err := rConn.Read(respBuf) + if err != nil { + // Connection closed or network issues + h.pool.Evict(key) + return + } + + // Inbound: Target -> Proxy -> Client + // Wrap with SOCKS5 Header + header := createUDPHeaderFromAddr(targetAddr) + response := append(header, respBuf[:n]...) + + if _, err := lUDPConn.WriteToUDP(response, clientAddr); err != nil { + // If we can't write back to the client, it might be gone or network issue. + // Exit this goroutine to avoid busy looping. + logger.Warn().Err(err).Msg("failed to write udp to client") + return + } + } +} + +func parseUDPHeader(b []byte) (string, []byte, error) { + if len(b) < 4 { + return "", nil, fmt.Errorf("header too short") + } + // RSV(2) FRAG(1) ATYP(1) + if b[0] != 0 || b[1] != 0 { + return "", nil, fmt.Errorf("invalid rsv") + } + frag := b[2] + if frag != 0 { + return "", nil, fmt.Errorf("fragmentation not supported") + } + + atyp := b[3] + var host string + var pos int + + switch atyp { + case proto.SOCKS5AddrTypeIPv4: + if len(b) < 10 { + return "", nil, fmt.Errorf("header too short for ipv4") + } + host = net.IP(b[4:8]).String() + pos = 8 + case proto.SOCKS5AddrTypeIPv6: + if len(b) < 22 { + return "", nil, fmt.Errorf("header too short for ipv6") + } + host = net.IP(b[4:20]).String() + pos = 20 + case proto.SOCKS5AddrTypeFQDN: + if len(b) < 5 { + return "", nil, fmt.Errorf("header too short for fqdn") + } + l := int(b[4]) + if len(b) < 5+l+2 { + return "", nil, fmt.Errorf("header too short for fqdn data") + } + host = string(b[5 : 5+l]) + pos = 5 + l + default: + return "", nil, fmt.Errorf("unsupported atyp: %d", atyp) + } + + port := binary.BigEndian.Uint16(b[pos : pos+2]) + pos += 2 + + addr := net.JoinHostPort(host, fmt.Sprintf("%d", port)) + return addr, b[pos:], nil +} + +func createUDPHeaderFromAddr(addr *net.UDPAddr) []byte { + // RSV(2) FRAG(1) ATYP(1) ... + buf := make([]byte, 0, 24) + buf = append(buf, 0, 0, 0) // RSV, FRAG + + ip4 := addr.IP.To4() + if ip4 != nil { + buf = append(buf, proto.SOCKS5AddrTypeIPv4) + buf = append(buf, ip4...) + } else { + buf = append(buf, proto.SOCKS5AddrTypeIPv6) + buf = append(buf, addr.IP.To16()...) + } + + portBuf := make([]byte, 2) + binary.BigEndian.PutUint16(portBuf, uint16(addr.Port)) + buf = append(buf, portBuf...) + + return buf +} diff --git a/internal/server/tun/network.go b/internal/server/tun/network.go new file mode 100644 index 00000000..1cbfb5c4 --- /dev/null +++ b/internal/server/tun/network.go @@ -0,0 +1,23 @@ +//go:build !darwin && !linux && !freebsd + +package tun + +func SetRoute(iface string, subnets []string) error { + return nil +} + +func UnsetRoute(iface string, subnets []string) error { + return nil +} + +func SetInterfaceAddress(iface string, local string, remote string) error { + return nil +} + +func UnsetGatewayRoute(gateway, iface string) error { + return nil +} + +func SetGatewayRoute(gateway, iface string) error { + return nil +} diff --git a/internal/server/tun/network_bsd.go b/internal/server/tun/network_bsd.go new file mode 100644 index 00000000..c4bc2f3d --- /dev/null +++ b/internal/server/tun/network_bsd.go @@ -0,0 +1,121 @@ +//go:build darwin || freebsd + +package tun + +import ( + "fmt" + "os/exec" + "strings" +) + +// SetRoute configures network routes for specified subnets +func SetRoute(iface string, subnets []string) error { + for _, subnet := range subnets { + targets := []string{subnet} + + /* Expand default route into two /1 subnets to override the default gateway + without removing the existing 0.0.0.0/0 entry. + */ + if subnet == "0.0.0.0/0" { + targets = []string{"0.0.0.0/1", "128.0.0.0/1"} + } + + for _, target := range targets { + cmd := exec.Command("route", "-n", "add", "-net", target, "-interface", iface) + out, err := cmd.CombinedOutput() + if err != nil { + // Check if it's a permission error + if strings.Contains(string(out), "must be root") { + return fmt.Errorf( + "permission denied: must run as root to modify routing table (sudo required)", + ) + } + return fmt.Errorf( + "failed to add route for %s on %s: %s: %w", + target, + iface, + string(out), + err, + ) + } + } + } + return nil +} + +// UnsetRoute removes previously configured network routes +func UnsetRoute(iface string, subnets []string) error { + for _, subnet := range subnets { + targets := []string{subnet} + + if subnet == "0.0.0.0/0" { + targets = []string{"0.0.0.0/1", "128.0.0.0/1"} + } + + for _, target := range targets { + /* Delete specific routes to revert traffic flow to the original gateway. */ + cmd := exec.Command("route", "-n", "delete", "-net", target, "-interface", iface) + if out, err := cmd.CombinedOutput(); err != nil { + _ = out + } + } + } + return nil +} + +// SetInterfaceAddress configures the TUN interface with local and remote endpoints +func SetInterfaceAddress(iface string, local string, remote string) error { + cmd := exec.Command("ifconfig", iface, local, remote, "up") + if out, err := cmd.CombinedOutput(); err != nil { + return fmt.Errorf("failed to set interface address: %s: %w", string(out), err) + } + return nil +} + +// UnsetGatewayRoute removes the scoped gateway route +func UnsetGatewayRoute(gateway, iface string) error { + // Remove the scoped default route for the physical interface + // This undoes the SetGatewayRoute ifscope route + cmd := exec.Command("route", "delete", "-ifscope", iface, "default") + if out, err := cmd.CombinedOutput(); err != nil { + _ = out + } + + // Remove the direct host route to the gateway + cmd = exec.Command("route", "-n", "delete", "-host", gateway, "-interface", iface) + if out, err := cmd.CombinedOutput(); err != nil { + _ = out + } + + return nil +} + +// SetGatewayRoute adds a scoped default route via the physical interface +// This ensures IP_BOUND_IF sockets on the physical interface can reach the gateway +func SetGatewayRoute(gateway, iface string) error { + // Add a scoped default route for the physical interface + // When a socket is bound to this interface via IP_BOUND_IF, + // this scoped route will be used instead of the TUN routes (0/1, 128/1) + cmd := exec.Command("route", "add", "-ifscope", iface, "default", gateway) + out, err := cmd.CombinedOutput() + if err != nil { + // Ignore "File exists" error - route already exists + if !strings.Contains(string(out), "File exists") { + return fmt.Errorf("failed to add scoped default route: %s: %w", string(out), err) + } + } + + // Also add a host route to the gateway via the physical interface + // This ensures packets to the gateway itself go through the right interface + cmd = exec.Command("route", "-n", "add", "-host", gateway, "-interface", iface) + out, err = cmd.CombinedOutput() + if err != nil { + // Ignore "File exists" error + if !strings.Contains(string(out), "File exists") { + // This is optional, don't fail if it doesn't work + _ = out + } + } + + return nil +} diff --git a/internal/server/tun/network_linux.go b/internal/server/tun/network_linux.go new file mode 100644 index 00000000..ba8f9190 --- /dev/null +++ b/internal/server/tun/network_linux.go @@ -0,0 +1,271 @@ +//go:build linux + +package tun + +import ( + "fmt" + "os/exec" + "regexp" + "strconv" + "strings" + "sync" +) + +var ( + allocatedTableID int + allocatedTableOnce sync.Once +) + +// getOrAllocateTableID returns a routing table ID, allocating one if needed. +// The ID is cached after first allocation. +func getOrAllocateTableID() (int, error) { + var initErr error + allocatedTableOnce.Do(func() { + allocatedTableID, initErr = findAvailableTableID() + }) + if allocatedTableID == 0 { + return 0, initErr + } + return allocatedTableID, nil +} + +// findAvailableTableID finds an unused routing table ID in the range 100-252. +func findAvailableTableID() (int, error) { + usedTables := make(map[int]bool) + + // Parse "ip rule show" output to find used table IDs + cmd := exec.Command("ip", "rule", "show") + if out, err := cmd.Output(); err == nil { + re := regexp.MustCompile(`lookup\s+(\d+)`) + matches := re.FindAllStringSubmatch(string(out), -1) + for _, match := range matches { + if len(match) >= 2 { + if id, err := strconv.Atoi(match[1]); err == nil { + usedTables[id] = true + } + } + } + } + + // Also check /etc/iproute2/rt_tables for reserved tables + rtTablesCmd := exec.Command("cat", "/etc/iproute2/rt_tables") + if rtOut, err := rtTablesCmd.Output(); err == nil { + lines := strings.Split(string(rtOut), "\n") + for _, line := range lines { + line = strings.TrimSpace(line) + if line == "" || strings.HasPrefix(line, "#") { + continue + } + fields := strings.Fields(line) + if len(fields) >= 1 { + if id, err := strconv.Atoi(fields[0]); err == nil { + usedTables[id] = true + } + } + } + } + + // Find first available ID in range 100-252 (253-255 are reserved) + for id := 100; id <= 252; id++ { + if !usedTables[id] { + return id, nil + } + } + + return 0, fmt.Errorf("no available routing table ID in range 100-252") +} + +// SetRoute configures network routes for specified subnets using ip route +func SetRoute(iface string, subnets []string) error { + for _, subnet := range subnets { + targets := []string{subnet} + + /* Expand default route into two /1 subnets to override the default gateway + without removing the existing 0.0.0.0/0 entry. + */ + if subnet == "0.0.0.0/0" { + targets = []string{"0.0.0.0/1", "128.0.0.0/1"} + } + + for _, target := range targets { + cmd := exec.Command("ip", "route", "add", target, "dev", iface) + out, err := cmd.CombinedOutput() + if err != nil { + // Check if it's a permission error + if strings.Contains(string(out), "Operation not permitted") || + strings.Contains(string(out), "RTNETLINK answers: Operation not permitted") { + return fmt.Errorf( + "permission denied: must run as root to modify routing table (sudo required)", + ) + } + // Ignore "File exists" error - route already exists + if strings.Contains(string(out), "File exists") { + continue + } + return fmt.Errorf( + "failed to add route for %s on %s: %s: %w", + target, + iface, + string(out), + err, + ) + } + } + } + return nil +} + +// UnsetRoute removes previously configured network routes using ip route +func UnsetRoute(iface string, subnets []string) error { + for _, subnet := range subnets { + targets := []string{subnet} + + if subnet == "0.0.0.0/0" { + targets = []string{"0.0.0.0/1", "128.0.0.0/1"} + } + + for _, target := range targets { + /* Delete specific routes to revert traffic flow to the original gateway. */ + cmd := exec.Command("ip", "route", "del", target, "dev", iface) + if out, err := cmd.CombinedOutput(); err != nil { + _ = out + } + } + } + return nil +} + +// SetInterfaceAddress configures the TUN interface with local and remote endpoints using ip addr +func SetInterfaceAddress(iface string, local string, remote string) error { + // Add the IP address to the interface + cmd := exec.Command("ip", "addr", "add", local, "peer", remote, "dev", iface) + out, err := cmd.CombinedOutput() + if err != nil { + // Ignore if address already exists + if !strings.Contains(string(out), "File exists") { + return fmt.Errorf("failed to set interface address: %s: %w", string(out), err) + } + } + + // Bring the interface up + cmd = exec.Command("ip", "link", "set", "dev", iface, "up") + if out, err := cmd.CombinedOutput(); err != nil { + return fmt.Errorf("failed to bring interface up: %s: %w", string(out), err) + } + + return nil +} + +// UnsetGatewayRoute removes the gateway route and policy routing rules +func UnsetGatewayRoute(gateway, iface string) error { + tableID, err := getOrAllocateTableID() + if err != nil { + return err + } + tableIDStr := strconv.Itoa(tableID) + + // Get the interface's IP address for policy routing cleanup + ifaceIP := getInterfaceIP(iface) + if ifaceIP != "" { + // Remove the policy rule + cmd := exec.Command("ip", "rule", "del", "from", ifaceIP, "lookup", tableIDStr) + if out, err := cmd.CombinedOutput(); err != nil { + _ = out + } + } + + // Remove the route from the allocated table + cmd := exec.Command("ip", "route", "del", "default", "table", tableIDStr) + if out, err := cmd.CombinedOutput(); err != nil { + _ = out + } + + // Remove the direct route to the gateway + cmd = exec.Command("ip", "route", "del", gateway, "dev", iface) + if out, err := cmd.CombinedOutput(); err != nil { + _ = out + } + + return nil +} + +// SetGatewayRoute configures policy routing so that packets from the physical interface +// are routed via the gateway, while other packets go through TUN +func SetGatewayRoute(gateway, iface string) error { + tableID, err := getOrAllocateTableID() + if err != nil { + return fmt.Errorf("failed to allocate routing table ID: %w", err) + } + tableIDStr := strconv.Itoa(tableID) + + // First, add a direct route to the gateway via the interface + cmd := exec.Command("ip", "route", "add", gateway, "dev", iface) + out, err := cmd.CombinedOutput() + if err != nil { + // Ignore "File exists" error - route already exists + if !strings.Contains(string(out), "File exists") { + return fmt.Errorf("failed to add gateway route: %s: %w", string(out), err) + } + } + + // Add default route to the allocated table via the gateway + cmd = exec.Command( + "ip", + "route", + "add", + "default", + "via", + gateway, + "dev", + iface, + "table", + tableIDStr, + ) + out, err = cmd.CombinedOutput() + if err != nil { + if !strings.Contains(string(out), "File exists") { + // Optional, don't fail + _ = out + } + } + + // Get the interface's IP address for policy routing + ifaceIP := getInterfaceIP(iface) + if ifaceIP == "" { + return fmt.Errorf("failed to get IP address for interface %s", iface) + } + + // Add policy rule: packets from this IP use the allocated table + cmd = exec.Command("ip", "rule", "add", "from", ifaceIP, "lookup", tableIDStr) + out, err = cmd.CombinedOutput() + if err != nil { + if !strings.Contains(string(out), "File exists") { + return fmt.Errorf("failed to add policy rule: %s: %w", string(out), err) + } + } + + return nil +} + +// getInterfaceIP returns the first IPv4 address of the given interface +func getInterfaceIP(ifaceName string) string { + cmd := exec.Command("ip", "-4", "addr", "show", ifaceName) + out, err := cmd.Output() + if err != nil { + return "" + } + + // Parse output to find inet line + lines := strings.Split(string(out), "\n") + for _, line := range lines { + line = strings.TrimSpace(line) + if strings.HasPrefix(line, "inet ") { + parts := strings.Fields(line) + if len(parts) >= 2 { + ip := strings.Split(parts[1], "/")[0] + return ip + } + } + } + return "" +} diff --git a/internal/server/tun/server.go b/internal/server/tun/server.go new file mode 100644 index 00000000..aff39d42 --- /dev/null +++ b/internal/server/tun/server.go @@ -0,0 +1,373 @@ +package tun + +import ( + "context" + "errors" + "fmt" + "io" + "io/fs" + "net" + "os" + "strconv" + + "github.com/rs/zerolog" + "github.com/samber/lo" + "github.com/songgao/water" + "github.com/xvzc/SpoofDPI/internal/config" + "github.com/xvzc/SpoofDPI/internal/logging" + "github.com/xvzc/SpoofDPI/internal/matcher" + "github.com/xvzc/SpoofDPI/internal/netutil" + "github.com/xvzc/SpoofDPI/internal/server" + "github.com/xvzc/SpoofDPI/internal/session" + "gvisor.dev/gvisor/pkg/buffer" + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet" + "gvisor.dev/gvisor/pkg/tcpip/link/channel" + "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" + "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" + "gvisor.dev/gvisor/pkg/tcpip/transport/udp" + "gvisor.dev/gvisor/pkg/waiter" +) + +// Ensure tcpip is used to avoid "imported and not used" error +var _ tcpip.NetworkProtocolNumber = ipv4.ProtocolNumber + +type TunServer struct { + logger zerolog.Logger + config *config.Config + matcher matcher.RuleMatcher // For IP-based rule matching + + tcpHandler *TCPHandler + udpHandler *UDPHandler + + tunDevice *water.Interface + iface string + gateway string +} + +func NewTunServer( + logger zerolog.Logger, + config *config.Config, + matcher matcher.RuleMatcher, + tcpHandler *TCPHandler, + udpHandler *UDPHandler, + iface string, + gateway string, +) server.Server { + return &TunServer{ + logger: logger, + config: config, + matcher: matcher, + tcpHandler: tcpHandler, + udpHandler: udpHandler, + iface: iface, + gateway: gateway, + } +} + +func (s *TunServer) ListenAndServe( + appctx context.Context, + ready chan<- struct{}, +) error { + var err error + s.tunDevice, err = newTunDevice() + if err != nil { + return fmt.Errorf("failed to create tun device: %w", err) + } + + go func() { + <-appctx.Done() + _ = s.tunDevice.Close() + }() + + if ready != nil { + close(ready) + } + + return s.handle(appctx) +} + +func (s *TunServer) SetNetworkConfig() (func() error, error) { + if s.tunDevice == nil { + return nil, fmt.Errorf("tun device not initialized") + } + + local, remote, err := netutil.FindSafeSubnet() + if err != nil { + return nil, fmt.Errorf("failed to find safe subnet: %w", err) + } + + if err := SetInterfaceAddress(s.tunDevice.Name(), local, remote); err != nil { + return nil, fmt.Errorf("failed to set interface address: %w", err) + } + + // Add route for the TUN interface subnet to ensure packets can return + // This is crucial for the TUN interface to receive packets destined for its own subnet + // Calculate the network address for /30 subnet (e.g., 10.0.0.1 -> 10.0.0.0/30) + localIP := net.ParseIP(local) + networkAddr := net.IPv4( + localIP[12], + localIP[13], + localIP[14], + localIP[15]&0xFC, + ) // Mask with /30 + + err = SetRoute(s.tunDevice.Name(), []string{networkAddr.String() + "/30"}) + if err != nil { + return nil, fmt.Errorf("failed to set local route: %w", err) + } + + // Add a host route to the gateway via the physical interface + // This ensures SpoofDPI's outbound traffic goes through en0, not utun8 + if err := SetGatewayRoute(s.gateway, s.iface); err != nil { + s.logger.Error().Err(err).Msg("failed to set gateway route") + } + + err = SetRoute(s.tunDevice.Name(), []string{"0.0.0.0/0"}) // Default Route + if err != nil { + return nil, fmt.Errorf("failed to set default route: %w", err) + } + + unset := func() error { + if s.tunDevice == nil { + return nil + } + + // Remove the gateway route + if s.gateway != "" && s.iface != "" { + if err := UnsetGatewayRoute(s.gateway, s.iface); err != nil { + s.logger.Warn().Err(err).Msg("failed to unset gateway route") + } + } + + return UnsetRoute(s.tunDevice.Name(), []string{"0.0.0.0/0"}) // Default Route + } + + return unset, nil +} + +func (s *TunServer) Addr() string { + if s.tunDevice != nil { + return s.tunDevice.Name() + } + return "tun" +} + +// matchRuleByAddr extracts IP and port from net.Addr and performs rule matching +func (s *TunServer) matchRuleByAddr(addr net.Addr) *config.Rule { + if s.matcher == nil { + return nil + } + + host, portStr, err := net.SplitHostPort(addr.String()) + if err != nil { + return nil + } + + ip := net.ParseIP(host) + if ip == nil { + return nil + } + + port, _ := strconv.Atoi(portStr) + + selector := &matcher.Selector{ + Kind: matcher.MatchKindAddr, + IP: lo.ToPtr(ip), + Port: lo.ToPtr(uint16(port)), + } + + return s.matcher.Search(selector) +} + +func (s *TunServer) handle(appctx context.Context) error { + logger := logging.WithLocalScope(appctx, s.logger, "tun") + + // 1. Create gVisor stack + stk := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{ + tcp.NewProtocol, + udp.NewProtocol, + }, + }) + + // 2. Create channel endpoint + ep := channel.New(256, 1500, "") + + const nicID = 1 + if err := stk.CreateNIC(nicID, ep); err != nil { + return fmt.Errorf("failed to create NIC: %v", err) + } + + // 3. Enable Promiscuous mode & Spoofing + stk.SetPromiscuousMode(nicID, true) + stk.SetSpoofing(nicID, true) + + // 3.5. Add default route to the stack + // Define a subnet that matches all IPv4 addresses (0.0.0.0/0) + defaultSubnet, _ := tcpip.NewSubnet( + tcpip.AddrFrom4([4]byte{0, 0, 0, 0}), + tcpip.MaskFrom("\x00\x00\x00\x00"), + ) + + stk.SetRouteTable([]tcpip.Route{ + { + Destination: defaultSubnet, + NIC: nicID, + }, + }) + + // 4. Register TCP Forwarder + tcpFwd := tcp.NewForwarder(stk, 0, 65535, func(r *tcp.ForwarderRequest) { + var wq waiter.Queue + ep, err := r.CreateEndpoint(&wq) + if err != nil { + logger.Error().Msgf("failed to create endpoint: %v", err) + r.Complete(true) + return + } + r.Complete(false) + + conn := gonet.NewTCPConn(&wq, ep) + + // Match rule by IP before passing to handler + rule := s.matchRuleByAddr(conn.LocalAddr()) + go s.tcpHandler.Handle(session.WithNewTraceID(context.Background()), conn, rule) + }) + stk.SetTransportProtocolHandler(tcp.ProtocolNumber, tcpFwd.HandlePacket) + + // 5. Register UDP Forwarder + udpFwd := udp.NewForwarder(stk, func(r *udp.ForwarderRequest) bool { + var wq waiter.Queue + ep, err := r.CreateEndpoint(&wq) + if err != nil { + logger.Error().Msgf("failed to create udp endpoint: %v", err) + return true + } + + conn := gonet.NewUDPConn(&wq, ep) + + // Match rule by IP before passing to handler + rule := s.matchRuleByAddr(conn.LocalAddr()) + go s.udpHandler.Handle(session.WithNewTraceID(context.Background()), conn, rule) + return true + }) + stk.SetTransportProtocolHandler(udp.ProtocolNumber, udpFwd.HandlePacket) + + // 6. Start packet pump + go s.tunToStack(appctx, logger, ep) + s.stackToTun(appctx, logger, ep) + + return nil +} + +func (s *TunServer) tunToStack( + appctx context.Context, + logger zerolog.Logger, + ep *channel.Endpoint, +) { + buf := make([]byte, 2000) + for { + n, err := s.tunDevice.Read(buf) + if err != nil { + if errors.Is(err, fs.ErrClosed) || errors.Is(err, os.ErrClosed) { + return + } + + select { + case <-appctx.Done(): + return + default: + if err != io.EOF { + logger.Error().Err(err).Msg("failed to read from tun") + } + return + } + } + + if n < 1 { + continue + } + + version := (buf[0] >> 4) + if version != 4 { + logger.Trace().Int("version", int(version)).Msg("skipping non-ipv4 packet") + continue + } + + // Parse source and destination IP for debugging + // if n >= 20 { + // srcIP := net.IP(buf[12:16]) + // dstIP := net.IP(buf[16:20]) + // protocol := buf[9] + // logger.Trace(). + // Str("src", srcIP.String()). + // Str("dst", dstIP.String()). + // Uint8("proto", protocol). + // Int("len", n). + // Msg("injecting packet to stack") + // } + + payload := buffer.MakeWithData(append([]byte(nil), buf[:n]...)) + + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + Payload: payload, + }) + ep.InjectInbound(ipv4.ProtocolNumber, pkt) + pkt.DecRef() + } +} + +type notifier struct { + ch chan<- struct{} +} + +func (n *notifier) WriteNotify() { + select { + case n.ch <- struct{}{}: + default: + } +} + +func (s *TunServer) stackToTun( + appctx context.Context, + logger zerolog.Logger, + ep *channel.Endpoint, +) { + ch := make(chan struct{}, 1) + n := ¬ifier{ch: ch} + ep.AddNotify(n) + + for { + select { + case <-appctx.Done(): + return + default: + } + + pkt := ep.Read() + if pkt == nil { + select { + case <-ch: + continue + case <-appctx.Done(): + return + } + } + + views := pkt.ToView().AsSlice() + if len(views) > 0 { + _, _ = s.tunDevice.Write(views) + } + pkt.DecRef() + } +} + +func newTunDevice() (*water.Interface, error) { + config := water.Config{ + DeviceType: water.TUN, + } + return water.New(config) +} diff --git a/internal/server/tun/tcp.go b/internal/server/tun/tcp.go new file mode 100644 index 00000000..fc2ec4d3 --- /dev/null +++ b/internal/server/tun/tcp.go @@ -0,0 +1,237 @@ +package tun + +import ( + "context" + "fmt" + "net" + "strconv" + "time" + + "github.com/rs/zerolog" + "github.com/xvzc/SpoofDPI/internal/config" + "github.com/xvzc/SpoofDPI/internal/desync" + "github.com/xvzc/SpoofDPI/internal/logging" + "github.com/xvzc/SpoofDPI/internal/matcher" + "github.com/xvzc/SpoofDPI/internal/netutil" + "github.com/xvzc/SpoofDPI/internal/packet" + "github.com/xvzc/SpoofDPI/internal/proto" +) + +type TCPHandler struct { + logger zerolog.Logger + domainMatcher matcher.RuleMatcher // For TLS domain matching only + defaultHTTPSOpts *config.HTTPSOptions + defaultConnOpts *config.ConnOptions + desyncer *desync.TLSDesyncer + sniffer packet.Sniffer // For TTL tracking + iface string + gateway string +} + +func NewTCPHandler( + logger zerolog.Logger, + domainMatcher matcher.RuleMatcher, + defaultHTTPSOpts *config.HTTPSOptions, + defaultConnOpts *config.ConnOptions, + desyncer *desync.TLSDesyncer, + sniffer packet.Sniffer, + iface string, + gateway string, +) *TCPHandler { + return &TCPHandler{ + logger: logger, + domainMatcher: domainMatcher, + defaultHTTPSOpts: defaultHTTPSOpts, + defaultConnOpts: defaultConnOpts, + desyncer: desyncer, + sniffer: sniffer, + iface: iface, + gateway: gateway, + } +} + +func (h *TCPHandler) Handle(ctx context.Context, lConn net.Conn, rule *config.Rule) { + logger := logging.WithLocalScope(ctx, h.logger, "tcp") + + defer netutil.CloseConns(lConn) + + // Set a read deadline for the first byte to avoid hanging indefinitely + _ = lConn.SetReadDeadline(time.Now().Add(1 * time.Second)) + + lBufferedConn := netutil.NewBufferedConn(lConn) + buf, err := lBufferedConn.Peek(1) + if err != nil { + return + } + + // Reset deadline + _ = lConn.SetReadDeadline(time.Time{}) + + // Parse destination from local address (which is the original destination in TUN) + host, portStr, err := net.SplitHostPort(lConn.LocalAddr().String()) + if err != nil { + return + } + port, _ := strconv.Atoi(portStr) + + ip := net.ParseIP(host) + var iface *net.Interface + if h.iface != "" { + iface, _ = net.InterfaceByName(h.iface) + logger.Debug().Str("iface", h.iface).Msg("using interface for dial") + } else { + logger.Debug().Msg("no interface specified for dial") + } + + dst := &netutil.Destination{ + Domain: host, + Port: port, + Addrs: []net.IP{}, + Iface: iface, + Gateway: h.gateway, + } + if h.defaultConnOpts != nil && h.defaultConnOpts.TCPTimeout != nil { + dst.Timeout = *h.defaultConnOpts.TCPTimeout + } + if ip != nil { + dst.Addrs = append(dst.Addrs, ip) + } + + // Check if it's a TLS Handshake (Content Type 0x16) + if buf[0] == 0x16 { + logger.Debug().Msg("detected tls handshake") + if err := h.handleTLS(ctx, logger, lBufferedConn, dst, rule); err != nil { + logger.Debug().Err(err).Msg("tls handler failed") + } + return + } + + // Handle as plain TCP + rConn, err := netutil.DialFastest(ctx, "tcp", dst) + if err != nil { + logger.Error().Msgf("failed to dial %v", err) + return + } + + logger.Debug().Msgf("new remote conn -> %s", rConn.RemoteAddr()) + + resCh := make(chan netutil.TransferResult, 2) + + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + startedAt := time.Now() + go netutil.TunnelConns(ctx, resCh, lBufferedConn, rConn, netutil.TunnelDirOut) + go netutil.TunnelConns(ctx, resCh, rConn, lBufferedConn, netutil.TunnelDirIn) + + err = netutil.WaitAndLogTunnel( + ctx, + logger, + resCh, + startedAt, + netutil.DescribeRoute(lConn, rConn), + nil, + ) + if err != nil { + logger.Error().Err(err).Msg("error handling request") + } +} + +func (h *TCPHandler) handleTLS( + ctx context.Context, + logger zerolog.Logger, + lConn net.Conn, + dst *netutil.Destination, + addrRule *config.Rule, // Rule matched by IP in server.go +) error { + // Read ClientHello + tlsMsg, err := proto.ReadTLSMessage(lConn) + if err != nil { + return err + } + + if !tlsMsg.IsClientHello() { + return fmt.Errorf("not a client hello") + } + + // Extract SNI + start, end, err := tlsMsg.ExtractSNIOffset() + if err != nil { + return fmt.Errorf("failed to extract sni: %w", err) + } + dst.Domain = string(tlsMsg.Raw()[start:end]) + + logger.Trace().Str("value", dst.Domain).Msgf("extracted sni feild") + + // Match Rules + httpsOpts := h.defaultHTTPSOpts.Clone() + connOpts := h.defaultConnOpts.Clone() + + // First, apply IP-based rule if matched in server.go + if addrRule != nil { + logger.Trace().RawJSON("summary", addrRule.JSON()).Msg("addr match") + httpsOpts = httpsOpts.Merge(addrRule.HTTPS) + connOpts = connOpts.Merge(addrRule.Conn) + } + + // Then, try domain-based matching (TLS-specific) + if h.domainMatcher != nil { + domainSelector := &matcher.Selector{ + Kind: matcher.MatchKindDomain, + Domain: &dst.Domain, + } + if domainRule := h.domainMatcher.Search(domainSelector); domainRule != nil { + logger.Trace().RawJSON("summary", domainRule.JSON()).Msg("domain match") + // Domain rule takes priority if it has higher priority + finalRule := matcher.GetHigherPriorityRule(addrRule, domainRule) + if finalRule == domainRule { + httpsOpts = h.defaultHTTPSOpts.Clone().Merge(domainRule.HTTPS) + connOpts = h.defaultConnOpts.Clone().Merge(domainRule.Conn) + } + } + } + + dst.Timeout = *connOpts.TCPTimeout + + // Dial Remote + if h.sniffer != nil { + h.sniffer.RegisterUntracked(dst.Addrs) + } + rConn, err := netutil.DialFastest(ctx, "tcp", dst) + if err != nil { + return err + } + defer netutil.CloseConns(rConn) + + logger.Debug(). + Msgf("new remote conn (%s -> %s)", lConn.RemoteAddr(), rConn.RemoteAddr()) + + // Send ClientHello with Desync + if _, err := h.desyncer.Desync(ctx, logger, rConn, tlsMsg, httpsOpts); err != nil { + return err + } + + // Tunnel rest + resCh := make(chan netutil.TransferResult, 2) + + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + startedAt := time.Now() + go netutil.TunnelConns(ctx, resCh, lConn, rConn, netutil.TunnelDirOut) + go netutil.TunnelConns(ctx, resCh, rConn, lConn, netutil.TunnelDirIn) + + return netutil.WaitAndLogTunnel( + ctx, + logger, + resCh, + startedAt, + netutil.DescribeRoute(lConn, rConn), + nil, + ) +} + +func (h *TCPHandler) SetNetworkInfo(iface, gateway string) { + h.iface = iface + h.gateway = gateway +} diff --git a/internal/server/tun/udp.go b/internal/server/tun/udp.go new file mode 100644 index 00000000..39b0bee8 --- /dev/null +++ b/internal/server/tun/udp.go @@ -0,0 +1,124 @@ +package tun + +import ( + "context" + "net" + "strconv" + "time" + + "github.com/rs/zerolog" + "github.com/xvzc/SpoofDPI/internal/config" + "github.com/xvzc/SpoofDPI/internal/desync" + "github.com/xvzc/SpoofDPI/internal/logging" + "github.com/xvzc/SpoofDPI/internal/netutil" +) + +type UDPHandler struct { + logger zerolog.Logger + defaultUDPOpts *config.UDPOptions + defaultConnOpts *config.ConnOptions + desyncer *desync.UDPDesyncer + iface string + gateway string +} + +func NewUDPHandler( + logger zerolog.Logger, + desyncer *desync.UDPDesyncer, + defaultUDPOpts *config.UDPOptions, + defaultConnOpts *config.ConnOptions, + iface string, + gateway string, +) *UDPHandler { + return &UDPHandler{ + logger: logger, + desyncer: desyncer, + defaultUDPOpts: defaultUDPOpts, + defaultConnOpts: defaultConnOpts, + iface: iface, + gateway: gateway, + } +} + +func (h *UDPHandler) SetNetworkInfo(iface, gateway string) { + h.iface = iface + h.gateway = gateway +} + +func (h *UDPHandler) Handle(ctx context.Context, lConn net.Conn, rule *config.Rule) { + logger := logging.WithLocalScope(ctx, h.logger, "udp") + + defer netutil.CloseConns(lConn) + + host, portStr, err := net.SplitHostPort(lConn.LocalAddr().String()) + if err != nil { + return + } + port, _ := strconv.Atoi(portStr) + + var iface *net.Interface + if h.iface != "" { + iface, _ = net.InterfaceByName(h.iface) + } + + dst := &netutil.Destination{ + Domain: host, + Port: port, + Iface: iface, + Gateway: h.gateway, + } + if ip := net.ParseIP(host); ip != nil { + dst.Addrs = []net.IP{ip} + } + + // Apply rule if matched in server.go + udpOpts := h.defaultUDPOpts.Clone() + connOpts := h.defaultConnOpts.Clone() + if rule != nil { + logger.Trace().RawJSON("summary", rule.JSON()).Msg("match") + udpOpts = udpOpts.Merge(rule.UDP) + connOpts = connOpts.Merge(rule.Conn) + } + + // Dial remote connection + rawConn, err := netutil.DialFastest(ctx, "udp", dst) + if err != nil { + logger.Error().Msgf("error dialing to %s", dst.String()) + return + } + + timeout := *connOpts.UDPIdleTimeout + + // Wrap rConn with IdleTimeoutConn + rConnWrapped := netutil.NewIdleTimeoutConn(rawConn, timeout) + + // Wrap lConn with IdleTimeoutConn as well + lConnWrapped := netutil.NewIdleTimeoutConn(lConn, timeout) + + // Desync + _, _ = h.desyncer.Desync(ctx, lConnWrapped, rConnWrapped, udpOpts) + + logger.Debug(). + Msgf("new remote conn (%s -> %s)", lConn.RemoteAddr(), rConnWrapped.RemoteAddr()) + + resCh := make(chan netutil.TransferResult, 2) + + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + startedAt := time.Now() + go netutil.TunnelConns(ctx, resCh, lConnWrapped, rConnWrapped, netutil.TunnelDirOut) + go netutil.TunnelConns(ctx, resCh, rConnWrapped, lConnWrapped, netutil.TunnelDirIn) + + err = netutil.WaitAndLogTunnel( + ctx, + logger, + resCh, + startedAt, + netutil.DescribeRoute(lConnWrapped, rConnWrapped), + nil, + ) + if err != nil { + logger.Error().Err(err).Msg("error handling request") + } +} diff --git a/internal/session/session.go b/internal/session/session.go index 1d9e3f06..45408056 100644 --- a/internal/session/session.go +++ b/internal/session/session.go @@ -3,7 +3,6 @@ package session import ( "context" "math/rand/v2" - "unsafe" ) // We define unexported key types to prevent key collisions with other packages. @@ -34,7 +33,8 @@ func TraceIDFrom(ctx context.Context) (string, bool) { if ok { return traceID, true } - return "", false + + return "0000000000000000", false } // WithHostInfo returns a new context carrying the given domain name string. @@ -87,5 +87,6 @@ func generateTraceID() string { b[i] = r + 0x30 } - return unsafe.String(unsafe.SliceData(b), 16) + return string(b) + // return unsafe.String(unsafe.SliceData(b), 16) } diff --git a/internal/system/sysproxy.go b/internal/system/sysproxy.go deleted file mode 100644 index 1419fdab..00000000 --- a/internal/system/sysproxy.go +++ /dev/null @@ -1,13 +0,0 @@ -//go:build !darwin && !linux - -package system - -import "github.com/rs/zerolog" - -func SetProxy(logger zerolog.Logger, port uint16) error { - return nil -} - -func UnsetProxy(logger zerolog.Logger) error { - return nil -} diff --git a/internal/system/sysproxy_darwin.go b/internal/system/sysproxy_darwin.go deleted file mode 100644 index 77973e1d..00000000 --- a/internal/system/sysproxy_darwin.go +++ /dev/null @@ -1,100 +0,0 @@ -//go:build darwin - -package system - -import ( - "errors" - "fmt" - "os/exec" - "strconv" - "strings" - - "github.com/rs/zerolog" -) - -const ( - getDefaultNetworkCMD = "networksetup -listnetworkserviceorder | grep" + - " `(route -n get default | grep 'interface' || route -n get -inet6 default | grep 'interface') | cut -d ':' -f2`" + - " -B 1 | head -n 1 | cut -d ' ' -f 2-" - permissionErrorHelpText = "By default SpoofDPI tries to set itself up as a system-wide proxy server.\n" + - "Doing so may require root access on machines with\n" + - "'Settings > Privacy & Security > Advanced > Require" + - " an administrator password to access system-wide settings' enabled.\n" + - "If you do not want SpoofDPI to act as a system-wide proxy, provide" + - " -system-proxy=false." -) - -func SetProxy(logger zerolog.Logger, port uint16) error { - network, err := getDefaultNetwork() - if err != nil { - return err - } - - return setProxyInternal(getProxyTypes(), network, "127.0.0.1", int(port)) -} - -func UnsetProxy(logger zerolog.Logger) error { - network, err := getDefaultNetwork() - if err != nil { - return err - } - - return unsetProxyInternal(getProxyTypes(), network) -} - -func getDefaultNetwork() (string, error) { - network, err := exec.Command("sh", "-c", getDefaultNetworkCMD).Output() - if err != nil { - return "", err - } else if len(network) == 0 { - return "", errors.New("no available networks") - } - return strings.TrimSpace(string(network)), nil -} - -func getProxyTypes() []string { - return []string{"webproxy", "securewebproxy"} -} - -func setProxyInternal(proxyTypes []string, network, domain string, port int) error { - args := []string{"", network, domain, strconv.FormatUint(uint64(port), 10)} - - for _, proxyType := range proxyTypes { - args[0] = "-set" + proxyType - if err := networkSetup(args); err != nil { - return fmt.Errorf("setting %s: %w", proxyType, err) - } - } - return nil -} - -func unsetProxyInternal(proxyTypes []string, network string) error { - args := []string{"", network, "off"} - - for _, proxyType := range proxyTypes { - args[0] = "-set" + proxyType + "state" - if err := networkSetup(args); err != nil { - return fmt.Errorf("unsetting %s: %w", proxyType, err) - } - } - return nil -} - -func networkSetup(args []string) error { - cmd := exec.Command("networksetup", args...) - out, err := cmd.CombinedOutput() - if err != nil { - msg := string(out) - if isPermissionError(err) { - msg += permissionErrorHelpText - } - return fmt.Errorf("%s", msg) - } - return nil -} - -func isPermissionError(err error) bool { - var exitErr *exec.ExitError - ok := errors.As(err, &exitErr) - return ok && exitErr.ExitCode() == 14 -} diff --git a/internal/system/sysproxy_linux.go b/internal/system/sysproxy_linux.go deleted file mode 100644 index 1916a7cd..00000000 --- a/internal/system/sysproxy_linux.go +++ /dev/null @@ -1,14 +0,0 @@ -//go:build linux - -package system - -import "github.com/rs/zerolog" - -func SetProxy(logger zerolog.Logger, port uint16) error { - logger.Info().Msgf("automatic system-wide proxy setup is not implemented on Linux") - return nil -} - -func UnsetProxy(logger zerolog.Logger) error { - return nil -} diff --git a/mkdocs.yml b/mkdocs.yml index f0c5002c..60d19877 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -56,10 +56,11 @@ nav: - 'User Guide': - 'Overview': user-guide/overview.md - 'Options': - - 'General': user-guide/general.md - - 'Server': user-guide/server.md + - 'App': user-guide/app.md + - 'Connection': user-guide/connection.md - 'DNS': user-guide/dns.md - 'HTTPS': user-guide/https.md + - 'UDP': user-guide/udp.md - 'Policy': user-guide/policy.md - 'Configuration Examples': user-guide/examples.md # - 'Geographical Recipes':