diff --git a/config.go b/config.go index f59bb3f85..ebef9ecb6 100644 --- a/config.go +++ b/config.go @@ -24,6 +24,7 @@ type Config struct { ConnectionsPerHost int `long:"connections-per-host" default:"1" description:"Number of times to connect to each host (results in more output)"` ReadLimitPerHost int `long:"read-limit-per-host" default:"96" description:"Maximum total kilobytes to read for a single host (default 96kb)"` Prometheus string `long:"prometheus" description:"Address to use for Prometheus server (e.g. localhost:8080). If empty, Prometheus is disabled."` + LocalAddrStr string `long:"local-addr" description:"Local source address for outgoing connections (e.g. 192.168.10.2:0, port is required even if it's 0)"` CustomDNS string `long:"dns" description:"Address of a custom DNS server for lookups. Default port is 53."` Multiple MultipleCommand `command:"multiple" description:"Multiple module actions"` inputFile *os.File @@ -100,6 +101,15 @@ func validateFrameworkConfiguration() { } runtime.GOMAXPROCS(config.GOMAXPROCS) + // Parse and validate the local address if specified + if config.LocalAddrStr != "" { + var err error + config.localAddr, err = net.ResolveTCPAddr("tcp", config.LocalAddrStr) + if err != nil { + log.Fatalf("could not resolve local address %s: %v", config.LocalAddrStr, err) + } + } + //validate/start prometheus if config.Prometheus != "" { go func() { diff --git a/conn.go b/conn.go index e4dc43e2e..2d8a71bee 100644 --- a/conn.go +++ b/conn.go @@ -236,6 +236,13 @@ func NewTimeoutConnection(ctx context.Context, conn net.Conn, timeout, readTimeo func DialTimeoutConnectionEx(proto string, target string, dialTimeout, sessionTimeout, readTimeout, writeTimeout time.Duration, bytesReadLimit int) (net.Conn, error) { var conn net.Conn var err error + dialer := &net.Dialer{ + Timeout: dialTimeout, + } + if config.localAddr != nil { + dialer.LocalAddr = config.localAddr + } + conn, err = dialer.Dial(proto, target) if dialTimeout > 0 { conn, err = net.DialTimeout(proto, target, dialTimeout) } else { @@ -300,7 +307,9 @@ func (d *Dialer) DialContext(ctx context.Context, network, address string) (net. d.Dialer.KeepAlive = d.Timeout // Copy over the source IP if set, or nil - d.Dialer.LocalAddr = config.localAddr + if config.localAddr != nil { + d.Dialer.LocalAddr = config.localAddr + } dialContext, cancelDial := context.WithTimeout(ctx, d.Dialer.Timeout) defer cancelDial()