From 05e558e2a81e27a956013dbe8e61aab006c511d5 Mon Sep 17 00:00:00 2001 From: Daniel Jampen Date: Sun, 2 Mar 2025 10:04:29 +0100 Subject: [PATCH 1/8] implement incoming handshake filtering --- examples/config.yml | 13 ++- firewall.go | 149 +++++++++++++++++++++++-- firewall/packet.go | 1 + firewall_test.go | 259 ++++++++++++++++++++++++++++++++++++++++---- handshake_ix.go | 8 ++ interface.go | 4 +- lighthouse.go | 26 ++++- lighthouse_test.go | 12 +- main.go | 4 +- 9 files changed, 434 insertions(+), 42 deletions(-) diff --git a/examples/config.yml b/examples/config.yml index 4e7a4ae99..07eefb4ca 100644 --- a/examples/config.yml +++ b/examples/config.yml @@ -61,6 +61,16 @@ lighthouse: hosts: - "192.168.100.1" + # This feature allows handshakes only from hosts whose hostname, groups, or IP address are specified in + # any inbound firewall rule. By doing so, it prevents unauthorized nodes from establishing a tunnel, conserving + # resources on the host and adding an additional security layer that reduces the potential attack surface. + # - If enabled on a lighthouse, this will prevent nodes from accessing lighthouse features unless there is an + # incoming rule allowing access. Use allow rules with `port: nebula` to grant access to nodes. + # - Similar considerations should be made before enabling the feature on relay nodes, as it can restrict + # access to certain nodes only. Both ends of the relayed tunnel must be allowed to communicate with this node + # for relaying to function. + #incoming_handshake_filtering: false + # remote_allow_list allows you to control ip ranges that this node will # consider when handshaking to another node. By default, any remote IPs are # allowed. You can provide CIDRs here with `true` to allow and `false` to @@ -340,7 +350,8 @@ firewall: # The firewall is default deny. There is no way to write a deny rule. # Rules are comprised of a protocol, port, and one or more of host, group, or CIDR # Logical evaluation is roughly: port AND proto AND (ca_sha OR ca_name) AND (host OR group OR groups OR cidr) AND (local cidr) - # - port: Takes `0` or `any` as any, a single number `80`, a range `200-901`, or `fragment` to match second and further fragments of fragmented packets (since there is no port available). + # - port: Takes `0` or `any` as any, a single number `80`, a range `200-901`, `fragment` to match second and further fragments of fragmented + # packets (since there is no port available) or `nebula` to create nebula access rules (udp) for the handshake filtering feature. # code: same as port but makes more sense when talking about ICMP, TODO: this is not currently implemented in a way that works, use `any` # proto: `any`, `tcp`, `udp`, or `icmp` # host: `any` or a literal hostname, ie `test-host` diff --git a/firewall.go b/firewall.go index e9f454deb..6814ed587 100644 --- a/firewall.go +++ b/firewall.go @@ -82,6 +82,15 @@ type FirewallConntrack struct { TimerWheel *TimerWheel[firewall.Packet] } +type HandshakeFilter struct { + AllowedHosts map[string]struct{} + AllowedGroups map[string]struct{} + AllowedGroupsCombos []map[string]struct{} + AllowedCidrs []netip.Prefix + AllowedCANames map[string]struct{} + AllowedCAShas map[string]struct{} +} + // FirewallTable is the entry point for a rule, the evaluation order is: // Proto AND port AND (CA SHA or CA name) AND local CIDR AND (group OR groups OR name OR remote CIDR) type FirewallTable struct { @@ -190,7 +199,7 @@ func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.D } } -func NewFirewallFromConfig(l *logrus.Logger, cs *CertState, c *config.C) (*Firewall, error) { +func NewFirewallFromConfig(l *logrus.Logger, cs *CertState, c *config.C) (*Firewall, *HandshakeFilter, error) { certificate := cs.getCertificate(cert.Version2) if certificate == nil { certificate = cs.getCertificate(cert.Version1) @@ -209,6 +218,8 @@ func NewFirewallFromConfig(l *logrus.Logger, cs *CertState, c *config.C) (*Firew //TODO: max_connections ) + hf := NewHandshakeFilter() + fw.defaultLocalCIDRAny = c.GetBool("firewall.default_local_cidr_any", false) inboundAction := c.GetString("firewall.inbound_action", "drop") @@ -233,17 +244,17 @@ func NewFirewallFromConfig(l *logrus.Logger, cs *CertState, c *config.C) (*Firew fw.OutSendReject = false } - err := AddFirewallRulesFromConfig(l, false, c, fw) + err := AddFirewallRulesFromConfig(l, false, c, fw, nil) if err != nil { - return nil, err + return nil, nil, err } - err = AddFirewallRulesFromConfig(l, true, c, fw) + err = AddFirewallRulesFromConfig(l, true, c, fw, hf) if err != nil { - return nil, err + return nil, nil, err } - return fw, nil + return fw, hf, nil } // AddRule properly creates the in memory rule structure for a firewall table. @@ -318,7 +329,7 @@ func (f *Firewall) GetRuleHashes() string { return "SHA:" + f.GetRuleHash() + ",FNV:" + strconv.FormatUint(uint64(f.GetRuleHashFNV()), 10) } -func AddFirewallRulesFromConfig(l *logrus.Logger, inbound bool, c *config.C, fw FirewallInterface) error { +func AddFirewallRulesFromConfig(l *logrus.Logger, inbound bool, c *config.C, fw FirewallInterface, hf *HandshakeFilter) error { var table string if inbound { table = "firewall.inbound" @@ -412,6 +423,10 @@ func AddFirewallRulesFromConfig(l *logrus.Logger, inbound bool, c *config.C, fw if err != nil { return fmt.Errorf("%s rule #%v; `%s`", table, i, err) } + + if hf != nil { + hf.AddRule(groups, r.Host, cidr, r.CAName, r.CASha) + } } return nil @@ -981,6 +996,10 @@ func parsePort(s string) (startPort, endPort int32, err error) { startPort = firewall.PortFragment endPort = firewall.PortFragment + } else if s == "nebula" { + startPort = firewall.PortNebula + endPort = firewall.PortNebula + } else if strings.Contains(s, `-`) { sPorts := strings.SplitN(s, `-`, 2) sPorts[0] = strings.Trim(sPorts[0], " ") @@ -1018,3 +1037,119 @@ func parsePort(s string) (startPort, endPort int32, err error) { return } + +func NewHandshakeFilter() *HandshakeFilter { + return &HandshakeFilter{ + AllowedHosts: make(map[string]struct{}), + AllowedGroups: make(map[string]struct{}), + AllowedGroupsCombos: make([]map[string]struct{}, 0), + AllowedCidrs: make([]netip.Prefix, 0), + AllowedCANames: make(map[string]struct{}), + AllowedCAShas: make(map[string]struct{}), + } +} + +func (hfws *HandshakeFilter) AddRule(groups []string, host string, localIp netip.Prefix, CAName string, CASha string) { + if host != "" { + hfws.AllowedHosts[host] = struct{}{} + } + + if len(groups) > 1 { + gs := make(map[string]struct{}, len(groups)) + for i := range groups { + gs[groups[i]] = struct{}{} + } + + if !containsMap(hfws.AllowedGroupsCombos, gs) { + hfws.AllowedGroupsCombos = append( + hfws.AllowedGroupsCombos, + gs, + ) + } + } else if len(groups) == 1 { + hfws.AllowedGroups[groups[0]] = struct{}{} + } + + if localIp.IsValid() { + hfws.AllowedCidrs = append( + hfws.AllowedCidrs, + localIp, + ) + } + + if CAName != "" { + hfws.AllowedCANames[CAName] = struct{}{} + } + + if CASha != "" { + hfws.AllowedCAShas[CASha] = struct{}{} + } +} + +func (hfws *HandshakeFilter) IsHandshakeAllowed(groups []string, host string, vpnAddrs []netip.Addr, CAName string, CASha string) bool { + if _, ok := hfws.AllowedHosts["any"]; ok { + return true + } + if _, ok := hfws.AllowedHosts[host]; ok { + return true + } + + if _, ok := hfws.AllowedCANames[CAName]; ok { + return true + } + + if _, ok := hfws.AllowedCAShas[CASha]; ok { + return true + } + + if len(groups) != 0 { + if _, ok := hfws.AllowedGroups["any"]; ok { + return true + } + } + for _, g := range groups { + if _, ok := hfws.AllowedGroups[g]; ok { + return true + } + } + + for _, c := range hfws.AllowedCidrs { + for _, a := range vpnAddrs { + if c.Contains(a) { + return true + } + } + } + + for _, gc := range hfws.AllowedGroupsCombos { + if len(groups) < len(gc) { + continue + } + + if isSubset(gc, groups) { + return true + } + } + + return false +} + +func isSubset(subset map[string]struct{}, superset []string) bool { + ls := len(subset) + s := make(map[string]struct{}, ls) + for _, value := range superset { + if _, ok := subset[value]; ok { + s[value] = struct{}{} + } + } + return len(s) == ls +} + +func containsMap(slice []map[string]struct{}, target map[string]struct{}) bool { + for _, m := range slice { + if reflect.DeepEqual(m, target) { + return true + } + } + return false +} diff --git a/firewall/packet.go b/firewall/packet.go index 1d8f12a0c..912e5aef9 100644 --- a/firewall/packet.go +++ b/firewall/packet.go @@ -17,6 +17,7 @@ const ( PortAny = 0 // Special value for matching `port: any` PortFragment = -1 // Special value for matching `port: fragment` + PortNebula = -2 // Special value for matching `port: nebula` ) type Packet struct { diff --git a/firewall_test.go b/firewall_test.go index 8c2eeb058..c6ad76854 100644 --- a/firewall_test.go +++ b/firewall_test.go @@ -510,6 +510,223 @@ func TestFirewall_DropConntrackReload(t *testing.T) { assert.Equal(t, fw.Drop(p, false, &h, cp, nil), ErrNoMatchingRule) } +func TestHandshakeFilter_AddRuleToHandshakeFilter(t *testing.T) { + hf := NewHandshakeFilter() + assert.Empty(t, hf.AllowedHosts) + assert.Empty(t, hf.AllowedGroups) + assert.Empty(t, hf.AllowedGroupsCombos) + assert.Empty(t, hf.AllowedCidrs) + assert.Empty(t, hf.AllowedCANames) + assert.Empty(t, hf.AllowedCAShas) + + hf.AddRule([]string{}, "", netip.Prefix{}, "", "") + assert.Empty(t, hf.AllowedCidrs) + assert.Empty(t, hf.AllowedGroups) + assert.Empty(t, hf.AllowedGroupsCombos) + assert.Empty(t, hf.AllowedHosts) + assert.Empty(t, hf.AllowedCANames) + assert.Empty(t, hf.AllowedCAShas) + + hf = NewHandshakeFilter() + g := "g1" + hf.AddRule([]string{g}, "", netip.Prefix{}, "", "") + assert.Empty(t, hf.AllowedCidrs) + assert.Contains(t, hf.AllowedGroups, g) + assert.Empty(t, hf.AllowedGroupsCombos) + assert.Empty(t, hf.AllowedHosts) + assert.Empty(t, hf.AllowedCANames) + assert.Empty(t, hf.AllowedCAShas) + + hf = NewHandshakeFilter() + h := "h1" + hf.AddRule([]string{}, h, netip.Prefix{}, "", "") + assert.Empty(t, hf.AllowedCidrs) + assert.Empty(t, hf.AllowedGroups) + assert.Empty(t, hf.AllowedGroupsCombos) + assert.Contains(t, hf.AllowedHosts, h) + assert.Empty(t, hf.AllowedCANames) + assert.Empty(t, hf.AllowedCAShas) + + hf = NewHandshakeFilter() + ti, err := netip.ParsePrefix("1.2.3.4/32") + assert.NoError(t, err) + hf.AddRule([]string{}, "", ti, "", "") + assert.Contains(t, hf.AllowedCidrs, ti) + assert.Empty(t, hf.AllowedGroups) + assert.Empty(t, hf.AllowedGroupsCombos) + assert.Empty(t, hf.AllowedHosts) + assert.Empty(t, hf.AllowedCANames) + assert.Empty(t, hf.AllowedCAShas) + + hf = NewHandshakeFilter() + groups := []string{"g1", "g2"} + hf.AddRule(groups, "", netip.Prefix{}, "", "") + assert.Empty(t, hf.AllowedCidrs) + assert.Empty(t, hf.AllowedGroups) + assert.Empty(t, hf.AllowedCANames) + assert.Empty(t, hf.AllowedCAShas) + + gs := make(map[string]struct{}) + for i := range groups { + gs[groups[i]] = struct{}{} + } + assert.Len(t, gs, len(groups)) + assert.Equal(t, hf.AllowedGroupsCombos[0], gs) + assert.Empty(t, hf.AllowedHosts) + assert.Empty(t, hf.AllowedCANames) + assert.Empty(t, hf.AllowedCAShas) + + hf = NewHandshakeFilter() + i := "TestCA" + hf.AddRule([]string{}, "", netip.Prefix{}, i, "") + assert.Empty(t, hf.AllowedCidrs) + assert.Empty(t, hf.AllowedGroups) + assert.Empty(t, hf.AllowedGroupsCombos) + assert.Empty(t, hf.AllowedHosts) + assert.Contains(t, hf.AllowedCANames, i) + assert.Empty(t, hf.AllowedCAShas) + + hf = NewHandshakeFilter() + s := "3fc204e4d45e8b22ed0879bcd7cb5bf93cdc1c7a309c5dcedddc03aed33a47c6" + hf.AddRule([]string{}, "", netip.Prefix{}, "", s) + assert.Empty(t, hf.AllowedCidrs) + assert.Empty(t, hf.AllowedGroups) + assert.Empty(t, hf.AllowedGroupsCombos) + assert.Empty(t, hf.AllowedHosts) + assert.Empty(t, hf.AllowedCANames) + assert.Contains(t, hf.AllowedCAShas, s) +} + +func TestHandshakeFilter_IsHandshakeAllowed(t *testing.T) { + hf := NewHandshakeFilter() + assert.NotNil(t, hf.AllowedHosts) + assert.NotNil(t, hf.AllowedGroups) + assert.NotNil(t, hf.AllowedGroupsCombos) + assert.NotNil(t, hf.AllowedCidrs) + + ti, err := netip.ParsePrefix("1.2.3.0/24") + assert.NoError(t, err) + ais := make([]netip.Addr, 2) + ai, err := netip.ParseAddr("1.2.3.5") + assert.NoError(t, err) + ai2, err := netip.ParseAddr("1.1.1.2") + assert.NoError(t, err) + ais[0] = ai + ais[1] = ai2 + + aos := make([]netip.Addr, 2) + ao, err := netip.ParseAddr("1.2.0.1") + assert.NoError(t, err) + ao2, err := netip.ParseAddr("1.10.0.1") + assert.NoError(t, err) + aos[0] = ao + aos[1] = ao2 + + hf.AddRule([]string{"g1"}, "", netip.Prefix{}, "", "") + hf.AddRule([]string{}, "h1", netip.Prefix{}, "", "") + hf.AddRule([]string{}, "", ti, "", "") + hf.AddRule([]string{"g1", "g2"}, "", netip.Prefix{}, "", "") + hf.AddRule([]string{"g1", "g2"}, "", netip.Prefix{}, "", "") + hf.AddRule([]string{}, "", netip.Prefix{}, "TestCA", "") + hf.AddRule([]string{}, "", netip.Prefix{}, "", "3fc204e4d45e8b22ed0879bcd7cb5bf93cdc1c7a309c5dcedddc03aed33a47c6") + + g := []string{"g2", "g3"} + g2 := []string{"g1", "g2"} + assert.True(t, hf.IsHandshakeAllowed([]string{"g1"}, "", []netip.Addr{netip.Addr{}}, "", "")) + assert.False(t, hf.IsHandshakeAllowed([]string{"g2"}, "", []netip.Addr{netip.Addr{}}, "", "")) + assert.False(t, hf.IsHandshakeAllowed(g, "", []netip.Addr{netip.Addr{}}, "", "")) + assert.True(t, hf.IsHandshakeAllowed(g2, "", []netip.Addr{netip.Addr{}}, "", "")) + assert.Len(t, hf.AllowedGroupsCombos[0], 2) + + assert.True(t, hf.IsHandshakeAllowed([]string{"g2", "g1"}, "", []netip.Addr{netip.Addr{}}, "", "")) + assert.True(t, hf.IsHandshakeAllowed([]string{"g3", "g2", "g1"}, "", []netip.Addr{netip.Addr{}}, "", "")) + assert.False(t, hf.IsHandshakeAllowed([]string{}, "h2", aos, "", "")) + + assert.True(t, hf.IsHandshakeAllowed([]string{}, "h1", []netip.Addr{netip.Addr{}}, "", "")) + assert.False(t, hf.IsHandshakeAllowed([]string{}, "h2", []netip.Addr{netip.Addr{}}, "", "")) + assert.False(t, hf.IsHandshakeAllowed([]string{}, "", []netip.Addr{netip.Addr{}}, "", "")) + assert.False(t, hf.IsHandshakeAllowed([]string{}, "any", []netip.Addr{netip.Addr{}}, "", "")) + assert.False(t, hf.IsHandshakeAllowed([]string{"h1"}, "", aos, "", "")) + + assert.True(t, hf.IsHandshakeAllowed([]string{}, "", ais, "", "")) + assert.False(t, hf.IsHandshakeAllowed([]string{}, "", aos, "", "")) + assert.False(t, hf.IsHandshakeAllowed([]string{"g3"}, "h2", aos, "", "")) + + assert.True(t, hf.IsHandshakeAllowed([]string{}, "", []netip.Addr{netip.Addr{}}, "TestCA", "")) + assert.False(t, hf.IsHandshakeAllowed([]string{}, "", []netip.Addr{netip.Addr{}}, "WrongCA", "")) + + assert.True(t, hf.IsHandshakeAllowed([]string{}, "", []netip.Addr{netip.Addr{}}, "", "3fc204e4d45e8b22ed0879bcd7cb5bf93cdc1c7a309c5dcedddc03aed33a47c6")) + assert.False(t, hf.IsHandshakeAllowed([]string{}, "", []netip.Addr{netip.Addr{}}, "", "FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF")) + + pAny, err := netip.ParsePrefix("0.0.0.0/0") + assert.NoError(t, err) + + hf = NewHandshakeFilter() + hf.AddRule([]string{"any"}, "", netip.Prefix{}, "", "") + hf.AddRule([]string{}, "any", netip.Prefix{}, "", "") + hf.AddRule([]string{}, "", pAny, "", "") + + assert.True(t, hf.IsHandshakeAllowed([]string{}, "", ais, "", "")) + assert.True(t, hf.IsHandshakeAllowed([]string{}, "h3", []netip.Addr{netip.Addr{}}, "", "")) + assert.True(t, hf.IsHandshakeAllowed([]string{"g4"}, "", []netip.Addr{netip.Addr{}}, "", "")) + assert.True(t, hf.IsHandshakeAllowed([]string{"g4", "g5"}, "", []netip.Addr{netip.Addr{}}, "", "")) + + hf = NewHandshakeFilter() + hf.AddRule([]string{}, "any", netip.Prefix{}, "", "") + assert.True(t, hf.IsHandshakeAllowed([]string{}, "", ais, "", "")) + assert.True(t, hf.IsHandshakeAllowed([]string{}, "h3", []netip.Addr{netip.Addr{}}, "", "")) + assert.True(t, hf.IsHandshakeAllowed([]string{"g4"}, "", []netip.Addr{netip.Addr{}}, "", "")) + assert.True(t, hf.IsHandshakeAllowed([]string{"g4", "g5"}, "", []netip.Addr{netip.Addr{}}, "", "")) + + hf = NewHandshakeFilter() + hf.AddRule([]string{}, "", pAny, "", "") + assert.True(t, hf.IsHandshakeAllowed([]string{}, "", ais, "", "")) + assert.False(t, hf.IsHandshakeAllowed([]string{}, "h3", []netip.Addr{netip.Addr{}}, "", "")) + assert.False(t, hf.IsHandshakeAllowed([]string{"g4"}, "", []netip.Addr{netip.Addr{}}, "", "")) + assert.False(t, hf.IsHandshakeAllowed([]string{"g4", "g5"}, "", []netip.Addr{netip.Addr{}}, "", "")) +} + +func Test_isSubset(t *testing.T) { + subset := make(map[string]struct{}, 2) + subset["g1"] = struct{}{} + subset["g2"] = struct{}{} + assert.Len(t, subset, 2) + + superset := []string{"g0", "g1", "g", "g2", "g3", "g1", "g2", "g4"} + assert.True(t, isSubset(subset, superset)) + superset = []string{"g0", "g1", "g3"} + assert.False(t, isSubset(subset, superset)) + superset = []string{"g0", "g1", "g1", "g3"} + assert.False(t, isSubset(subset, superset)) +} + +func Test_containsMap(t *testing.T) { + slice := make([]map[string]struct{}, 2) + + m1 := make(map[string]struct{}, 2) + m1["g1"] = struct{}{} + m1["g2"] = struct{}{} + assert.Len(t, m1, 2) + + m2 := make(map[string]struct{}, 2) + m2["g2"] = struct{}{} + m2["g3"] = struct{}{} + assert.Len(t, m1, 2) + + slice[0] = m1 + slice[1] = m2 + assert.Len(t, slice, 2) + + assert.True(t, containsMap(slice, m1)) + assert.True(t, containsMap(slice, m2)) + + c1 := make(map[string]struct{}, 2) + c1["g1"] = struct{}{} + c1["g3"] = struct{}{} + assert.Len(t, c1, 2) + assert.False(t, containsMap(slice, c1)) +} + func BenchmarkLookup(b *testing.B) { ml := func(m map[string]struct{}, a [][]string) { for n := 0; n < b.N; n++ { @@ -632,53 +849,53 @@ func TestNewFirewallFromConfig(t *testing.T) { conf := config.NewC(l) conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": "asdf"} - _, err = NewFirewallFromConfig(l, cs, conf) + _, _, err = NewFirewallFromConfig(l, cs, conf) require.EqualError(t, err, "firewall.outbound failed to parse, should be an array of rules") // Test both port and code conf = config.NewC(l) conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "code": "2"}}} - _, err = NewFirewallFromConfig(l, cs, conf) + _, _, err = NewFirewallFromConfig(l, cs, conf) require.EqualError(t, err, "firewall.outbound rule #0; only one of port or code should be provided") // Test missing host, group, cidr, ca_name and ca_sha conf = config.NewC(l) conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{}}} - _, err = NewFirewallFromConfig(l, cs, conf) + _, _, err = NewFirewallFromConfig(l, cs, conf) require.EqualError(t, err, "firewall.outbound rule #0; at least one of host, group, cidr, local_cidr, ca_name, or ca_sha must be provided") // Test code/port error conf = config.NewC(l) conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "a", "host": "testh"}}} - _, err = NewFirewallFromConfig(l, cs, conf) + _, _, err = NewFirewallFromConfig(l, cs, conf) require.EqualError(t, err, "firewall.outbound rule #0; code was not a number; `a`") conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "a", "host": "testh"}}} - _, err = NewFirewallFromConfig(l, cs, conf) + _, _, err = NewFirewallFromConfig(l, cs, conf) require.EqualError(t, err, "firewall.outbound rule #0; port was not a number; `a`") // Test proto error conf = config.NewC(l) conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "1", "host": "testh"}}} - _, err = NewFirewallFromConfig(l, cs, conf) + _, _, err = NewFirewallFromConfig(l, cs, conf) require.EqualError(t, err, "firewall.outbound rule #0; proto was not understood; ``") // Test cidr parse error conf = config.NewC(l) conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "1", "cidr": "testh", "proto": "any"}}} - _, err = NewFirewallFromConfig(l, cs, conf) + _, _, err = NewFirewallFromConfig(l, cs, conf) require.EqualError(t, err, "firewall.outbound rule #0; cidr did not parse; netip.ParsePrefix(\"testh\"): no '/'") // Test local_cidr parse error conf = config.NewC(l) conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "1", "local_cidr": "testh", "proto": "any"}}} - _, err = NewFirewallFromConfig(l, cs, conf) + _, _, err = NewFirewallFromConfig(l, cs, conf) require.EqualError(t, err, "firewall.outbound rule #0; local_cidr did not parse; netip.ParsePrefix(\"testh\"): no '/'") // Test both group and groups conf = config.NewC(l) conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "group": "a", "groups": []string{"b", "c"}}}} - _, err = NewFirewallFromConfig(l, cs, conf) + _, _, err = NewFirewallFromConfig(l, cs, conf) require.EqualError(t, err, "firewall.inbound rule #0; only one of group or groups should be defined, both provided") } @@ -688,28 +905,28 @@ func TestAddFirewallRulesFromConfig(t *testing.T) { conf := config.NewC(l) mf := &mockFirewall{} conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "tcp", "host": "a"}}} - require.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf)) + require.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf, nil)) assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoTCP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall) // Test adding udp rule conf = config.NewC(l) mf = &mockFirewall{} conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "udp", "host": "a"}}} - require.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf)) + require.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf, nil)) assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoUDP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall) // Test adding icmp rule conf = config.NewC(l) mf = &mockFirewall{} conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "icmp", "host": "a"}}} - require.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf)) + require.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf, nil)) assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoICMP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall) // Test adding any rule conf = config.NewC(l) mf = &mockFirewall{} conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "host": "a"}}} - require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) + require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf, nil)) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, host: "a", ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall) // Test adding rule with cidr @@ -717,49 +934,49 @@ func TestAddFirewallRulesFromConfig(t *testing.T) { conf = config.NewC(l) mf = &mockFirewall{} conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "cidr": cidr.String()}}} - require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) + require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf, nil)) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: cidr, localIp: netip.Prefix{}}, mf.lastCall) // Test adding rule with local_cidr conf = config.NewC(l) mf = &mockFirewall{} conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "local_cidr": cidr.String()}}} - require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) + require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf, nil)) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: netip.Prefix{}, localIp: cidr}, mf.lastCall) // Test adding rule with ca_sha conf = config.NewC(l) mf = &mockFirewall{} conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "ca_sha": "12312313123"}}} - require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) + require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf, nil)) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: netip.Prefix{}, localIp: netip.Prefix{}, caSha: "12312313123"}, mf.lastCall) // Test adding rule with ca_name conf = config.NewC(l) mf = &mockFirewall{} conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "ca_name": "root01"}}} - require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) + require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf, nil)) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: netip.Prefix{}, localIp: netip.Prefix{}, caName: "root01"}, mf.lastCall) // Test single group conf = config.NewC(l) mf = &mockFirewall{} conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "group": "a"}}} - require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) + require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf, nil)) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall) // Test single groups conf = config.NewC(l) mf = &mockFirewall{} conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "groups": "a"}}} - require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) + require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf, nil)) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall) // Test multiple AND groups conf = config.NewC(l) mf = &mockFirewall{} conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "groups": []string{"a", "b"}}}} - require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) + require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf, nil)) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a", "b"}, ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall) // Test Add error @@ -767,7 +984,7 @@ func TestAddFirewallRulesFromConfig(t *testing.T) { mf = &mockFirewall{} mf.nextCallReturn = errors.New("test error") conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "host": "a"}}} - require.EqualError(t, AddFirewallRulesFromConfig(l, true, conf, mf), "firewall.inbound rule #0; `test error`") + require.EqualError(t, AddFirewallRulesFromConfig(l, true, conf, mf, nil), "firewall.inbound rule #0; `test error`") } func TestFirewall_convertRule(t *testing.T) { diff --git a/handshake_ix.go b/handshake_ix.go index daea526cb..1cbb67f11 100644 --- a/handshake_ix.go +++ b/handshake_ix.go @@ -218,6 +218,14 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet return } + if f.lightHouse.incomingHandshakeFiltering.Load() { + if !f.lightHouse.hf.IsHandshakeAllowed(remoteCert.Certificate.Groups(), certName, vpnAddrs, issuer, fingerprint) { + f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).Warn("handshake filtering denied incoming handshake") + f.lightHouse.metricFilteredHandshakes.Inc(1) + return + } + } + if addr.IsValid() { // addr can be invalid when the tunnel is being relayed. // We only want to apply the remote allow list for direct tunnels here diff --git a/interface.go b/interface.go index 21e198cf1..9d9212410 100644 --- a/interface.go +++ b/interface.go @@ -327,12 +327,14 @@ func (f *Interface) reloadFirewall(c *config.C) { return } - fw, err := NewFirewallFromConfig(f.l, f.pki.getCertState(), c) + fw, hf, err := NewFirewallFromConfig(f.l, f.pki.getCertState(), c) if err != nil { f.l.WithError(err).Error("Error while creating firewall during reload") return } + f.lightHouse.hf = hf + oldFw := f.firewall conntrack := oldFw.Conntrack conntrack.Lock() diff --git a/lighthouse.go b/lighthouse.go index ce37023e2..e15bcb665 100644 --- a/lighthouse.go +++ b/lighthouse.go @@ -40,6 +40,12 @@ type LightHouse struct { // map of vpn addr to answers addrMap map[netip.Addr]*RemoteList + // Controls incoming handshake filtering based on firewall rules. + incomingHandshakeFiltering atomic.Bool + + // Filters incoming handshakes according to the specified firewall rules. + hf *HandshakeFilter + // filters remote addresses allowed for each host // - When we are a lighthouse, this filters what addresses we store and // respond with. @@ -72,14 +78,15 @@ type LightHouse struct { calculatedRemotes atomic.Pointer[bart.Table[[]*calculatedRemote]] // Maps VpnAddr to []*calculatedRemote - metrics *MessageMetrics - metricHolepunchTx metrics.Counter - l *logrus.Logger + metrics *MessageMetrics + metricHolepunchTx metrics.Counter + metricFilteredHandshakes metrics.Counter + l *logrus.Logger } // NewLightHouseFromConfig will build a Lighthouse struct from the values provided in the config object // addrMap should be nil unless this is during a config reload -func NewLightHouseFromConfig(ctx context.Context, l *logrus.Logger, c *config.C, cs *CertState, pc udp.Conn, p *Punchy) (*LightHouse, error) { +func NewLightHouseFromConfig(ctx context.Context, l *logrus.Logger, c *config.C, cs *CertState, pc udp.Conn, p *Punchy, hf *HandshakeFilter) (*LightHouse, error) { amLighthouse := c.GetBool("lighthouse.am_lighthouse", false) nebulaPort := uint32(c.GetInt("listen.port", 0)) if amLighthouse && nebulaPort == 0 { @@ -105,18 +112,22 @@ func NewLightHouseFromConfig(ctx context.Context, l *logrus.Logger, c *config.C, punchConn: pc, punchy: p, queryChan: make(chan netip.Addr, c.GetUint32("handshakes.query_buffer", 64)), + hf: hf, l: l, } lighthouses := make(map[netip.Addr]struct{}) h.lighthouses.Store(&lighthouses) staticList := make(map[netip.Addr]struct{}) h.staticList.Store(&staticList) + h.incomingHandshakeFiltering.Store(false) if c.GetBool("stats.lighthouse_metrics", false) { h.metrics = newLighthouseMetrics() h.metricHolepunchTx = metrics.GetOrRegisterCounter("messages.tx.holepunch", nil) + h.metricFilteredHandshakes = metrics.GetOrRegisterCounter("handshakes.filtered", nil) } else { h.metricHolepunchTx = metrics.NilCounter{} + h.metricFilteredHandshakes = metrics.NilCounter{} } err := h.reload(c, true) @@ -233,6 +244,13 @@ func (lh *LightHouse) reload(c *config.C, initial bool) error { } } + if initial || c.HasChanged("lighthouse.incoming_handshake_filtering") { + lh.incomingHandshakeFiltering.Store(c.GetBool("lighthouse.incoming_handshake_filtering", false)) + if lh.incomingHandshakeFiltering.Load() { + lh.l.Info("Incoming handshake filtering enabled") + } + } + if initial || c.HasChanged("lighthouse.remote_allow_list") || c.HasChanged("lighthouse.remote_allow_ranges") { ral, err := NewRemoteAllowListFromConfig(c, "lighthouse.remote_allow_list", "lighthouse.remote_allow_ranges") if err != nil { diff --git a/lighthouse_test.go b/lighthouse_test.go index 3b1295a61..582effa85 100644 --- a/lighthouse_test.go +++ b/lighthouse_test.go @@ -42,14 +42,14 @@ func Test_lhStaticMapping(t *testing.T) { c := config.NewC(l) c.Settings["lighthouse"] = map[interface{}]interface{}{"hosts": []interface{}{lh1}} c.Settings["static_host_map"] = map[interface{}]interface{}{lh1: []interface{}{"1.1.1.1:4242"}} - _, err := NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil) + _, err := NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil, nil) require.NoError(t, err) lh2 := "10.128.0.3" c = config.NewC(l) c.Settings["lighthouse"] = map[interface{}]interface{}{"hosts": []interface{}{lh1, lh2}} c.Settings["static_host_map"] = map[interface{}]interface{}{lh1: []interface{}{"100.1.1.1:4242"}} - _, err = NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil) + _, err = NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil, nil) require.EqualError(t, err, "lighthouse 10.128.0.3 does not have a static_host_map entry") } @@ -71,7 +71,7 @@ func TestReloadLighthouseInterval(t *testing.T) { } c.Settings["static_host_map"] = map[interface{}]interface{}{lh1: []interface{}{"1.1.1.1:4242"}} - lh, err := NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil) + lh, err := NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil, nil) require.NoError(t, err) lh.ifce = &mockEncWriter{} @@ -99,7 +99,7 @@ func BenchmarkLighthouseHandleRequest(b *testing.B) { } c := config.NewC(l) - lh, err := NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil) + lh, err := NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil, nil) require.NoError(b, err) hAddr := netip.MustParseAddrPort("4.5.6.7:12345") @@ -202,7 +202,7 @@ func TestLighthouse_Memory(t *testing.T) { myVpnNetworks: []netip.Prefix{myVpnNet}, myVpnNetworksTable: nt, } - lh, err := NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil) + lh, err := NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil, nil) lh.ifce = &mockEncWriter{} require.NoError(t, err) lhh := lh.NewRequestHandler() @@ -288,7 +288,7 @@ func TestLighthouse_reload(t *testing.T) { myVpnNetworksTable: nt, } - lh, err := NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil) + lh, err := NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil, nil) require.NoError(t, err) nc := map[interface{}]interface{}{ diff --git a/main.go b/main.go index 7e94c32e0..ffc50e6be 100644 --- a/main.go +++ b/main.go @@ -60,7 +60,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg return nil, util.ContextualizeIfNeeded("Failed to load PKI from config", err) } - fw, err := NewFirewallFromConfig(l, pki.getCertState(), c) + fw, hf, err := NewFirewallFromConfig(l, pki.getCertState(), c) if err != nil { return nil, util.ContextualizeIfNeeded("Error while loading firewall rules", err) } @@ -185,7 +185,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg hostMap := NewHostMapFromConfig(l, c) punchy := NewPunchyFromConfig(l, c) - lightHouse, err := NewLightHouseFromConfig(ctx, l, c, pki.getCertState(), udpConns[0], punchy) + lightHouse, err := NewLightHouseFromConfig(ctx, l, c, pki.getCertState(), udpConns[0], punchy, hf) if err != nil { return nil, util.ContextualizeIfNeeded("Failed to initialize lighthouse handler", err) } From 028adbf4731b73c87f1043c66c44a9c5837ef2c7 Mon Sep 17 00:00:00 2001 From: Daniel Jampen Date: Wed, 19 Mar 2025 18:48:26 +0100 Subject: [PATCH 2/8] use MustParse* functions in firewall tests --- firewall_test.go | 33 ++++++++++++--------------------- 1 file changed, 12 insertions(+), 21 deletions(-) diff --git a/firewall_test.go b/firewall_test.go index c6ad76854..0ef6f022d 100644 --- a/firewall_test.go +++ b/firewall_test.go @@ -548,8 +548,7 @@ func TestHandshakeFilter_AddRuleToHandshakeFilter(t *testing.T) { assert.Empty(t, hf.AllowedCAShas) hf = NewHandshakeFilter() - ti, err := netip.ParsePrefix("1.2.3.4/32") - assert.NoError(t, err) + ti := netip.MustParsePrefix("1.2.3.4/32") hf.AddRule([]string{}, "", ti, "", "") assert.Contains(t, hf.AllowedCidrs, ti) assert.Empty(t, hf.AllowedGroups) @@ -604,23 +603,16 @@ func TestHandshakeFilter_IsHandshakeAllowed(t *testing.T) { assert.NotNil(t, hf.AllowedGroupsCombos) assert.NotNil(t, hf.AllowedCidrs) - ti, err := netip.ParsePrefix("1.2.3.0/24") - assert.NoError(t, err) - ais := make([]netip.Addr, 2) - ai, err := netip.ParseAddr("1.2.3.5") - assert.NoError(t, err) - ai2, err := netip.ParseAddr("1.1.1.2") - assert.NoError(t, err) - ais[0] = ai - ais[1] = ai2 - - aos := make([]netip.Addr, 2) - ao, err := netip.ParseAddr("1.2.0.1") - assert.NoError(t, err) - ao2, err := netip.ParseAddr("1.10.0.1") - assert.NoError(t, err) - aos[0] = ao - aos[1] = ao2 + ti := netip.MustParsePrefix("1.2.3.0/24") + ais := []netip.Addr{ + netip.MustParseAddr("1.2.3.5"), + netip.MustParseAddr("1.1.1.2"), + } + + aos := []netip.Addr{ + netip.MustParseAddr("1.2.0.1"), + netip.MustParseAddr("1.10.0.1"), + } hf.AddRule([]string{"g1"}, "", netip.Prefix{}, "", "") hf.AddRule([]string{}, "h1", netip.Prefix{}, "", "") @@ -658,8 +650,7 @@ func TestHandshakeFilter_IsHandshakeAllowed(t *testing.T) { assert.True(t, hf.IsHandshakeAllowed([]string{}, "", []netip.Addr{netip.Addr{}}, "", "3fc204e4d45e8b22ed0879bcd7cb5bf93cdc1c7a309c5dcedddc03aed33a47c6")) assert.False(t, hf.IsHandshakeAllowed([]string{}, "", []netip.Addr{netip.Addr{}}, "", "FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF")) - pAny, err := netip.ParsePrefix("0.0.0.0/0") - assert.NoError(t, err) + pAny := netip.MustParsePrefix("0.0.0.0/0") hf = NewHandshakeFilter() hf.AddRule([]string{"any"}, "", netip.Prefix{}, "", "") From f7f3ff337439b4c36a03535f3a247ce259679794 Mon Sep 17 00:00:00 2001 From: Daniel Jampen Date: Sun, 2 Mar 2025 10:04:29 +0100 Subject: [PATCH 3/8] implement host query protection --- examples/config.yml | 5 + firewall.go | 106 +++++- firewall_test.go | 181 +++++++++ interface.go | 2 + lighthouse.go | 132 ++++++- lighthouse_test.go | 14 +- nebula.pb.go | 887 ++++++++++++++++++++++++++++++++++++++++---- nebula.proto | 15 + outside.go | 2 +- remote_list.go | 15 + 10 files changed, 1266 insertions(+), 93 deletions(-) diff --git a/examples/config.yml b/examples/config.yml index 07eefb4ca..3b2d3dd70 100644 --- a/examples/config.yml +++ b/examples/config.yml @@ -71,6 +71,11 @@ lighthouse: # for relaying to function. #incoming_handshake_filtering: false + # This setting on a lighthouse determines whether to enforce the host query protection + # whitelist received from a node. On a node, this setting controls whether the node + # sends its handshake filtering whitelist to the lighthouses at all, not reloadable. + #enable_host_query_protection: false + # remote_allow_list allows you to control ip ranges that this node will # consider when handshaking to another node. By default, any remote IPs are # allowed. You can provide CIDRs here with `true` to allow and `false` to diff --git a/firewall.go b/firewall.go index 6814ed587..d1b285614 100644 --- a/firewall.go +++ b/firewall.go @@ -11,6 +11,7 @@ import ( "strconv" "strings" "sync" + "sync/atomic" "time" "github.com/gaissmai/bart" @@ -89,6 +90,9 @@ type HandshakeFilter struct { AllowedCidrs []netip.Prefix AllowedCANames map[string]struct{} AllowedCAShas map[string]struct{} + + IsEmtpy atomic.Bool + IsModifiedSinceLastMashalling atomic.Bool } // FirewallTable is the entry point for a rule, the evaluation order is: @@ -1039,7 +1043,7 @@ func parsePort(s string) (startPort, endPort int32, err error) { } func NewHandshakeFilter() *HandshakeFilter { - return &HandshakeFilter{ + hf := &HandshakeFilter{ AllowedHosts: make(map[string]struct{}), AllowedGroups: make(map[string]struct{}), AllowedGroupsCombos: make([]map[string]struct{}, 0), @@ -1047,11 +1051,18 @@ func NewHandshakeFilter() *HandshakeFilter { AllowedCANames: make(map[string]struct{}), AllowedCAShas: make(map[string]struct{}), } + + hf.IsModifiedSinceLastMashalling.Store(false) + hf.IsEmtpy.Store(true) + + return hf } func (hfws *HandshakeFilter) AddRule(groups []string, host string, localIp netip.Prefix, CAName string, CASha string) { + ruleAdded := false if host != "" { hfws.AllowedHosts[host] = struct{}{} + ruleAdded = true } if len(groups) > 1 { @@ -1066,8 +1077,10 @@ func (hfws *HandshakeFilter) AddRule(groups []string, host string, localIp netip gs, ) } + ruleAdded = true } else if len(groups) == 1 { hfws.AllowedGroups[groups[0]] = struct{}{} + ruleAdded = true } if localIp.IsValid() { @@ -1075,14 +1088,23 @@ func (hfws *HandshakeFilter) AddRule(groups []string, host string, localIp netip hfws.AllowedCidrs, localIp, ) + ruleAdded = true } if CAName != "" { hfws.AllowedCANames[CAName] = struct{}{} + ruleAdded = true } if CASha != "" { hfws.AllowedCAShas[CASha] = struct{}{} + ruleAdded = true + } + + hfws.IsModifiedSinceLastMashalling.Store(ruleAdded) + + if ruleAdded { + hfws.IsEmtpy.Store(false) } } @@ -1134,6 +1156,88 @@ func (hfws *HandshakeFilter) IsHandshakeAllowed(groups []string, host string, vp return false } +func (hfws *HandshakeFilter) MarshalToHfw() *HandshakeFilteringWhitelist { + hfw := &HandshakeFilteringWhitelist{ + AllowedHosts: make([]string, len(hfws.AllowedHosts)), + AllowedGroups: make([]string, len(hfws.AllowedGroups)), + AllowedGroupsCombos: make([]*GroupsCombos, len(hfws.AllowedGroupsCombos)), + AllowedCidrs: make([]string, len(hfws.AllowedCidrs)), + AllowedCANames: make([]string, len(hfws.AllowedCANames)), + AllowedCAShas: make([]string, len(hfws.AllowedCAShas)), + SetEmpty: hfws.IsEmtpy.Load(), + } + + for host := range hfws.AllowedHosts { + hfw.AllowedHosts = append(hfw.AllowedHosts, host) + } + + for group := range hfws.AllowedGroups { + hfw.AllowedGroups = append(hfw.AllowedGroups, group) + } + + for i, groupCombo := range hfws.AllowedGroupsCombos { + gc := &GroupsCombos{ + Group: make([]string, len(groupCombo)), + } + j := 0 + for group := range groupCombo { + gc.Group[j] = group + j += 1 + } + hfw.AllowedGroupsCombos[i] = gc + } + + for i, cidr := range hfws.AllowedCidrs { + hfw.AllowedCidrs[i] = cidr.String() + } + + for ca := range hfws.AllowedCANames { + hfw.AllowedCANames = append(hfw.AllowedCANames, ca) + } + + for fp := range hfws.AllowedCAShas { + hfw.AllowedCAShas = append(hfw.AllowedCAShas, fp) + } + + hfws.IsModifiedSinceLastMashalling.Store(false) + + return hfw +} + +func (hfws *HandshakeFilter) UnmarshalFromHfw(hfw *HandshakeFilteringWhitelist) { + if hfw == nil { + return + } + + for _, h := range hfw.AllowedHosts { + hfws.AddRule(nil, h, netip.Prefix{}, "", "") + } + + for _, g := range hfw.AllowedGroups { + hfws.AddRule([]string{g}, "", netip.Prefix{}, "", "") + } + + for _, gc := range hfw.AllowedGroupsCombos { + hfws.AddRule(gc.Group, "", netip.Prefix{}, "", "") + } + + for _, cs := range hfw.AllowedCidrs { + c, err := netip.ParsePrefix(cs) + if err != nil { + continue + } + hfws.AddRule(nil, "", c, "", "") + } + + for _, ca := range hfw.AllowedCANames { + hfws.AddRule(nil, "", netip.Prefix{}, ca, "") + } + + for _, sha := range hfw.AllowedCAShas { + hfws.AddRule(nil, "", netip.Prefix{}, "", sha) + } +} + func isSubset(subset map[string]struct{}, superset []string) bool { ls := len(subset) s := make(map[string]struct{}, ls) diff --git a/firewall_test.go b/firewall_test.go index 0ef6f022d..c848c6531 100644 --- a/firewall_test.go +++ b/firewall_test.go @@ -519,11 +519,15 @@ func TestHandshakeFilter_AddRuleToHandshakeFilter(t *testing.T) { assert.Empty(t, hf.AllowedCANames) assert.Empty(t, hf.AllowedCAShas) + assert.True(t, hf.IsEmtpy.Load()) + assert.False(t, hf.IsModifiedSinceLastMashalling.Load()) hf.AddRule([]string{}, "", netip.Prefix{}, "", "") assert.Empty(t, hf.AllowedCidrs) assert.Empty(t, hf.AllowedGroups) assert.Empty(t, hf.AllowedGroupsCombos) assert.Empty(t, hf.AllowedHosts) + assert.True(t, hf.IsEmtpy.Load()) + assert.False(t, hf.IsModifiedSinceLastMashalling.Load()) assert.Empty(t, hf.AllowedCANames) assert.Empty(t, hf.AllowedCAShas) @@ -536,6 +540,8 @@ func TestHandshakeFilter_AddRuleToHandshakeFilter(t *testing.T) { assert.Empty(t, hf.AllowedHosts) assert.Empty(t, hf.AllowedCANames) assert.Empty(t, hf.AllowedCAShas) + assert.False(t, hf.IsEmtpy.Load()) + assert.True(t, hf.IsModifiedSinceLastMashalling.Load()) hf = NewHandshakeFilter() h := "h1" @@ -546,6 +552,8 @@ func TestHandshakeFilter_AddRuleToHandshakeFilter(t *testing.T) { assert.Contains(t, hf.AllowedHosts, h) assert.Empty(t, hf.AllowedCANames) assert.Empty(t, hf.AllowedCAShas) + assert.False(t, hf.IsEmtpy.Load()) + assert.True(t, hf.IsModifiedSinceLastMashalling.Load()) hf = NewHandshakeFilter() ti := netip.MustParsePrefix("1.2.3.4/32") @@ -556,6 +564,8 @@ func TestHandshakeFilter_AddRuleToHandshakeFilter(t *testing.T) { assert.Empty(t, hf.AllowedHosts) assert.Empty(t, hf.AllowedCANames) assert.Empty(t, hf.AllowedCAShas) + assert.False(t, hf.IsEmtpy.Load()) + assert.True(t, hf.IsModifiedSinceLastMashalling.Load()) hf = NewHandshakeFilter() groups := []string{"g1", "g2"} @@ -594,6 +604,8 @@ func TestHandshakeFilter_AddRuleToHandshakeFilter(t *testing.T) { assert.Empty(t, hf.AllowedHosts) assert.Empty(t, hf.AllowedCANames) assert.Contains(t, hf.AllowedCAShas, s) + assert.False(t, hf.IsEmtpy.Load()) + assert.True(t, hf.IsModifiedSinceLastMashalling.Load()) } func TestHandshakeFilter_IsHandshakeAllowed(t *testing.T) { @@ -677,6 +689,175 @@ func TestHandshakeFilter_IsHandshakeAllowed(t *testing.T) { assert.False(t, hf.IsHandshakeAllowed([]string{"g4", "g5"}, "", []netip.Addr{netip.Addr{}}, "", "")) } +func TestHandshakeFilter_Marshalling(t *testing.T) { + hf := NewHandshakeFilter() + assert.NotNil(t, hf.AllowedHosts) + assert.NotNil(t, hf.AllowedGroups) + assert.NotNil(t, hf.AllowedGroupsCombos) + assert.NotNil(t, hf.AllowedCidrs) + assert.Empty(t, hf.AllowedCANames) + assert.Empty(t, hf.AllowedCAShas) + assert.False(t, hf.IsModifiedSinceLastMashalling.Load()) + assert.True(t, hf.IsEmtpy.Load()) + hfw := hf.MarshalToHfw() + assert.False(t, hf.IsModifiedSinceLastMashalling.Load()) + assert.Empty(t, hfw.AllowedGroups) + assert.Empty(t, hfw.AllowedGroupsCombos) + assert.Empty(t, hfw.AllowedHosts) + assert.Empty(t, hfw.AllowedCidrs) + assert.Empty(t, hfw.AllowedCANames) + assert.Empty(t, hfw.AllowedCAShas) + assert.True(t, hfw.SetEmpty) + + hf = NewHandshakeFilter() + g := "g1" + hf.AddRule([]string{g}, "", netip.Prefix{}, "", "") + assert.True(t, hf.IsModifiedSinceLastMashalling.Load()) + hfw = hf.MarshalToHfw() + assert.Contains(t, hfw.AllowedGroups, g) + assert.Empty(t, hfw.AllowedGroupsCombos) + assert.Empty(t, hfw.AllowedHosts) + assert.Empty(t, hfw.AllowedCidrs) + assert.Empty(t, hfw.AllowedCANames) + assert.Empty(t, hfw.AllowedCAShas) + assert.False(t, hfw.SetEmpty) + assert.False(t, hf.IsModifiedSinceLastMashalling.Load()) + hf = NewHandshakeFilter() + hf.UnmarshalFromHfw(hfw) + assert.Empty(t, hf.AllowedGroupsCombos) + assert.Empty(t, hf.AllowedHosts) + assert.Empty(t, hf.AllowedCidrs) + assert.Empty(t, hf.AllowedCANames) + assert.Empty(t, hf.AllowedCAShas) + assert.Contains(t, hf.AllowedGroups, g) + assert.True(t, hf.IsModifiedSinceLastMashalling.Load()) + + hf = NewHandshakeFilter() + gc := []string{"g1", "g2"} + hf.AddRule(gc, "", netip.Prefix{}, "", "") + assert.Len(t, hf.AllowedGroupsCombos, 1) + assert.True(t, hf.IsModifiedSinceLastMashalling.Load()) + hfw = hf.MarshalToHfw() + assert.Empty(t, hfw.AllowedGroups) + assert.Len(t, hfw.AllowedGroupsCombos, 1) + for _, g := range gc { + assert.Contains(t, hfw.AllowedGroupsCombos[0].Group, g) + } + assert.Empty(t, hfw.AllowedHosts) + assert.Empty(t, hfw.AllowedCidrs) + assert.Empty(t, hfw.AllowedCANames) + assert.Empty(t, hfw.AllowedCAShas) + assert.False(t, hfw.SetEmpty) + assert.False(t, hf.IsModifiedSinceLastMashalling.Load()) + hf = NewHandshakeFilter() + hf.UnmarshalFromHfw(hfw) + assert.Empty(t, hf.AllowedGroups) + gs := make(map[string]struct{}) + for _, g := range hfw.AllowedGroupsCombos[0].Group { + gs[g] = struct{}{} + } + for _, g := range gc { + assert.Contains(t, hf.AllowedGroupsCombos[0], g) + } + assert.Empty(t, hf.AllowedHosts) + assert.Empty(t, hf.AllowedCidrs) + assert.Empty(t, hf.AllowedCANames) + assert.Empty(t, hf.AllowedCAShas) + assert.True(t, hf.IsModifiedSinceLastMashalling.Load()) + + hf = NewHandshakeFilter() + h := "h1" + hf.AddRule(nil, h, netip.Prefix{}, "", "") + assert.True(t, hf.IsModifiedSinceLastMashalling.Load()) + hfw = hf.MarshalToHfw() + assert.Empty(t, hfw.AllowedGroups) + assert.Empty(t, hfw.AllowedGroupsCombos) + assert.Contains(t, hfw.AllowedHosts, h) + assert.Empty(t, hfw.AllowedCidrs) + assert.Empty(t, hfw.AllowedCANames) + assert.Empty(t, hfw.AllowedCAShas) + assert.False(t, hfw.SetEmpty) + assert.False(t, hf.IsModifiedSinceLastMashalling.Load()) + hf = NewHandshakeFilter() + hf.UnmarshalFromHfw(hfw) + assert.Empty(t, hf.AllowedGroups) + assert.Empty(t, hf.AllowedGroupsCombos) + assert.Contains(t, hf.AllowedHosts, h) + assert.Empty(t, hf.AllowedCidrs) + assert.Empty(t, hf.AllowedCANames) + assert.Empty(t, hf.AllowedCAShas) + assert.True(t, hf.IsModifiedSinceLastMashalling.Load()) + + hf = NewHandshakeFilter() + p, _ := netip.ParsePrefix("10.1.1.1/32") + hf.AddRule(nil, "", p, "", "") + assert.True(t, hf.IsModifiedSinceLastMashalling.Load()) + hfw = hf.MarshalToHfw() + assert.Empty(t, hfw.AllowedGroups) + assert.Empty(t, hfw.AllowedGroupsCombos) + assert.Empty(t, hfw.AllowedHosts) + assert.Equal(t, hfw.AllowedCidrs[0], p.String()) + assert.Empty(t, hfw.AllowedCANames) + assert.Empty(t, hfw.AllowedCAShas) + assert.False(t, hfw.SetEmpty) + assert.False(t, hf.IsModifiedSinceLastMashalling.Load()) + hf = NewHandshakeFilter() + hf.UnmarshalFromHfw(hfw) + assert.Empty(t, hf.AllowedGroups) + assert.Empty(t, hf.AllowedGroupsCombos) + assert.Empty(t, hf.AllowedHosts) + assert.Contains(t, hf.AllowedCidrs, p) + assert.Empty(t, hf.AllowedCANames) + assert.Empty(t, hf.AllowedCAShas) + assert.True(t, hf.IsModifiedSinceLastMashalling.Load()) + + hf = NewHandshakeFilter() + ca := "TestCA" + hf.AddRule(nil, "", netip.Prefix{}, ca, "") + assert.True(t, hf.IsModifiedSinceLastMashalling.Load()) + hfw = hf.MarshalToHfw() + assert.Empty(t, hfw.AllowedGroups) + assert.Empty(t, hfw.AllowedGroupsCombos) + assert.Empty(t, hfw.AllowedHosts) + assert.Empty(t, hfw.AllowedCidrs) + assert.Contains(t, hfw.AllowedCANames, ca) + assert.Empty(t, hfw.AllowedCAShas) + assert.False(t, hfw.SetEmpty) + assert.False(t, hf.IsModifiedSinceLastMashalling.Load()) + hf = NewHandshakeFilter() + hf.UnmarshalFromHfw(hfw) + assert.Empty(t, hf.AllowedGroups) + assert.Empty(t, hf.AllowedGroupsCombos) + assert.Empty(t, hf.AllowedHosts) + assert.Empty(t, hf.AllowedCidrs) + assert.Contains(t, hf.AllowedCANames, ca) + assert.Empty(t, hf.AllowedCAShas) + assert.True(t, hf.IsModifiedSinceLastMashalling.Load()) + + hf = NewHandshakeFilter() + fp := "3fc204e4d45e8b22ed0879bcd7cb5bf93cdc1c7a309c5dcedddc03aed33a47c6" + hf.AddRule(nil, "", netip.Prefix{}, "", fp) + assert.True(t, hf.IsModifiedSinceLastMashalling.Load()) + hfw = hf.MarshalToHfw() + assert.Empty(t, hfw.AllowedGroups) + assert.Empty(t, hfw.AllowedGroupsCombos) + assert.Empty(t, hfw.AllowedHosts) + assert.Empty(t, hfw.AllowedCidrs) + assert.Empty(t, hfw.AllowedCANames) + assert.Contains(t, hfw.AllowedCAShas, fp) + assert.False(t, hfw.SetEmpty) + assert.False(t, hf.IsModifiedSinceLastMashalling.Load()) + hf = NewHandshakeFilter() + hf.UnmarshalFromHfw(hfw) + assert.Empty(t, hf.AllowedGroups) + assert.Empty(t, hf.AllowedGroupsCombos) + assert.Empty(t, hf.AllowedHosts, h) + assert.Empty(t, hf.AllowedCidrs) + assert.Empty(t, hfw.AllowedCANames) + assert.Contains(t, hf.AllowedCAShas, fp) + assert.True(t, hf.IsModifiedSinceLastMashalling.Load()) +} + func Test_isSubset(t *testing.T) { subset := make(map[string]struct{}, 2) subset["g1"] = struct{}{} diff --git a/interface.go b/interface.go index 9d9212410..0b7528e20 100644 --- a/interface.go +++ b/interface.go @@ -333,6 +333,8 @@ func (f *Interface) reloadFirewall(c *config.C) { return } + // Set to send updated whitelist to lh after firewall rule reload + hf.IsModifiedSinceLastMashalling.Store(true) f.lightHouse.hf = hf oldFw := f.firewall diff --git a/lighthouse.go b/lighthouse.go index e15bcb665..882467c0d 100644 --- a/lighthouse.go +++ b/lighthouse.go @@ -46,6 +46,9 @@ type LightHouse struct { // Filters incoming handshakes according to the specified firewall rules. hf *HandshakeFilter + // Controls weather to send the handshake white list rules to the lighthouses. + enableHostQueryProtection atomic.Bool + // filters remote addresses allowed for each host // - When we are a lighthouse, this filters what addresses we store and // respond with. @@ -78,10 +81,11 @@ type LightHouse struct { calculatedRemotes atomic.Pointer[bart.Table[[]*calculatedRemote]] // Maps VpnAddr to []*calculatedRemote - metrics *MessageMetrics - metricHolepunchTx metrics.Counter - metricFilteredHandshakes metrics.Counter - l *logrus.Logger + metrics *MessageMetrics + metricHolepunchTx metrics.Counter + metricFilteredHostQueries metrics.Counter + metricFilteredHandshakes metrics.Counter + l *logrus.Logger } // NewLightHouseFromConfig will build a Lighthouse struct from the values provided in the config object @@ -120,13 +124,18 @@ func NewLightHouseFromConfig(ctx context.Context, l *logrus.Logger, c *config.C, staticList := make(map[netip.Addr]struct{}) h.staticList.Store(&staticList) h.incomingHandshakeFiltering.Store(false) + h.enableHostQueryProtection.Store(false) if c.GetBool("stats.lighthouse_metrics", false) { h.metrics = newLighthouseMetrics() h.metricHolepunchTx = metrics.GetOrRegisterCounter("messages.tx.holepunch", nil) h.metricFilteredHandshakes = metrics.GetOrRegisterCounter("handshakes.filtered", nil) + if amLighthouse { + h.metricFilteredHostQueries = metrics.GetOrRegisterCounter("lighthouse.hostqueries.filtered", nil) + } } else { h.metricHolepunchTx = metrics.NilCounter{} + h.metricFilteredHostQueries = metrics.NilCounter{} h.metricFilteredHandshakes = metrics.NilCounter{} } @@ -251,6 +260,13 @@ func (lh *LightHouse) reload(c *config.C, initial bool) error { } } + if initial { + lh.enableHostQueryProtection.Store(c.GetBool("lighthouse.enable_host_query_protection", false)) + if lh.enableHostQueryProtection.Load() { + lh.l.Info("Host query protection enabled") + } + } + if initial || c.HasChanged("lighthouse.remote_allow_list") || c.HasChanged("lighthouse.remote_allow_ranges") { ral, err := NewRemoteAllowListFromConfig(c, "lighthouse.remote_allow_list", "lighthouse.remote_allow_ranges") if err != nil { @@ -539,6 +555,22 @@ func (lh *LightHouse) queryAndPrepMessage(vpnAddr netip.Addr, f func(*cache) (in return false, 0, nil } +func (lh *LightHouse) IsHostQueryAllowed(targetAddr netip.Addr, groups []string, host string, queryAddrs []netip.Addr, CAName string, CASha string) bool { + lh.RLock() + // Do we have an entry in the main cache? + if v, ok := lh.addrMap[targetAddr]; ok { + // Swap lh lock for remote list lock + v.RLock() + defer v.RUnlock() + + lh.RUnlock() + + return v.hf.IsEmtpy.Load() || v.hf.IsHandshakeAllowed(groups, host, queryAddrs, CAName, CASha) + } + lh.RUnlock() + return true +} + func (lh *LightHouse) DeleteVpnAddrs(allVpnAddrs []netip.Addr) { // First we check the static mapping // and do nothing if it is there @@ -890,11 +922,18 @@ func (lh *LightHouse) SendUpdate() { nb := make([]byte, 12, 12) out := make([]byte, mtu) - var v1Update, v2Update []byte + // cache for v1Update/v2Update with or without hfwl + updateMessageCache := map[cert.Version]map[bool][]byte{ + cert.Version1: map[bool][]byte{}, + cert.Version2: map[bool][]byte{}, + } + + sendHfw := lh.enableHostQueryProtection.Load() && + lh.hf.IsModifiedSinceLastMashalling.Load() + var err error updated := 0 lighthouses := lh.GetLighthouses() - for lhVpnAddr := range lighthouses { var v cert.Version hi := lh.ifce.GetHostInfo(lhVpnAddr) @@ -903,8 +942,10 @@ func (lh *LightHouse) SendUpdate() { } else { v = lh.ifce.GetCertState().defaultVersion } + + sendHfwToLh := hi == nil || sendHfw if v == cert.Version1 { - if v1Update == nil { + if _, ok := updateMessageCache[v][sendHfwToLh]; !ok { if !lh.myVpnNetworks[0].Addr().Is4() { lh.l.WithField("lighthouseAddr", lhVpnAddr). Warn("cannot update lighthouse using v1 protocol without an IPv4 address") @@ -929,7 +970,20 @@ func (lh *LightHouse) SendUpdate() { }, } - v1Update, err = msg.Marshal() + if sendHfwToLh { + msg.Details.HandshakeFilteringWhitelist = lh.hf.MarshalToHfw() + if msg.Details.HandshakeFilteringWhitelist != nil && lh.l.Level >= logrus.DebugLevel { + lh.l.WithField("hosts", msg.Details.HandshakeFilteringWhitelist.AllowedHosts). + WithField("groups", msg.Details.HandshakeFilteringWhitelist.AllowedGroups). + WithField("groupcombos", msg.Details.HandshakeFilteringWhitelist.AllowedGroupsCombos). + WithField("cidrs", msg.Details.HandshakeFilteringWhitelist.AllowedCidrs). + WithField("canames", msg.Details.HandshakeFilteringWhitelist.AllowedCANames). + WithField("cashas", msg.Details.HandshakeFilteringWhitelist.AllowedCAShas). + Debug("Sending handshake filtering whitelist to lighthouse") + } + } + + updateMessageCache[v][sendHfwToLh], err = msg.Marshal() if err != nil { lh.l.WithError(err).WithField("lighthouseAddr", lhVpnAddr). Error("Error while marshaling for lighthouse v1 update") @@ -937,11 +991,11 @@ func (lh *LightHouse) SendUpdate() { } } - lh.ifce.SendMessageToVpnAddr(header.LightHouse, 0, lhVpnAddr, v1Update, nb, out) + lh.ifce.SendMessageToVpnAddr(header.LightHouse, 0, lhVpnAddr, updateMessageCache[v][sendHfwToLh], nb, out) updated++ } else if v == cert.Version2 { - if v2Update == nil { + if _, ok := updateMessageCache[v][sendHfwToLh]; !ok { var relays []*Addr for _, r := range lh.GetRelaysForMe() { relays = append(relays, netAddrToProtoAddr(r)) @@ -957,7 +1011,22 @@ func (lh *LightHouse) SendUpdate() { }, } - v2Update, err = msg.Marshal() + if sendHfwToLh { + msg.Details.HandshakeFilteringWhitelist = lh.hf.MarshalToHfw() + if lh.l.Level >= logrus.DebugLevel { + if msg.Details.HandshakeFilteringWhitelist != nil && lh.l.Level >= logrus.DebugLevel { + lh.l.WithField("hosts", msg.Details.HandshakeFilteringWhitelist.AllowedHosts). + WithField("groups", msg.Details.HandshakeFilteringWhitelist.AllowedGroups). + WithField("groupcombos", msg.Details.HandshakeFilteringWhitelist.AllowedGroupsCombos). + WithField("cidrs", msg.Details.HandshakeFilteringWhitelist.AllowedCidrs). + WithField("canames", msg.Details.HandshakeFilteringWhitelist.AllowedCANames). + WithField("cashas", msg.Details.HandshakeFilteringWhitelist.AllowedCAShas). + Debug("Sending handshake filtering whitelist to lighthouse") + } + } + } + + updateMessageCache[v][sendHfwToLh], err = msg.Marshal() if err != nil { lh.l.WithError(err).WithField("lighthouseAddr", lhVpnAddr). Error("Error while marshaling for lighthouse v2 update") @@ -965,7 +1034,7 @@ func (lh *LightHouse) SendUpdate() { } } - lh.ifce.SendMessageToVpnAddr(header.LightHouse, 0, lhVpnAddr, v2Update, nb, out) + lh.ifce.SendMessageToVpnAddr(header.LightHouse, 0, lhVpnAddr, updateMessageCache[v][sendHfwToLh], nb, out) updated++ } else { @@ -1023,12 +1092,15 @@ func (lhh *LightHouseHandler) resetMeta() *NebulaMeta { details.OldRelayVpnAddrs = details.OldRelayVpnAddrs[:0] details.OldVpnAddr = 0 details.VpnAddr = nil + details.HandshakeFilteringWhitelist = nil lhh.meta.Details = details return lhh.meta } -func (lhh *LightHouseHandler) HandleRequest(rAddr netip.AddrPort, fromVpnAddrs []netip.Addr, p []byte, w EncWriter) { +func (lhh *LightHouseHandler) HandleRequest(rAddr netip.AddrPort, hostInfo *HostInfo, p []byte, w EncWriter) { + fromVpnAddrs := hostInfo.vpnAddrs + n := lhh.resetMeta() err := n.Unmarshal(p) if err != nil { @@ -1047,7 +1119,7 @@ func (lhh *LightHouseHandler) HandleRequest(rAddr netip.AddrPort, fromVpnAddrs [ switch n.Type { case NebulaMeta_HostQuery: - lhh.handleHostQuery(n, fromVpnAddrs, rAddr, w) + lhh.handleHostQuery(n, hostInfo, rAddr, w) case NebulaMeta_HostQueryReply: lhh.handleHostQueryReply(n, fromVpnAddrs) @@ -1064,7 +1136,7 @@ func (lhh *LightHouseHandler) HandleRequest(rAddr netip.AddrPort, fromVpnAddrs [ } } -func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, fromVpnAddrs []netip.Addr, addr netip.AddrPort, w EncWriter) { +func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, hostInfo *HostInfo, addr netip.AddrPort, w EncWriter) { // Exit if we don't answer queries if !lhh.lh.amLighthouse { if lhh.l.Level >= logrus.DebugLevel { @@ -1073,6 +1145,7 @@ func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, fromVpnAddrs []neti return } + fromVpnAddrs := hostInfo.vpnAddrs useVersion := cert.Version1 var queryVpnAddr netip.Addr if n.Details.OldVpnAddr != 0 { @@ -1090,6 +1163,21 @@ func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, fromVpnAddrs []neti return } + if lhh.lh.enableHostQueryProtection.Load() { + c := hostInfo.ConnectionState.peerCert.Certificate + fp, err := c.Fingerprint() + if err != nil { + lhh.l.WithField("CAName", c.Name()).WithError(err).Warn("could not calculate fingerprint for provided CA") + return + } + + if !lhh.lh.IsHostQueryAllowed(queryVpnAddr, c.Groups(), c.Name(), fromVpnAddrs, c.Issuer(), fp) { + lhh.l.WithField("from", fromVpnAddrs).WithField("queryVpnAddr", queryVpnAddr).Warn("Preventing query due to host query protection") + lhh.lh.metricFilteredHostQueries.Inc(1) + return + } + } + found, ln, err := lhh.lh.queryAndPrepMessage(queryVpnAddr, func(c *cache) (int, error) { n = lhh.resetMeta() n.Type = NebulaMeta_HostQueryReply @@ -1287,6 +1375,7 @@ func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, fromVp } relays := n.Details.GetRelays() + hfws := n.Details.GetHandshakeFilteringWhitelist() lhh.lh.Lock() am := lhh.lh.unlockedGetRemoteList(fromVpnAddrs) @@ -1296,8 +1385,21 @@ func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, fromVp am.unlockedSetV4(fromVpnAddrs[0], detailsVpnAddr, n.Details.V4AddrPorts, lhh.lh.unlockedShouldAddV4) am.unlockedSetV6(fromVpnAddrs[0], detailsVpnAddr, n.Details.V6AddrPorts, lhh.lh.unlockedShouldAddV6) am.unlockedSetRelay(fromVpnAddrs[0], relays) + am.unlockedSetHandshakeFilteringWhitelist(hfws) am.Unlock() + if hfws != nil && lhh.l.Level >= logrus.DebugLevel { + lhh.l.WithField("vpnAddrs", fromVpnAddrs). + WithField("hosts", hfws.AllowedHosts). + WithField("groups", hfws.AllowedGroups). + WithField("groupcombos", hfws.AllowedGroupsCombos). + WithField("cidrs", hfws.AllowedCidrs). + WithField("canames", hfws.AllowedCANames). + WithField("cashas", hfws.AllowedCAShas). + WithField("setempty", hfws.SetEmpty). + Debug("Received host query filter") + } + n = lhh.resetMeta() n.Type = NebulaMeta_HostUpdateNotificationAck diff --git a/lighthouse_test.go b/lighthouse_test.go index 582effa85..86bfbf2f2 100644 --- a/lighthouse_test.go +++ b/lighthouse_test.go @@ -133,7 +133,9 @@ func BenchmarkLighthouseHandleRequest(b *testing.B) { mw := &mockEncWriter{} - hi := []netip.Addr{vpnIp2} + hi := &HostInfo{ + vpnAddrs: []netip.Addr{vpnIp2}, + } b.Run("notfound", func(b *testing.B) { lhh := lh.NewRequestHandler() req := &NebulaMeta{ @@ -326,7 +328,10 @@ func newLHHostRequest(fromAddr netip.AddrPort, myVpnIp, queryVpnIp netip.Addr, l w := &testEncWriter{ metaFilter: &filter, } - lhh.HandleRequest(fromAddr, []netip.Addr{myVpnIp}, b, w) + hi := &HostInfo{ + vpnAddrs: []netip.Addr{myVpnIp}, + } + lhh.HandleRequest(fromAddr, hi, b, w) return w.lastReply } @@ -357,7 +362,10 @@ func newLHHostUpdate(fromAddr netip.AddrPort, vpnIp netip.Addr, addrs []netip.Ad } w := &testEncWriter{} - lhh.HandleRequest(fromAddr, []netip.Addr{vpnIp}, b, w) + hi := &HostInfo{ + vpnAddrs: []netip.Addr{vpnIp}, + } + lhh.HandleRequest(fromAddr, hi, b, w) } type testLhReply struct { diff --git a/nebula.pb.go b/nebula.pb.go index 2fd2ff665..70877fd8b 100644 --- a/nebula.pb.go +++ b/nebula.pb.go @@ -96,7 +96,7 @@ func (x NebulaPing_MessageType) String() string { } func (NebulaPing_MessageType) EnumDescriptor() ([]byte, []int) { - return fileDescriptor_2d65afa7693df5ef, []int{5, 0} + return fileDescriptor_2d65afa7693df5ef, []int{7, 0} } type NebulaControl_MessageType int32 @@ -124,7 +124,7 @@ func (x NebulaControl_MessageType) String() string { } func (NebulaControl_MessageType) EnumDescriptor() ([]byte, []int) { - return fileDescriptor_2d65afa7693df5ef, []int{8, 0} + return fileDescriptor_2d65afa7693df5ef, []int{10, 0} } type NebulaMeta struct { @@ -180,13 +180,14 @@ func (m *NebulaMeta) GetDetails() *NebulaMetaDetails { } type NebulaMetaDetails struct { - OldVpnAddr uint32 `protobuf:"varint,1,opt,name=OldVpnAddr,proto3" json:"OldVpnAddr,omitempty"` // Deprecated: Do not use. - VpnAddr *Addr `protobuf:"bytes,6,opt,name=VpnAddr,proto3" json:"VpnAddr,omitempty"` - OldRelayVpnAddrs []uint32 `protobuf:"varint,5,rep,packed,name=OldRelayVpnAddrs,proto3" json:"OldRelayVpnAddrs,omitempty"` // Deprecated: Do not use. - RelayVpnAddrs []*Addr `protobuf:"bytes,7,rep,name=RelayVpnAddrs,proto3" json:"RelayVpnAddrs,omitempty"` - V4AddrPorts []*V4AddrPort `protobuf:"bytes,2,rep,name=V4AddrPorts,proto3" json:"V4AddrPorts,omitempty"` - V6AddrPorts []*V6AddrPort `protobuf:"bytes,4,rep,name=V6AddrPorts,proto3" json:"V6AddrPorts,omitempty"` - Counter uint32 `protobuf:"varint,3,opt,name=counter,proto3" json:"counter,omitempty"` + OldVpnAddr uint32 `protobuf:"varint,1,opt,name=OldVpnAddr,proto3" json:"OldVpnAddr,omitempty"` // Deprecated: Do not use. + VpnAddr *Addr `protobuf:"bytes,6,opt,name=VpnAddr,proto3" json:"VpnAddr,omitempty"` + OldRelayVpnAddrs []uint32 `protobuf:"varint,5,rep,packed,name=OldRelayVpnAddrs,proto3" json:"OldRelayVpnAddrs,omitempty"` // Deprecated: Do not use. + RelayVpnAddrs []*Addr `protobuf:"bytes,7,rep,name=RelayVpnAddrs,proto3" json:"RelayVpnAddrs,omitempty"` + HandshakeFilteringWhitelist *HandshakeFilteringWhitelist `protobuf:"bytes,8,opt,name=HandshakeFilteringWhitelist,proto3" json:"HandshakeFilteringWhitelist,omitempty"` + V4AddrPorts []*V4AddrPort `protobuf:"bytes,2,rep,name=V4AddrPorts,proto3" json:"V4AddrPorts,omitempty"` + V6AddrPorts []*V6AddrPort `protobuf:"bytes,4,rep,name=V6AddrPorts,proto3" json:"V6AddrPorts,omitempty"` + Counter uint32 `protobuf:"varint,3,opt,name=counter,proto3" json:"counter,omitempty"` } func (m *NebulaMetaDetails) Reset() { *m = NebulaMetaDetails{} } @@ -252,6 +253,13 @@ func (m *NebulaMetaDetails) GetRelayVpnAddrs() []*Addr { return nil } +func (m *NebulaMetaDetails) GetHandshakeFilteringWhitelist() *HandshakeFilteringWhitelist { + if m != nil { + return m.HandshakeFilteringWhitelist + } + return nil +} + func (m *NebulaMetaDetails) GetV4AddrPorts() []*V4AddrPort { if m != nil { return m.V4AddrPorts @@ -437,6 +445,142 @@ func (m *V6AddrPort) GetPort() uint32 { return 0 } +type HandshakeFilteringWhitelist struct { + AllowedHosts []string `protobuf:"bytes,1,rep,name=AllowedHosts,proto3" json:"AllowedHosts,omitempty"` + AllowedGroups []string `protobuf:"bytes,2,rep,name=AllowedGroups,proto3" json:"AllowedGroups,omitempty"` + AllowedGroupsCombos []*GroupsCombos `protobuf:"bytes,3,rep,name=AllowedGroupsCombos,proto3" json:"AllowedGroupsCombos,omitempty"` + AllowedCidrs []string `protobuf:"bytes,4,rep,name=AllowedCidrs,proto3" json:"AllowedCidrs,omitempty"` + AllowedCANames []string `protobuf:"bytes,5,rep,name=AllowedCANames,proto3" json:"AllowedCANames,omitempty"` + AllowedCAShas []string `protobuf:"bytes,6,rep,name=AllowedCAShas,proto3" json:"AllowedCAShas,omitempty"` + SetEmpty bool `protobuf:"varint,7,opt,name=SetEmpty,proto3" json:"SetEmpty,omitempty"` +} + +func (m *HandshakeFilteringWhitelist) Reset() { *m = HandshakeFilteringWhitelist{} } +func (m *HandshakeFilteringWhitelist) String() string { return proto.CompactTextString(m) } +func (*HandshakeFilteringWhitelist) ProtoMessage() {} +func (*HandshakeFilteringWhitelist) Descriptor() ([]byte, []int) { + return fileDescriptor_2d65afa7693df5ef, []int{5} +} +func (m *HandshakeFilteringWhitelist) XXX_Unmarshal(b []byte) error { + return m.Unmarshal(b) +} +func (m *HandshakeFilteringWhitelist) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + if deterministic { + return xxx_messageInfo_HandshakeFilteringWhitelist.Marshal(b, m, deterministic) + } else { + b = b[:cap(b)] + n, err := m.MarshalToSizedBuffer(b) + if err != nil { + return nil, err + } + return b[:n], nil + } +} +func (m *HandshakeFilteringWhitelist) XXX_Merge(src proto.Message) { + xxx_messageInfo_HandshakeFilteringWhitelist.Merge(m, src) +} +func (m *HandshakeFilteringWhitelist) XXX_Size() int { + return m.Size() +} +func (m *HandshakeFilteringWhitelist) XXX_DiscardUnknown() { + xxx_messageInfo_HandshakeFilteringWhitelist.DiscardUnknown(m) +} + +var xxx_messageInfo_HandshakeFilteringWhitelist proto.InternalMessageInfo + +func (m *HandshakeFilteringWhitelist) GetAllowedHosts() []string { + if m != nil { + return m.AllowedHosts + } + return nil +} + +func (m *HandshakeFilteringWhitelist) GetAllowedGroups() []string { + if m != nil { + return m.AllowedGroups + } + return nil +} + +func (m *HandshakeFilteringWhitelist) GetAllowedGroupsCombos() []*GroupsCombos { + if m != nil { + return m.AllowedGroupsCombos + } + return nil +} + +func (m *HandshakeFilteringWhitelist) GetAllowedCidrs() []string { + if m != nil { + return m.AllowedCidrs + } + return nil +} + +func (m *HandshakeFilteringWhitelist) GetAllowedCANames() []string { + if m != nil { + return m.AllowedCANames + } + return nil +} + +func (m *HandshakeFilteringWhitelist) GetAllowedCAShas() []string { + if m != nil { + return m.AllowedCAShas + } + return nil +} + +func (m *HandshakeFilteringWhitelist) GetSetEmpty() bool { + if m != nil { + return m.SetEmpty + } + return false +} + +type GroupsCombos struct { + Group []string `protobuf:"bytes,1,rep,name=Group,proto3" json:"Group,omitempty"` +} + +func (m *GroupsCombos) Reset() { *m = GroupsCombos{} } +func (m *GroupsCombos) String() string { return proto.CompactTextString(m) } +func (*GroupsCombos) ProtoMessage() {} +func (*GroupsCombos) Descriptor() ([]byte, []int) { + return fileDescriptor_2d65afa7693df5ef, []int{6} +} +func (m *GroupsCombos) XXX_Unmarshal(b []byte) error { + return m.Unmarshal(b) +} +func (m *GroupsCombos) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + if deterministic { + return xxx_messageInfo_GroupsCombos.Marshal(b, m, deterministic) + } else { + b = b[:cap(b)] + n, err := m.MarshalToSizedBuffer(b) + if err != nil { + return nil, err + } + return b[:n], nil + } +} +func (m *GroupsCombos) XXX_Merge(src proto.Message) { + xxx_messageInfo_GroupsCombos.Merge(m, src) +} +func (m *GroupsCombos) XXX_Size() int { + return m.Size() +} +func (m *GroupsCombos) XXX_DiscardUnknown() { + xxx_messageInfo_GroupsCombos.DiscardUnknown(m) +} + +var xxx_messageInfo_GroupsCombos proto.InternalMessageInfo + +func (m *GroupsCombos) GetGroup() []string { + if m != nil { + return m.Group + } + return nil +} + type NebulaPing struct { Type NebulaPing_MessageType `protobuf:"varint,1,opt,name=Type,proto3,enum=nebula.NebulaPing_MessageType" json:"Type,omitempty"` Time uint64 `protobuf:"varint,2,opt,name=Time,proto3" json:"Time,omitempty"` @@ -446,7 +590,7 @@ func (m *NebulaPing) Reset() { *m = NebulaPing{} } func (m *NebulaPing) String() string { return proto.CompactTextString(m) } func (*NebulaPing) ProtoMessage() {} func (*NebulaPing) Descriptor() ([]byte, []int) { - return fileDescriptor_2d65afa7693df5ef, []int{5} + return fileDescriptor_2d65afa7693df5ef, []int{7} } func (m *NebulaPing) XXX_Unmarshal(b []byte) error { return m.Unmarshal(b) @@ -498,7 +642,7 @@ func (m *NebulaHandshake) Reset() { *m = NebulaHandshake{} } func (m *NebulaHandshake) String() string { return proto.CompactTextString(m) } func (*NebulaHandshake) ProtoMessage() {} func (*NebulaHandshake) Descriptor() ([]byte, []int) { - return fileDescriptor_2d65afa7693df5ef, []int{6} + return fileDescriptor_2d65afa7693df5ef, []int{8} } func (m *NebulaHandshake) XXX_Unmarshal(b []byte) error { return m.Unmarshal(b) @@ -554,7 +698,7 @@ func (m *NebulaHandshakeDetails) Reset() { *m = NebulaHandshakeDetails{} func (m *NebulaHandshakeDetails) String() string { return proto.CompactTextString(m) } func (*NebulaHandshakeDetails) ProtoMessage() {} func (*NebulaHandshakeDetails) Descriptor() ([]byte, []int) { - return fileDescriptor_2d65afa7693df5ef, []int{7} + return fileDescriptor_2d65afa7693df5ef, []int{9} } func (m *NebulaHandshakeDetails) XXX_Unmarshal(b []byte) error { return m.Unmarshal(b) @@ -639,7 +783,7 @@ func (m *NebulaControl) Reset() { *m = NebulaControl{} } func (m *NebulaControl) String() string { return proto.CompactTextString(m) } func (*NebulaControl) ProtoMessage() {} func (*NebulaControl) Descriptor() ([]byte, []int) { - return fileDescriptor_2d65afa7693df5ef, []int{8} + return fileDescriptor_2d65afa7693df5ef, []int{10} } func (m *NebulaControl) XXX_Unmarshal(b []byte) error { return m.Unmarshal(b) @@ -728,6 +872,8 @@ func init() { proto.RegisterType((*Addr)(nil), "nebula.Addr") proto.RegisterType((*V4AddrPort)(nil), "nebula.V4AddrPort") proto.RegisterType((*V6AddrPort)(nil), "nebula.V6AddrPort") + proto.RegisterType((*HandshakeFilteringWhitelist)(nil), "nebula.HandshakeFilteringWhitelist") + proto.RegisterType((*GroupsCombos)(nil), "nebula.GroupsCombos") proto.RegisterType((*NebulaPing)(nil), "nebula.NebulaPing") proto.RegisterType((*NebulaHandshake)(nil), "nebula.NebulaHandshake") proto.RegisterType((*NebulaHandshakeDetails)(nil), "nebula.NebulaHandshakeDetails") @@ -737,57 +883,67 @@ func init() { func init() { proto.RegisterFile("nebula.proto", fileDescriptor_2d65afa7693df5ef) } var fileDescriptor_2d65afa7693df5ef = []byte{ - // 785 bytes of a gzipped FileDescriptorProto - 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0x84, 0x55, 0xcd, 0x6e, 0xeb, 0x44, - 0x14, 0x8e, 0x1d, 0x27, 0x4e, 0x4f, 0x7e, 0xae, 0x39, 0x15, 0xc1, 0x41, 0x22, 0x0a, 0x5e, 0x54, - 0x57, 0x2c, 0x72, 0x51, 0x5a, 0xae, 0x58, 0x72, 0x1b, 0x84, 0xd2, 0xaa, 0x3f, 0x61, 0x54, 0x8a, - 0xc4, 0x06, 0xb9, 0xf6, 0xd0, 0x58, 0x71, 0x3c, 0xa9, 0x3d, 0x41, 0xcd, 0x5b, 0xf0, 0x30, 0x3c, - 0x04, 0xec, 0xba, 0x42, 0x2c, 0x51, 0xbb, 0x64, 0xc9, 0x0b, 0xa0, 0x19, 0xff, 0x27, 0x86, 0xbb, - 0x9b, 0x73, 0xbe, 0xef, 0x3b, 0x73, 0xe6, 0xf3, 0x9c, 0x31, 0x74, 0x02, 0x7a, 0xb7, 0xf1, 0xed, - 0xf1, 0x3a, 0x64, 0x9c, 0x61, 0x33, 0x8e, 0xac, 0xbf, 0x55, 0x80, 0x2b, 0xb9, 0xbc, 0xa4, 0xdc, - 0xc6, 0x09, 0x68, 0x37, 0xdb, 0x35, 0x35, 0x95, 0x91, 0xf2, 0xba, 0x37, 0x19, 0x8e, 0x13, 0x4d, - 0xce, 0x18, 0x5f, 0xd2, 0x28, 0xb2, 0xef, 0xa9, 0x60, 0x11, 0xc9, 0xc5, 0x63, 0xd0, 0xbf, 0xa6, - 0xdc, 0xf6, 0xfc, 0xc8, 0x54, 0x47, 0xca, 0xeb, 0xf6, 0x64, 0xb0, 0x2f, 0x4b, 0x08, 0x24, 0x65, - 0x5a, 0xff, 0x28, 0xd0, 0x2e, 0x94, 0xc2, 0x16, 0x68, 0x57, 0x2c, 0xa0, 0x46, 0x0d, 0xbb, 0x70, - 0x30, 0x63, 0x11, 0xff, 0x76, 0x43, 0xc3, 0xad, 0xa1, 0x20, 0x42, 0x2f, 0x0b, 0x09, 0x5d, 0xfb, - 0x5b, 0x43, 0xc5, 0x8f, 0xa1, 0x2f, 0x72, 0xdf, 0xad, 0x5d, 0x9b, 0xd3, 0x2b, 0xc6, 0xbd, 0x9f, - 0x3c, 0xc7, 0xe6, 0x1e, 0x0b, 0x8c, 0x3a, 0x0e, 0xe0, 0x43, 0x81, 0x5d, 0xb2, 0x9f, 0xa9, 0x5b, - 0x82, 0xb4, 0x14, 0x9a, 0x6f, 0x02, 0x67, 0x51, 0x82, 0x1a, 0xd8, 0x03, 0x10, 0xd0, 0xf7, 0x0b, - 0x66, 0xaf, 0x3c, 0xa3, 0x89, 0x87, 0xf0, 0x2a, 0x8f, 0xe3, 0x6d, 0x75, 0xd1, 0xd9, 0xdc, 0xe6, - 0x8b, 0xe9, 0x82, 0x3a, 0x4b, 0xa3, 0x25, 0x3a, 0xcb, 0xc2, 0x98, 0x72, 0x80, 0x9f, 0xc0, 0xa0, - 0xba, 0xb3, 0x77, 0xce, 0xd2, 0x00, 0xeb, 0x77, 0x15, 0x3e, 0xd8, 0x33, 0x05, 0x2d, 0x80, 0x6b, - 0xdf, 0xbd, 0x5d, 0x07, 0xef, 0x5c, 0x37, 0x94, 0xd6, 0x77, 0x4f, 0x55, 0x53, 0x21, 0x85, 0x2c, - 0x1e, 0x81, 0x9e, 0x12, 0x9a, 0xd2, 0xe4, 0x4e, 0x6a, 0xb2, 0xc8, 0x91, 0x14, 0xc4, 0x31, 0x18, - 0xd7, 0xbe, 0x4b, 0xa8, 0x6f, 0x6f, 0x93, 0x54, 0x64, 0x36, 0x46, 0xf5, 0xa4, 0xe2, 0x1e, 0x86, - 0x13, 0xe8, 0x96, 0xc9, 0xfa, 0xa8, 0xbe, 0x57, 0xbd, 0x4c, 0xc1, 0x13, 0x68, 0xdf, 0x9e, 0x88, - 0xe5, 0x9c, 0x85, 0x5c, 0x7c, 0x74, 0xa1, 0xc0, 0x54, 0x91, 0x43, 0xa4, 0x48, 0x93, 0xaa, 0xb7, - 0xb9, 0x4a, 0xdb, 0x51, 0xbd, 0x2d, 0xa8, 0x72, 0x1a, 0x9a, 0xa0, 0x3b, 0x6c, 0x13, 0x70, 0x1a, - 0x9a, 0x75, 0x61, 0x0c, 0x49, 0x43, 0xeb, 0x08, 0x34, 0x79, 0xe2, 0x1e, 0xa8, 0x33, 0x4f, 0xba, - 0xa6, 0x11, 0x75, 0xe6, 0x89, 0xf8, 0x82, 0xc9, 0x9b, 0xa8, 0x11, 0xf5, 0x82, 0x59, 0x27, 0x00, - 0x79, 0x1b, 0x88, 0xb1, 0x2a, 0x76, 0x99, 0xc4, 0x15, 0x10, 0x34, 0x81, 0x49, 0x4d, 0x97, 0xc8, - 0xb5, 0xf5, 0x15, 0x40, 0xde, 0xc6, 0xfb, 0xf6, 0xc8, 0x2a, 0xd4, 0x0b, 0x15, 0x1e, 0xd3, 0xc1, - 0x9a, 0x7b, 0xc1, 0xfd, 0xff, 0x0f, 0x96, 0x60, 0x54, 0x0c, 0x16, 0x82, 0x76, 0xe3, 0xad, 0x68, - 0xb2, 0x8f, 0x5c, 0x5b, 0xd6, 0xde, 0xd8, 0x08, 0xb1, 0x51, 0xc3, 0x03, 0x68, 0xc4, 0x97, 0x50, - 0xb1, 0x7e, 0x84, 0x57, 0x71, 0xdd, 0x99, 0x1d, 0xb8, 0xd1, 0xc2, 0x5e, 0x52, 0xfc, 0x32, 0x9f, - 0x51, 0x45, 0x5e, 0x9f, 0x9d, 0x0e, 0x32, 0xe6, 0xee, 0xa0, 0x8a, 0x26, 0x66, 0x2b, 0xdb, 0x91, - 0x4d, 0x74, 0x88, 0x5c, 0x5b, 0x7f, 0x28, 0xd0, 0xaf, 0xd6, 0x09, 0xfa, 0x94, 0x86, 0x5c, 0xee, - 0xd2, 0x21, 0x72, 0x8d, 0x47, 0xd0, 0x3b, 0x0b, 0x3c, 0xee, 0xd9, 0x9c, 0x85, 0x67, 0x81, 0x4b, - 0x1f, 0x13, 0xa7, 0x77, 0xb2, 0x82, 0x47, 0x68, 0xb4, 0x66, 0x81, 0x4b, 0x13, 0x5e, 0xec, 0xe7, - 0x4e, 0x16, 0xfb, 0xd0, 0x9c, 0x32, 0xb6, 0xf4, 0xa8, 0xa9, 0x49, 0x67, 0x92, 0x28, 0xf3, 0xab, - 0x91, 0xfb, 0x85, 0x23, 0x68, 0x8b, 0x1e, 0x6e, 0x69, 0x18, 0x79, 0x2c, 0x30, 0x5b, 0xb2, 0x60, - 0x31, 0x75, 0xae, 0xb5, 0x9a, 0x86, 0x7e, 0xae, 0xb5, 0x74, 0xa3, 0x65, 0xfd, 0x5a, 0x87, 0x6e, - 0x7c, 0xb0, 0x29, 0x0b, 0x78, 0xc8, 0x7c, 0xfc, 0xa2, 0xf4, 0xdd, 0x3e, 0x2d, 0xbb, 0x96, 0x90, - 0x2a, 0x3e, 0xdd, 0xe7, 0x70, 0x98, 0x1d, 0x4e, 0x0e, 0x4f, 0xf1, 0xdc, 0x55, 0x90, 0x50, 0x64, - 0xc7, 0x2c, 0x28, 0x62, 0x07, 0xaa, 0x20, 0xfc, 0x0c, 0x7a, 0xe9, 0x38, 0xdf, 0x30, 0x79, 0xa9, - 0xb5, 0xec, 0xe9, 0xd8, 0x41, 0x8a, 0xcf, 0xc2, 0x37, 0x21, 0x5b, 0x49, 0x76, 0x23, 0x63, 0xef, - 0x61, 0x38, 0x86, 0x76, 0xb1, 0x70, 0xd5, 0x93, 0x53, 0x24, 0x64, 0xcf, 0x48, 0x56, 0x5c, 0xaf, - 0x50, 0x94, 0x29, 0xd6, 0xec, 0xbf, 0xfe, 0x00, 0x7d, 0xc0, 0x69, 0x48, 0x6d, 0x4e, 0x25, 0x9f, - 0xd0, 0x87, 0x0d, 0x8d, 0xb8, 0xa1, 0xe0, 0x47, 0x70, 0x58, 0xca, 0x0b, 0x4b, 0x22, 0x6a, 0xa8, - 0xa7, 0xc7, 0xbf, 0x3d, 0x0f, 0x95, 0xa7, 0xe7, 0xa1, 0xf2, 0xd7, 0xf3, 0x50, 0xf9, 0xe5, 0x65, - 0x58, 0x7b, 0x7a, 0x19, 0xd6, 0xfe, 0x7c, 0x19, 0xd6, 0x7e, 0x18, 0xdc, 0x7b, 0x7c, 0xb1, 0xb9, - 0x1b, 0x3b, 0x6c, 0xf5, 0x26, 0xf2, 0x6d, 0x67, 0xb9, 0x78, 0x78, 0x13, 0xb7, 0x74, 0xd7, 0x94, - 0x3f, 0xc2, 0xe3, 0x7f, 0x03, 0x00, 0x00, 0xff, 0xff, 0xea, 0x6f, 0xbc, 0x50, 0x18, 0x07, 0x00, - 0x00, + // 949 bytes of a gzipped FileDescriptorProto + 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0x84, 0x56, 0xcd, 0x72, 0x1b, 0x45, + 0x10, 0xd6, 0x4a, 0xab, 0xbf, 0xd6, 0x4f, 0x96, 0x76, 0x30, 0xeb, 0x50, 0xa8, 0xc4, 0x92, 0x72, + 0xb9, 0x38, 0x28, 0x94, 0x6d, 0x52, 0x1c, 0x51, 0x04, 0x46, 0x49, 0xc5, 0x8e, 0x99, 0x18, 0xa7, + 0x8a, 0x0b, 0xb5, 0xd6, 0x0e, 0xde, 0x29, 0xad, 0x76, 0x94, 0xdd, 0x11, 0x44, 0x6f, 0xc1, 0xa3, + 0x70, 0xe0, 0xca, 0x9d, 0x63, 0x4e, 0x14, 0x47, 0xca, 0x3e, 0x72, 0xe4, 0x05, 0xa8, 0x99, 0xfd, + 0x97, 0x16, 0xe7, 0x36, 0xdd, 0xfd, 0x75, 0xef, 0x37, 0x5f, 0xb7, 0x7a, 0x04, 0x5d, 0x9f, 0x5e, + 0xad, 0x3c, 0x7b, 0xb4, 0x0c, 0xb8, 0xe0, 0xd8, 0x88, 0x2c, 0xeb, 0x9f, 0x2a, 0xc0, 0x99, 0x3a, + 0x9e, 0x52, 0x61, 0xe3, 0x21, 0xe8, 0x17, 0xeb, 0x25, 0x35, 0xb5, 0xa1, 0x76, 0xd0, 0x3f, 0x1c, + 0x8c, 0xe2, 0x9c, 0x0c, 0x31, 0x3a, 0xa5, 0x61, 0x68, 0x5f, 0x53, 0x89, 0x22, 0x0a, 0x8b, 0x47, + 0xd0, 0xfc, 0x8a, 0x0a, 0x9b, 0x79, 0xa1, 0x59, 0x1d, 0x6a, 0x07, 0x9d, 0xc3, 0xbd, 0xed, 0xb4, + 0x18, 0x40, 0x12, 0xa4, 0xf5, 0xaf, 0x06, 0x9d, 0x5c, 0x29, 0x6c, 0x81, 0x7e, 0xc6, 0x7d, 0x6a, + 0x54, 0xb0, 0x07, 0xed, 0x29, 0x0f, 0xc5, 0xb7, 0x2b, 0x1a, 0xac, 0x0d, 0x0d, 0x11, 0xfa, 0xa9, + 0x49, 0xe8, 0xd2, 0x5b, 0x1b, 0x55, 0x7c, 0x00, 0xbb, 0xd2, 0xf7, 0xdd, 0xd2, 0xb1, 0x05, 0x3d, + 0xe3, 0x82, 0xfd, 0xc8, 0x66, 0xb6, 0x60, 0xdc, 0x37, 0x6a, 0xb8, 0x07, 0xef, 0xcb, 0xd8, 0x29, + 0xff, 0x89, 0x3a, 0x85, 0x90, 0x9e, 0x84, 0xce, 0x57, 0xfe, 0xcc, 0x2d, 0x84, 0xea, 0xd8, 0x07, + 0x90, 0xa1, 0x57, 0x2e, 0xb7, 0x17, 0xcc, 0x68, 0xe0, 0x0e, 0xdc, 0xcb, 0xec, 0xe8, 0xb3, 0x4d, + 0xc9, 0xec, 0xdc, 0x16, 0xee, 0xc4, 0xa5, 0xb3, 0xb9, 0xd1, 0x92, 0xcc, 0x52, 0x33, 0x82, 0xb4, + 0xf1, 0x23, 0xd8, 0x2b, 0x67, 0x36, 0x9e, 0xcd, 0x0d, 0xb0, 0x7e, 0xad, 0xc1, 0x7b, 0x5b, 0xa2, + 0xa0, 0x05, 0xf0, 0xc2, 0x73, 0x2e, 0x97, 0xfe, 0xd8, 0x71, 0x02, 0x25, 0x7d, 0xef, 0x49, 0xd5, + 0xd4, 0x48, 0xce, 0x8b, 0xfb, 0xd0, 0x4c, 0x00, 0x0d, 0x25, 0x72, 0x37, 0x11, 0x59, 0xfa, 0x48, + 0x12, 0xc4, 0x11, 0x18, 0x2f, 0x3c, 0x87, 0x50, 0xcf, 0x5e, 0xc7, 0xae, 0xd0, 0xac, 0x0f, 0x6b, + 0x71, 0xc5, 0xad, 0x18, 0x1e, 0x42, 0xaf, 0x08, 0x6e, 0x0e, 0x6b, 0x5b, 0xd5, 0x8b, 0x10, 0xa4, + 0xf0, 0xe1, 0xd4, 0xf6, 0x9d, 0xd0, 0xb5, 0xe7, 0xf4, 0x84, 0x79, 0x82, 0x06, 0xcc, 0xbf, 0x7e, + 0xe5, 0x32, 0x41, 0x3d, 0x16, 0x0a, 0xb3, 0xa5, 0xf8, 0x7d, 0x92, 0x54, 0xb8, 0x03, 0x4a, 0xee, + 0xaa, 0x83, 0xc7, 0xd0, 0xb9, 0x3c, 0x96, 0x5f, 0x3c, 0xe7, 0x81, 0x90, 0xb3, 0x25, 0x89, 0x61, + 0x52, 0x36, 0x0b, 0x91, 0x3c, 0x4c, 0x65, 0x3d, 0xce, 0xb2, 0xf4, 0x8d, 0xac, 0xc7, 0xb9, 0xac, + 0x0c, 0x86, 0x26, 0x34, 0x67, 0x7c, 0xe5, 0x0b, 0x1a, 0x98, 0x35, 0xa9, 0x3f, 0x49, 0x4c, 0x6b, + 0x1f, 0x74, 0x25, 0x6c, 0x1f, 0xaa, 0x53, 0xa6, 0x9a, 0xa3, 0x93, 0xea, 0x94, 0x49, 0xfb, 0x39, + 0x57, 0x03, 0xaf, 0x93, 0xea, 0x73, 0x6e, 0x1d, 0x03, 0x64, 0x34, 0x10, 0xa3, 0xac, 0xa8, 0x99, + 0x24, 0xaa, 0x80, 0xa0, 0xcb, 0x98, 0xca, 0xe9, 0x11, 0x75, 0xb6, 0xbe, 0x04, 0xc8, 0x68, 0xbc, + 0xeb, 0x1b, 0x69, 0x85, 0x5a, 0xae, 0xc2, 0xef, 0xd5, 0x3b, 0xbb, 0x81, 0x16, 0x74, 0xc7, 0x9e, + 0xc7, 0x7f, 0xa6, 0x8e, 0x1c, 0xcc, 0xd0, 0xd4, 0x86, 0xb5, 0x83, 0x36, 0x29, 0xf8, 0xf0, 0x21, + 0xf4, 0x62, 0xfb, 0x9b, 0x80, 0xaf, 0x96, 0x91, 0xd6, 0x6d, 0x52, 0x74, 0xe2, 0x09, 0xec, 0x14, + 0x1c, 0x13, 0xbe, 0xb8, 0xe2, 0xa1, 0x59, 0x53, 0x0a, 0xdf, 0x4f, 0x14, 0xce, 0xc7, 0x48, 0x59, + 0x42, 0x8e, 0xd1, 0x84, 0xc9, 0x89, 0xd3, 0x0b, 0x8c, 0x94, 0x0f, 0xf7, 0xa1, 0x9f, 0xd8, 0xe3, + 0x33, 0x7b, 0x41, 0xa3, 0x21, 0x6e, 0x93, 0x0d, 0x6f, 0x8e, 0xf9, 0x64, 0xfc, 0xd2, 0xb5, 0x43, + 0xb3, 0x51, 0x60, 0x1e, 0x39, 0xf1, 0x01, 0xb4, 0x5e, 0x52, 0xf1, 0xf5, 0x62, 0x29, 0xd6, 0x66, + 0x73, 0xa8, 0x1d, 0xb4, 0x48, 0x6a, 0x5b, 0x0f, 0xa1, 0x5b, 0x60, 0x77, 0x1f, 0xea, 0xca, 0x8e, + 0x85, 0x8a, 0x0c, 0xeb, 0x4d, 0xb2, 0x25, 0xcf, 0x99, 0x7f, 0x7d, 0xf7, 0x96, 0x94, 0x88, 0x92, + 0x2d, 0x89, 0xa0, 0x5f, 0xb0, 0x05, 0x8d, 0xbb, 0xa9, 0xce, 0x96, 0xb5, 0xb5, 0x03, 0x65, 0xb2, + 0x51, 0xc1, 0x36, 0xd4, 0xa3, 0x8d, 0xa2, 0x59, 0x3f, 0xc0, 0xbd, 0xa8, 0x6e, 0xda, 0x64, 0xfc, + 0x22, 0x5b, 0xb8, 0x9a, 0xfa, 0xad, 0x6d, 0x30, 0x48, 0x91, 0x9b, 0x5b, 0x57, 0x92, 0x98, 0x2e, + 0xec, 0x99, 0x22, 0xd1, 0x25, 0xea, 0x6c, 0xfd, 0xa9, 0xc1, 0x6e, 0x79, 0x9e, 0x84, 0x4f, 0x68, + 0x20, 0xd4, 0x57, 0xba, 0x44, 0x9d, 0x65, 0x67, 0x9e, 0xfa, 0x4c, 0x30, 0x5b, 0xf0, 0xe0, 0xa9, + 0xef, 0xd0, 0x37, 0xf1, 0x3c, 0x6f, 0x78, 0x25, 0x8e, 0xd0, 0x70, 0xc9, 0x7d, 0x87, 0xc6, 0xb8, + 0x68, 0x6a, 0x37, 0xbc, 0xb8, 0x0b, 0x8d, 0x09, 0xe7, 0x73, 0x46, 0x4d, 0x5d, 0x29, 0x13, 0x5b, + 0xa9, 0x5e, 0xf5, 0x4c, 0x2f, 0x1c, 0x42, 0x47, 0x72, 0xb8, 0xa4, 0x41, 0xc8, 0xb8, 0xaf, 0x16, + 0x4d, 0x8f, 0xe4, 0x5d, 0xcf, 0xf4, 0x56, 0xc3, 0x68, 0x3e, 0xd3, 0x5b, 0x4d, 0xa3, 0x65, 0xfd, + 0x56, 0x83, 0x5e, 0x74, 0xb1, 0x09, 0xf7, 0x45, 0xc0, 0x3d, 0xfc, 0xbc, 0xd0, 0xb7, 0x8f, 0x8b, + 0xaa, 0xc5, 0xa0, 0x92, 0xd6, 0x7d, 0x06, 0x3b, 0xe9, 0xe5, 0xd4, 0x26, 0xcc, 0xdf, 0xbb, 0x2c, + 0x24, 0x33, 0xd2, 0x6b, 0xe6, 0x32, 0x22, 0x05, 0xca, 0x42, 0xf8, 0x29, 0xf4, 0x93, 0xdd, 0x7c, + 0xc1, 0xd5, 0xea, 0xd0, 0xd3, 0x77, 0x60, 0x23, 0x92, 0xdf, 0xf1, 0x27, 0x01, 0x5f, 0x28, 0x74, + 0x3d, 0x45, 0x6f, 0xc5, 0x70, 0x04, 0x9d, 0x7c, 0xe1, 0xb2, 0xf7, 0x23, 0x0f, 0x48, 0xdf, 0x84, + 0xb4, 0x78, 0xb3, 0x24, 0xa3, 0x08, 0xb1, 0xa6, 0xff, 0xf7, 0x9c, 0xef, 0x02, 0x4e, 0x02, 0x6a, + 0x0b, 0xaa, 0xf0, 0x84, 0xbe, 0x5e, 0xd1, 0x50, 0x18, 0x1a, 0x7e, 0x00, 0x3b, 0x05, 0xbf, 0x94, + 0x24, 0xa4, 0x46, 0xf5, 0xc9, 0xd1, 0x1f, 0x37, 0x03, 0xed, 0xed, 0xcd, 0x40, 0xfb, 0xfb, 0x66, + 0xa0, 0xfd, 0x72, 0x3b, 0xa8, 0xbc, 0xbd, 0x1d, 0x54, 0xfe, 0xba, 0x1d, 0x54, 0xbe, 0xdf, 0xbb, + 0x66, 0xc2, 0x5d, 0x5d, 0x8d, 0x66, 0x7c, 0xf1, 0x28, 0xf4, 0xec, 0xd9, 0xdc, 0x7d, 0xfd, 0x28, + 0xa2, 0x74, 0xd5, 0x50, 0xff, 0x6a, 0x8e, 0xfe, 0x0b, 0x00, 0x00, 0xff, 0xff, 0x6f, 0xb9, 0x8e, + 0x82, 0xe5, 0x08, 0x00, 0x00, } func (m *NebulaMeta) Marshal() (dAtA []byte, err error) { @@ -850,6 +1006,18 @@ func (m *NebulaMetaDetails) MarshalToSizedBuffer(dAtA []byte) (int, error) { _ = i var l int _ = l + if m.HandshakeFilteringWhitelist != nil { + { + size, err := m.HandshakeFilteringWhitelist.MarshalToSizedBuffer(dAtA[:i]) + if err != nil { + return 0, err + } + i -= size + i = encodeVarintNebula(dAtA, i, uint64(size)) + } + i-- + dAtA[i] = 0x42 + } if len(m.RelayVpnAddrs) > 0 { for iNdEx := len(m.RelayVpnAddrs) - 1; iNdEx >= 0; iNdEx-- { { @@ -877,20 +1045,20 @@ func (m *NebulaMetaDetails) MarshalToSizedBuffer(dAtA []byte) (int, error) { dAtA[i] = 0x32 } if len(m.OldRelayVpnAddrs) > 0 { - dAtA4 := make([]byte, len(m.OldRelayVpnAddrs)*10) - var j3 int + dAtA5 := make([]byte, len(m.OldRelayVpnAddrs)*10) + var j4 int for _, num := range m.OldRelayVpnAddrs { for num >= 1<<7 { - dAtA4[j3] = uint8(uint64(num)&0x7f | 0x80) + dAtA5[j4] = uint8(uint64(num)&0x7f | 0x80) num >>= 7 - j3++ + j4++ } - dAtA4[j3] = uint8(num) - j3++ + dAtA5[j4] = uint8(num) + j4++ } - i -= j3 - copy(dAtA[i:], dAtA4[:j3]) - i = encodeVarintNebula(dAtA, i, uint64(j3)) + i -= j4 + copy(dAtA[i:], dAtA5[:j4]) + i = encodeVarintNebula(dAtA, i, uint64(j4)) i-- dAtA[i] = 0x2a } @@ -1039,6 +1207,130 @@ func (m *V6AddrPort) MarshalToSizedBuffer(dAtA []byte) (int, error) { return len(dAtA) - i, nil } +func (m *HandshakeFilteringWhitelist) Marshal() (dAtA []byte, err error) { + size := m.Size() + dAtA = make([]byte, size) + n, err := m.MarshalToSizedBuffer(dAtA[:size]) + if err != nil { + return nil, err + } + return dAtA[:n], nil +} + +func (m *HandshakeFilteringWhitelist) MarshalTo(dAtA []byte) (int, error) { + size := m.Size() + return m.MarshalToSizedBuffer(dAtA[:size]) +} + +func (m *HandshakeFilteringWhitelist) MarshalToSizedBuffer(dAtA []byte) (int, error) { + i := len(dAtA) + _ = i + var l int + _ = l + if m.SetEmpty { + i-- + if m.SetEmpty { + dAtA[i] = 1 + } else { + dAtA[i] = 0 + } + i-- + dAtA[i] = 0x38 + } + if len(m.AllowedCAShas) > 0 { + for iNdEx := len(m.AllowedCAShas) - 1; iNdEx >= 0; iNdEx-- { + i -= len(m.AllowedCAShas[iNdEx]) + copy(dAtA[i:], m.AllowedCAShas[iNdEx]) + i = encodeVarintNebula(dAtA, i, uint64(len(m.AllowedCAShas[iNdEx]))) + i-- + dAtA[i] = 0x32 + } + } + if len(m.AllowedCANames) > 0 { + for iNdEx := len(m.AllowedCANames) - 1; iNdEx >= 0; iNdEx-- { + i -= len(m.AllowedCANames[iNdEx]) + copy(dAtA[i:], m.AllowedCANames[iNdEx]) + i = encodeVarintNebula(dAtA, i, uint64(len(m.AllowedCANames[iNdEx]))) + i-- + dAtA[i] = 0x2a + } + } + if len(m.AllowedCidrs) > 0 { + for iNdEx := len(m.AllowedCidrs) - 1; iNdEx >= 0; iNdEx-- { + i -= len(m.AllowedCidrs[iNdEx]) + copy(dAtA[i:], m.AllowedCidrs[iNdEx]) + i = encodeVarintNebula(dAtA, i, uint64(len(m.AllowedCidrs[iNdEx]))) + i-- + dAtA[i] = 0x22 + } + } + if len(m.AllowedGroupsCombos) > 0 { + for iNdEx := len(m.AllowedGroupsCombos) - 1; iNdEx >= 0; iNdEx-- { + { + size, err := m.AllowedGroupsCombos[iNdEx].MarshalToSizedBuffer(dAtA[:i]) + if err != nil { + return 0, err + } + i -= size + i = encodeVarintNebula(dAtA, i, uint64(size)) + } + i-- + dAtA[i] = 0x1a + } + } + if len(m.AllowedGroups) > 0 { + for iNdEx := len(m.AllowedGroups) - 1; iNdEx >= 0; iNdEx-- { + i -= len(m.AllowedGroups[iNdEx]) + copy(dAtA[i:], m.AllowedGroups[iNdEx]) + i = encodeVarintNebula(dAtA, i, uint64(len(m.AllowedGroups[iNdEx]))) + i-- + dAtA[i] = 0x12 + } + } + if len(m.AllowedHosts) > 0 { + for iNdEx := len(m.AllowedHosts) - 1; iNdEx >= 0; iNdEx-- { + i -= len(m.AllowedHosts[iNdEx]) + copy(dAtA[i:], m.AllowedHosts[iNdEx]) + i = encodeVarintNebula(dAtA, i, uint64(len(m.AllowedHosts[iNdEx]))) + i-- + dAtA[i] = 0xa + } + } + return len(dAtA) - i, nil +} + +func (m *GroupsCombos) Marshal() (dAtA []byte, err error) { + size := m.Size() + dAtA = make([]byte, size) + n, err := m.MarshalToSizedBuffer(dAtA[:size]) + if err != nil { + return nil, err + } + return dAtA[:n], nil +} + +func (m *GroupsCombos) MarshalTo(dAtA []byte) (int, error) { + size := m.Size() + return m.MarshalToSizedBuffer(dAtA[:size]) +} + +func (m *GroupsCombos) MarshalToSizedBuffer(dAtA []byte) (int, error) { + i := len(dAtA) + _ = i + var l int + _ = l + if len(m.Group) > 0 { + for iNdEx := len(m.Group) - 1; iNdEx >= 0; iNdEx-- { + i -= len(m.Group[iNdEx]) + copy(dAtA[i:], m.Group[iNdEx]) + i = encodeVarintNebula(dAtA, i, uint64(len(m.Group[iNdEx]))) + i-- + dAtA[i] = 0xa + } + } + return len(dAtA) - i, nil +} + func (m *NebulaPing) Marshal() (dAtA []byte, err error) { size := m.Size() dAtA = make([]byte, size) @@ -1309,6 +1601,10 @@ func (m *NebulaMetaDetails) Size() (n int) { n += 1 + l + sovNebula(uint64(l)) } } + if m.HandshakeFilteringWhitelist != nil { + l = m.HandshakeFilteringWhitelist.Size() + n += 1 + l + sovNebula(uint64(l)) + } return n } @@ -1360,6 +1656,69 @@ func (m *V6AddrPort) Size() (n int) { return n } +func (m *HandshakeFilteringWhitelist) Size() (n int) { + if m == nil { + return 0 + } + var l int + _ = l + if len(m.AllowedHosts) > 0 { + for _, s := range m.AllowedHosts { + l = len(s) + n += 1 + l + sovNebula(uint64(l)) + } + } + if len(m.AllowedGroups) > 0 { + for _, s := range m.AllowedGroups { + l = len(s) + n += 1 + l + sovNebula(uint64(l)) + } + } + if len(m.AllowedGroupsCombos) > 0 { + for _, e := range m.AllowedGroupsCombos { + l = e.Size() + n += 1 + l + sovNebula(uint64(l)) + } + } + if len(m.AllowedCidrs) > 0 { + for _, s := range m.AllowedCidrs { + l = len(s) + n += 1 + l + sovNebula(uint64(l)) + } + } + if len(m.AllowedCANames) > 0 { + for _, s := range m.AllowedCANames { + l = len(s) + n += 1 + l + sovNebula(uint64(l)) + } + } + if len(m.AllowedCAShas) > 0 { + for _, s := range m.AllowedCAShas { + l = len(s) + n += 1 + l + sovNebula(uint64(l)) + } + } + if m.SetEmpty { + n += 2 + } + return n +} + +func (m *GroupsCombos) Size() (n int) { + if m == nil { + return 0 + } + var l int + _ = l + if len(m.Group) > 0 { + for _, s := range m.Group { + l = len(s) + n += 1 + l + sovNebula(uint64(l)) + } + } + return n +} + func (m *NebulaPing) Size() (n int) { if m == nil { return 0 @@ -1844,6 +2203,42 @@ func (m *NebulaMetaDetails) Unmarshal(dAtA []byte) error { return err } iNdEx = postIndex + case 8: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field HandshakeFilteringWhitelist", wireType) + } + var msglen int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowNebula + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + msglen |= int(b&0x7F) << shift + if b < 0x80 { + break + } + } + if msglen < 0 { + return ErrInvalidLengthNebula + } + postIndex := iNdEx + msglen + if postIndex < 0 { + return ErrInvalidLengthNebula + } + if postIndex > l { + return io.ErrUnexpectedEOF + } + if m.HandshakeFilteringWhitelist == nil { + m.HandshakeFilteringWhitelist = &HandshakeFilteringWhitelist{} + } + if err := m.HandshakeFilteringWhitelist.Unmarshal(dAtA[iNdEx:postIndex]); err != nil { + return err + } + iNdEx = postIndex default: iNdEx = preIndex skippy, err := skipNebula(dAtA[iNdEx:]) @@ -2148,6 +2543,352 @@ func (m *V6AddrPort) Unmarshal(dAtA []byte) error { } return nil } +func (m *HandshakeFilteringWhitelist) Unmarshal(dAtA []byte) error { + l := len(dAtA) + iNdEx := 0 + for iNdEx < l { + preIndex := iNdEx + var wire uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowNebula + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + wire |= uint64(b&0x7F) << shift + if b < 0x80 { + break + } + } + fieldNum := int32(wire >> 3) + wireType := int(wire & 0x7) + if wireType == 4 { + return fmt.Errorf("proto: HandshakeFilteringWhitelist: wiretype end group for non-group") + } + if fieldNum <= 0 { + return fmt.Errorf("proto: HandshakeFilteringWhitelist: illegal tag %d (wire type %d)", fieldNum, wire) + } + switch fieldNum { + case 1: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field AllowedHosts", wireType) + } + var stringLen uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowNebula + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + stringLen |= uint64(b&0x7F) << shift + if b < 0x80 { + break + } + } + intStringLen := int(stringLen) + if intStringLen < 0 { + return ErrInvalidLengthNebula + } + postIndex := iNdEx + intStringLen + if postIndex < 0 { + return ErrInvalidLengthNebula + } + if postIndex > l { + return io.ErrUnexpectedEOF + } + m.AllowedHosts = append(m.AllowedHosts, string(dAtA[iNdEx:postIndex])) + iNdEx = postIndex + case 2: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field AllowedGroups", wireType) + } + var stringLen uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowNebula + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + stringLen |= uint64(b&0x7F) << shift + if b < 0x80 { + break + } + } + intStringLen := int(stringLen) + if intStringLen < 0 { + return ErrInvalidLengthNebula + } + postIndex := iNdEx + intStringLen + if postIndex < 0 { + return ErrInvalidLengthNebula + } + if postIndex > l { + return io.ErrUnexpectedEOF + } + m.AllowedGroups = append(m.AllowedGroups, string(dAtA[iNdEx:postIndex])) + iNdEx = postIndex + case 3: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field AllowedGroupsCombos", wireType) + } + var msglen int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowNebula + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + msglen |= int(b&0x7F) << shift + if b < 0x80 { + break + } + } + if msglen < 0 { + return ErrInvalidLengthNebula + } + postIndex := iNdEx + msglen + if postIndex < 0 { + return ErrInvalidLengthNebula + } + if postIndex > l { + return io.ErrUnexpectedEOF + } + m.AllowedGroupsCombos = append(m.AllowedGroupsCombos, &GroupsCombos{}) + if err := m.AllowedGroupsCombos[len(m.AllowedGroupsCombos)-1].Unmarshal(dAtA[iNdEx:postIndex]); err != nil { + return err + } + iNdEx = postIndex + case 4: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field AllowedCidrs", wireType) + } + var stringLen uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowNebula + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + stringLen |= uint64(b&0x7F) << shift + if b < 0x80 { + break + } + } + intStringLen := int(stringLen) + if intStringLen < 0 { + return ErrInvalidLengthNebula + } + postIndex := iNdEx + intStringLen + if postIndex < 0 { + return ErrInvalidLengthNebula + } + if postIndex > l { + return io.ErrUnexpectedEOF + } + m.AllowedCidrs = append(m.AllowedCidrs, string(dAtA[iNdEx:postIndex])) + iNdEx = postIndex + case 5: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field AllowedCANames", wireType) + } + var stringLen uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowNebula + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + stringLen |= uint64(b&0x7F) << shift + if b < 0x80 { + break + } + } + intStringLen := int(stringLen) + if intStringLen < 0 { + return ErrInvalidLengthNebula + } + postIndex := iNdEx + intStringLen + if postIndex < 0 { + return ErrInvalidLengthNebula + } + if postIndex > l { + return io.ErrUnexpectedEOF + } + m.AllowedCANames = append(m.AllowedCANames, string(dAtA[iNdEx:postIndex])) + iNdEx = postIndex + case 6: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field AllowedCAShas", wireType) + } + var stringLen uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowNebula + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + stringLen |= uint64(b&0x7F) << shift + if b < 0x80 { + break + } + } + intStringLen := int(stringLen) + if intStringLen < 0 { + return ErrInvalidLengthNebula + } + postIndex := iNdEx + intStringLen + if postIndex < 0 { + return ErrInvalidLengthNebula + } + if postIndex > l { + return io.ErrUnexpectedEOF + } + m.AllowedCAShas = append(m.AllowedCAShas, string(dAtA[iNdEx:postIndex])) + iNdEx = postIndex + case 7: + if wireType != 0 { + return fmt.Errorf("proto: wrong wireType = %d for field SetEmpty", wireType) + } + var v int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowNebula + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + v |= int(b&0x7F) << shift + if b < 0x80 { + break + } + } + m.SetEmpty = bool(v != 0) + default: + iNdEx = preIndex + skippy, err := skipNebula(dAtA[iNdEx:]) + if err != nil { + return err + } + if (skippy < 0) || (iNdEx+skippy) < 0 { + return ErrInvalidLengthNebula + } + if (iNdEx + skippy) > l { + return io.ErrUnexpectedEOF + } + iNdEx += skippy + } + } + + if iNdEx > l { + return io.ErrUnexpectedEOF + } + return nil +} +func (m *GroupsCombos) Unmarshal(dAtA []byte) error { + l := len(dAtA) + iNdEx := 0 + for iNdEx < l { + preIndex := iNdEx + var wire uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowNebula + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + wire |= uint64(b&0x7F) << shift + if b < 0x80 { + break + } + } + fieldNum := int32(wire >> 3) + wireType := int(wire & 0x7) + if wireType == 4 { + return fmt.Errorf("proto: GroupsCombos: wiretype end group for non-group") + } + if fieldNum <= 0 { + return fmt.Errorf("proto: GroupsCombos: illegal tag %d (wire type %d)", fieldNum, wire) + } + switch fieldNum { + case 1: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field Group", wireType) + } + var stringLen uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowNebula + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + stringLen |= uint64(b&0x7F) << shift + if b < 0x80 { + break + } + } + intStringLen := int(stringLen) + if intStringLen < 0 { + return ErrInvalidLengthNebula + } + postIndex := iNdEx + intStringLen + if postIndex < 0 { + return ErrInvalidLengthNebula + } + if postIndex > l { + return io.ErrUnexpectedEOF + } + m.Group = append(m.Group, string(dAtA[iNdEx:postIndex])) + iNdEx = postIndex + default: + iNdEx = preIndex + skippy, err := skipNebula(dAtA[iNdEx:]) + if err != nil { + return err + } + if (skippy < 0) || (iNdEx+skippy) < 0 { + return ErrInvalidLengthNebula + } + if (iNdEx + skippy) > l { + return io.ErrUnexpectedEOF + } + iNdEx += skippy + } + } + + if iNdEx > l { + return io.ErrUnexpectedEOF + } + return nil +} func (m *NebulaPing) Unmarshal(dAtA []byte) error { l := len(dAtA) iNdEx := 0 diff --git a/nebula.proto b/nebula.proto index ea1023348..d68d7050d 100644 --- a/nebula.proto +++ b/nebula.proto @@ -28,6 +28,7 @@ message NebulaMetaDetails { repeated uint32 OldRelayVpnAddrs = 5 [deprecated = true]; repeated Addr RelayVpnAddrs = 7; + HandshakeFilteringWhitelist HandshakeFilteringWhitelist = 8; repeated V4AddrPort V4AddrPorts = 2; repeated V6AddrPort V6AddrPorts = 4; @@ -50,6 +51,20 @@ message V6AddrPort { uint32 Port = 3; } +message HandshakeFilteringWhitelist { + repeated string AllowedHosts = 1; + repeated string AllowedGroups = 2; + repeated GroupsCombos AllowedGroupsCombos = 3; + repeated string AllowedCidrs = 4; + repeated string AllowedCANames = 5; + repeated string AllowedCAShas = 6; + bool SetEmpty = 7; +} + +message GroupsCombos { + repeated string Group = 1; +} + message NebulaPing { enum MessageType { Ping = 0; diff --git a/outside.go b/outside.go index 1e9cde16b..71ab59d41 100644 --- a/outside.go +++ b/outside.go @@ -139,7 +139,7 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out [] return } - lhf.HandleRequest(ip, hostinfo.vpnAddrs, d, f) + lhf.HandleRequest(ip, hostinfo, d, f) // Fallthrough to the bottom to record incoming traffic diff --git a/remote_list.go b/remote_list.go index 6baed29b2..c8baed9ad 100644 --- a/remote_list.go +++ b/remote_list.go @@ -196,6 +196,9 @@ type RemoteList struct { // A set of relay addresses. VpnIp addresses that the remote identified as relays. relays []netip.Addr + // Handshake filter. Used to filter host queries if a node enables host query protection. + hf *HandshakeFilter + // These are maps to store v4 and v6 addresses per lighthouse // Map key is the vpnIp of the person that told us about this the cached entries underneath. // For learned addresses, this is the vpnIp that sent the packet @@ -220,6 +223,7 @@ func NewRemoteList(vpnAddrs []netip.Addr, shouldAdd func(netip.Addr) bool) *Remo relays: make([]netip.Addr, 0), cache: make(map[netip.Addr]*cache), shouldAdd: shouldAdd, + hf: NewHandshakeFilter(), } copy(r.vpnAddrs, vpnAddrs) return r @@ -436,6 +440,17 @@ func (r *RemoteList) unlockedSetRelay(ownerVpnIp netip.Addr, to []netip.Addr) { c.relay = append(c.relay, to[:minInt(len(to), MaxRemotes)]...) } +func (r *RemoteList) unlockedSetHandshakeFilteringWhitelist(hfwl *HandshakeFilteringWhitelist) { + if hfwl == nil { + return + } + + r.hf = NewHandshakeFilter() + if !hfwl.GetSetEmpty() { + r.hf.UnmarshalFromHfw(hfwl) + } +} + // unlockedPrependV4 assumes you have the write lock and prepends the address in the reported list for this owner // This is only useful for establishing static hosts func (r *RemoteList) unlockedPrependV4(ownerVpnIp netip.Addr, to *V4AddrPort) { From 66222be258214bbd8d0b4dedac421c9b6396bb8e Mon Sep 17 00:00:00 2001 From: Daniel Jampen Date: Wed, 19 Mar 2025 23:07:47 +0100 Subject: [PATCH 4/8] always create handshake filter instance --- lighthouse.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/lighthouse.go b/lighthouse.go index 882467c0d..4087b43ff 100644 --- a/lighthouse.go +++ b/lighthouse.go @@ -106,6 +106,10 @@ func NewLightHouseFromConfig(ctx context.Context, l *logrus.Logger, c *config.C, nebulaPort = uint32(uPort.Port()) } + if hf == nil { + hf = NewHandshakeFilter() + } + h := LightHouse{ ctx: ctx, amLighthouse: amLighthouse, From cc96b90058cb0ffb3f62b6ca29d3a4191c0f638b Mon Sep 17 00:00:00 2001 From: Daniel Jampen Date: Sun, 30 Mar 2025 12:57:12 +0200 Subject: [PATCH 5/8] implement HostQueryWhitelist message - send HostQueryWhitelist when NebulaMeta_HostUpdateNotificationAck is handled - support multiple HostQueryWhitelist messages based on node mtu - handle lost udp packet using NebulaMeta_HostQueryWhitelistAck message - improved testing --- firewall.go | 91 +++++++++++---- firewall_test.go | 213 +++++++++++++++++++++++----------- handshake_ix.go | 2 +- handshake_manager.go | 1 + hostmap.go | 3 + lighthouse.go | 258 ++++++++++++++++++++++++++++++++---------- nebula.pb.go | 226 ++++++++++++++++++++++++------------ nebula.proto | 8 +- overlay/route.go | 2 +- overlay/route_test.go | 26 ++--- overlay/tun.go | 2 +- remote_list.go | 5 +- 12 files changed, 594 insertions(+), 243 deletions(-) diff --git a/firewall.go b/firewall.go index d1b285614..67f772b5e 100644 --- a/firewall.go +++ b/firewall.go @@ -1156,57 +1156,98 @@ func (hfws *HandshakeFilter) IsHandshakeAllowed(groups []string, host string, vp return false } -func (hfws *HandshakeFilter) MarshalToHfw() *HandshakeFilteringWhitelist { - hfw := &HandshakeFilteringWhitelist{ - AllowedHosts: make([]string, len(hfws.AllowedHosts)), - AllowedGroups: make([]string, len(hfws.AllowedGroups)), - AllowedGroupsCombos: make([]*GroupsCombos, len(hfws.AllowedGroupsCombos)), - AllowedCidrs: make([]string, len(hfws.AllowedCidrs)), - AllowedCANames: make([]string, len(hfws.AllowedCANames)), - AllowedCAShas: make([]string, len(hfws.AllowedCAShas)), - SetEmpty: hfws.IsEmtpy.Load(), +func (hfws *HandshakeFilter) MarshalToHfwList(maxMessageSize int) []*HandshakeFilteringWhitelist { + hfwList := make([]*HandshakeFilteringWhitelist, 0) + maxMessageSize = maxMessageSize - 10 // account for potential overhead + + appendOldAndGetNewHfw := func(old *HandshakeFilteringWhitelist) *HandshakeFilteringWhitelist { + if old != nil { + hfwList = append(hfwList, old) + } + return &HandshakeFilteringWhitelist{ + AllowedHosts: make([]string, 0), + AllowedGroups: make([]string, 0), + AllowedGroupsCombos: make([]*GroupsCombos, 0), + AllowedCidrs: make([]string, 0), + AllowedCANames: make([]string, 0), + AllowedCAShas: make([]string, 0), + Append: old != nil, + } + } + + type appendToHfw func(hfw *HandshakeFilteringWhitelist, e string) + addIteratableToHfw := func(hfw *HandshakeFilteringWhitelist, e string, appendToHfwFunc appendToHfw) *HandshakeFilteringWhitelist { + if hfw.Size()+len(e) > maxMessageSize { + hfw = appendOldAndGetNewHfw(hfw) + } + appendToHfwFunc(hfw, e) + return hfw } - for host := range hfws.AllowedHosts { - hfw.AllowedHosts = append(hfw.AllowedHosts, host) + hfw := appendOldAndGetNewHfw(nil) + + for e := range hfws.AllowedHosts { + hfw = addIteratableToHfw(hfw, e, func(h *HandshakeFilteringWhitelist, e string) { + h.AllowedHosts = append(h.AllowedHosts, e) + }) } - for group := range hfws.AllowedGroups { - hfw.AllowedGroups = append(hfw.AllowedGroups, group) + for e := range hfws.AllowedGroups { + hfw = addIteratableToHfw(hfw, e, func(h *HandshakeFilteringWhitelist, e string) { + h.AllowedGroups = append(h.AllowedGroups, e) + }) } - for i, groupCombo := range hfws.AllowedGroupsCombos { + for _, groupCombo := range hfws.AllowedGroupsCombos { gc := &GroupsCombos{ Group: make([]string, len(groupCombo)), } + j := 0 + groupBytes := 0 for group := range groupCombo { gc.Group[j] = group + groupBytes += len(group) j += 1 } - hfw.AllowedGroupsCombos[i] = gc + + if hfw.Size()+groupBytes > maxMessageSize { + hfw = appendOldAndGetNewHfw(hfw) + } + hfw.AllowedGroupsCombos = append(hfw.AllowedGroupsCombos, gc) } - for i, cidr := range hfws.AllowedCidrs { - hfw.AllowedCidrs[i] = cidr.String() + for _, e := range hfws.AllowedCidrs { + hfw = addIteratableToHfw(hfw, e.String(), func(h *HandshakeFilteringWhitelist, e string) { + h.AllowedCidrs = append(h.AllowedCidrs, e) + }) } - for ca := range hfws.AllowedCANames { - hfw.AllowedCANames = append(hfw.AllowedCANames, ca) + for e := range hfws.AllowedCANames { + hfw = addIteratableToHfw(hfw, e, func(h *HandshakeFilteringWhitelist, e string) { + h.AllowedCANames = append(h.AllowedCANames, e) + }) } - for fp := range hfws.AllowedCAShas { - hfw.AllowedCAShas = append(hfw.AllowedCAShas, fp) + for e := range hfws.AllowedCAShas { + hfw = addIteratableToHfw(hfw, e, func(h *HandshakeFilteringWhitelist, e string) { + h.AllowedCAShas = append(h.AllowedCAShas, e) + }) } hfws.IsModifiedSinceLastMashalling.Store(false) - return hfw + hfwList = append(hfwList, hfw) + return hfwList } -func (hfws *HandshakeFilter) UnmarshalFromHfw(hfw *HandshakeFilteringWhitelist) { +func (hfws *HandshakeFilter) UnmarshalFromHfw(hfw *HandshakeFilteringWhitelist) *HandshakeFilter { if hfw == nil { - return + return hfws + } + + if !hfw.Append { + hfws = NewHandshakeFilter() } for _, h := range hfw.AllowedHosts { @@ -1236,6 +1277,8 @@ func (hfws *HandshakeFilter) UnmarshalFromHfw(hfw *HandshakeFilteringWhitelist) for _, sha := range hfw.AllowedCAShas { hfws.AddRule(nil, "", netip.Prefix{}, "", sha) } + + return hfws } func isSubset(subset map[string]struct{}, superset []string) bool { diff --git a/firewall_test.go b/firewall_test.go index c848c6531..4c2046cc7 100644 --- a/firewall_test.go +++ b/firewall_test.go @@ -5,6 +5,7 @@ import ( "errors" "math" "net/netip" + "strconv" "testing" "time" @@ -699,31 +700,33 @@ func TestHandshakeFilter_Marshalling(t *testing.T) { assert.Empty(t, hf.AllowedCAShas) assert.False(t, hf.IsModifiedSinceLastMashalling.Load()) assert.True(t, hf.IsEmtpy.Load()) - hfw := hf.MarshalToHfw() + hfwl := hf.MarshalToHfwList(1300) assert.False(t, hf.IsModifiedSinceLastMashalling.Load()) - assert.Empty(t, hfw.AllowedGroups) - assert.Empty(t, hfw.AllowedGroupsCombos) - assert.Empty(t, hfw.AllowedHosts) - assert.Empty(t, hfw.AllowedCidrs) - assert.Empty(t, hfw.AllowedCANames) - assert.Empty(t, hfw.AllowedCAShas) - assert.True(t, hfw.SetEmpty) + assert.Equal(t, 1, len(hfwl)) + assert.Empty(t, hfwl[0].AllowedGroups) + assert.Empty(t, hfwl[0].AllowedGroupsCombos) + assert.Empty(t, hfwl[0].AllowedHosts) + assert.Empty(t, hfwl[0].AllowedCidrs) + assert.Empty(t, hfwl[0].AllowedCANames) + assert.Empty(t, hfwl[0].AllowedCAShas) + assert.False(t, hfwl[0].Append) hf = NewHandshakeFilter() g := "g1" hf.AddRule([]string{g}, "", netip.Prefix{}, "", "") assert.True(t, hf.IsModifiedSinceLastMashalling.Load()) - hfw = hf.MarshalToHfw() - assert.Contains(t, hfw.AllowedGroups, g) - assert.Empty(t, hfw.AllowedGroupsCombos) - assert.Empty(t, hfw.AllowedHosts) - assert.Empty(t, hfw.AllowedCidrs) - assert.Empty(t, hfw.AllowedCANames) - assert.Empty(t, hfw.AllowedCAShas) - assert.False(t, hfw.SetEmpty) + hfwl = hf.MarshalToHfwList(1300) + assert.Equal(t, 1, len(hfwl)) + assert.Contains(t, hfwl[0].AllowedGroups, g) + assert.Empty(t, hfwl[0].AllowedGroupsCombos) + assert.Empty(t, hfwl[0].AllowedHosts) + assert.Empty(t, hfwl[0].AllowedCidrs) + assert.Empty(t, hfwl[0].AllowedCANames) + assert.Empty(t, hfwl[0].AllowedCAShas) + assert.False(t, hfwl[0].Append) assert.False(t, hf.IsModifiedSinceLastMashalling.Load()) hf = NewHandshakeFilter() - hf.UnmarshalFromHfw(hfw) + hf = hf.UnmarshalFromHfw(hfwl[0]) assert.Empty(t, hf.AllowedGroupsCombos) assert.Empty(t, hf.AllowedHosts) assert.Empty(t, hf.AllowedCidrs) @@ -737,23 +740,24 @@ func TestHandshakeFilter_Marshalling(t *testing.T) { hf.AddRule(gc, "", netip.Prefix{}, "", "") assert.Len(t, hf.AllowedGroupsCombos, 1) assert.True(t, hf.IsModifiedSinceLastMashalling.Load()) - hfw = hf.MarshalToHfw() - assert.Empty(t, hfw.AllowedGroups) - assert.Len(t, hfw.AllowedGroupsCombos, 1) + hfwl = hf.MarshalToHfwList(1300) + assert.Equal(t, 1, len(hfwl)) + assert.Empty(t, hfwl[0].AllowedGroups) + assert.Len(t, hfwl[0].AllowedGroupsCombos, 1) for _, g := range gc { - assert.Contains(t, hfw.AllowedGroupsCombos[0].Group, g) + assert.Contains(t, hfwl[0].AllowedGroupsCombos[0].Group, g) } - assert.Empty(t, hfw.AllowedHosts) - assert.Empty(t, hfw.AllowedCidrs) - assert.Empty(t, hfw.AllowedCANames) - assert.Empty(t, hfw.AllowedCAShas) - assert.False(t, hfw.SetEmpty) + assert.Empty(t, hfwl[0].AllowedHosts) + assert.Empty(t, hfwl[0].AllowedCidrs) + assert.Empty(t, hfwl[0].AllowedCANames) + assert.Empty(t, hfwl[0].AllowedCAShas) + assert.False(t, hfwl[0].Append) assert.False(t, hf.IsModifiedSinceLastMashalling.Load()) hf = NewHandshakeFilter() - hf.UnmarshalFromHfw(hfw) + hf = hf.UnmarshalFromHfw(hfwl[0]) assert.Empty(t, hf.AllowedGroups) gs := make(map[string]struct{}) - for _, g := range hfw.AllowedGroupsCombos[0].Group { + for _, g := range hfwl[0].AllowedGroupsCombos[0].Group { gs[g] = struct{}{} } for _, g := range gc { @@ -769,17 +773,18 @@ func TestHandshakeFilter_Marshalling(t *testing.T) { h := "h1" hf.AddRule(nil, h, netip.Prefix{}, "", "") assert.True(t, hf.IsModifiedSinceLastMashalling.Load()) - hfw = hf.MarshalToHfw() - assert.Empty(t, hfw.AllowedGroups) - assert.Empty(t, hfw.AllowedGroupsCombos) - assert.Contains(t, hfw.AllowedHosts, h) - assert.Empty(t, hfw.AllowedCidrs) - assert.Empty(t, hfw.AllowedCANames) - assert.Empty(t, hfw.AllowedCAShas) - assert.False(t, hfw.SetEmpty) + hfwl = hf.MarshalToHfwList(1300) + assert.Equal(t, 1, len(hfwl)) + assert.Empty(t, hfwl[0].AllowedGroups) + assert.Empty(t, hfwl[0].AllowedGroupsCombos) + assert.Contains(t, hfwl[0].AllowedHosts, h) + assert.Empty(t, hfwl[0].AllowedCidrs) + assert.Empty(t, hfwl[0].AllowedCANames) + assert.Empty(t, hfwl[0].AllowedCAShas) + assert.False(t, hfwl[0].Append) assert.False(t, hf.IsModifiedSinceLastMashalling.Load()) hf = NewHandshakeFilter() - hf.UnmarshalFromHfw(hfw) + hf = hf.UnmarshalFromHfw(hfwl[0]) assert.Empty(t, hf.AllowedGroups) assert.Empty(t, hf.AllowedGroupsCombos) assert.Contains(t, hf.AllowedHosts, h) @@ -792,17 +797,18 @@ func TestHandshakeFilter_Marshalling(t *testing.T) { p, _ := netip.ParsePrefix("10.1.1.1/32") hf.AddRule(nil, "", p, "", "") assert.True(t, hf.IsModifiedSinceLastMashalling.Load()) - hfw = hf.MarshalToHfw() - assert.Empty(t, hfw.AllowedGroups) - assert.Empty(t, hfw.AllowedGroupsCombos) - assert.Empty(t, hfw.AllowedHosts) - assert.Equal(t, hfw.AllowedCidrs[0], p.String()) - assert.Empty(t, hfw.AllowedCANames) - assert.Empty(t, hfw.AllowedCAShas) - assert.False(t, hfw.SetEmpty) + hfwl = hf.MarshalToHfwList(1300) + assert.Equal(t, 1, len(hfwl)) + assert.Empty(t, hfwl[0].AllowedGroups) + assert.Empty(t, hfwl[0].AllowedGroupsCombos) + assert.Empty(t, hfwl[0].AllowedHosts) + assert.Equal(t, hfwl[0].AllowedCidrs[0], p.String()) + assert.Empty(t, hfwl[0].AllowedCANames) + assert.Empty(t, hfwl[0].AllowedCAShas) + assert.False(t, hfwl[0].Append) assert.False(t, hf.IsModifiedSinceLastMashalling.Load()) hf = NewHandshakeFilter() - hf.UnmarshalFromHfw(hfw) + hf = hf.UnmarshalFromHfw(hfwl[0]) assert.Empty(t, hf.AllowedGroups) assert.Empty(t, hf.AllowedGroupsCombos) assert.Empty(t, hf.AllowedHosts) @@ -815,17 +821,18 @@ func TestHandshakeFilter_Marshalling(t *testing.T) { ca := "TestCA" hf.AddRule(nil, "", netip.Prefix{}, ca, "") assert.True(t, hf.IsModifiedSinceLastMashalling.Load()) - hfw = hf.MarshalToHfw() - assert.Empty(t, hfw.AllowedGroups) - assert.Empty(t, hfw.AllowedGroupsCombos) - assert.Empty(t, hfw.AllowedHosts) - assert.Empty(t, hfw.AllowedCidrs) - assert.Contains(t, hfw.AllowedCANames, ca) - assert.Empty(t, hfw.AllowedCAShas) - assert.False(t, hfw.SetEmpty) + hfwl = hf.MarshalToHfwList(1300) + assert.Equal(t, 1, len(hfwl)) + assert.Empty(t, hfwl[0].AllowedGroups) + assert.Empty(t, hfwl[0].AllowedGroupsCombos) + assert.Empty(t, hfwl[0].AllowedHosts) + assert.Empty(t, hfwl[0].AllowedCidrs) + assert.Contains(t, hfwl[0].AllowedCANames, ca) + assert.Empty(t, hfwl[0].AllowedCAShas) + assert.False(t, hfwl[0].Append) assert.False(t, hf.IsModifiedSinceLastMashalling.Load()) hf = NewHandshakeFilter() - hf.UnmarshalFromHfw(hfw) + hf = hf.UnmarshalFromHfw(hfwl[0]) assert.Empty(t, hf.AllowedGroups) assert.Empty(t, hf.AllowedGroupsCombos) assert.Empty(t, hf.AllowedHosts) @@ -838,26 +845,102 @@ func TestHandshakeFilter_Marshalling(t *testing.T) { fp := "3fc204e4d45e8b22ed0879bcd7cb5bf93cdc1c7a309c5dcedddc03aed33a47c6" hf.AddRule(nil, "", netip.Prefix{}, "", fp) assert.True(t, hf.IsModifiedSinceLastMashalling.Load()) - hfw = hf.MarshalToHfw() - assert.Empty(t, hfw.AllowedGroups) - assert.Empty(t, hfw.AllowedGroupsCombos) - assert.Empty(t, hfw.AllowedHosts) - assert.Empty(t, hfw.AllowedCidrs) - assert.Empty(t, hfw.AllowedCANames) - assert.Contains(t, hfw.AllowedCAShas, fp) - assert.False(t, hfw.SetEmpty) + hfwl = hf.MarshalToHfwList(1300) + assert.Equal(t, 1, len(hfwl)) + assert.Empty(t, hfwl[0].AllowedGroups) + assert.Empty(t, hfwl[0].AllowedGroupsCombos) + assert.Empty(t, hfwl[0].AllowedHosts) + assert.Empty(t, hfwl[0].AllowedCidrs) + assert.Empty(t, hfwl[0].AllowedCANames) + assert.Contains(t, hfwl[0].AllowedCAShas, fp) + assert.False(t, hfwl[0].Append) assert.False(t, hf.IsModifiedSinceLastMashalling.Load()) hf = NewHandshakeFilter() - hf.UnmarshalFromHfw(hfw) + hf = hf.UnmarshalFromHfw(hfwl[0]) assert.Empty(t, hf.AllowedGroups) assert.Empty(t, hf.AllowedGroupsCombos) assert.Empty(t, hf.AllowedHosts, h) assert.Empty(t, hf.AllowedCidrs) - assert.Empty(t, hfw.AllowedCANames) + assert.Empty(t, hf.AllowedCANames) assert.Contains(t, hf.AllowedCAShas, fp) assert.True(t, hf.IsModifiedSinceLastMashalling.Load()) } +func TestHandshakeFilter_MarshallingMultiPacket(t *testing.T) { + mtu := 1100 + hf := NewHandshakeFilter() + h := "h" + for i := 0; i <= 100; i++ { + hf.AddRule(nil, h+strconv.Itoa(i), netip.Prefix{}, "", "") + } + + g := "g" + for i := 0; i <= 100; i++ { + hf.AddRule([]string{g + strconv.Itoa(i)}, "", netip.Prefix{}, "", "") + } + + for i := 0; i <= 100; i++ { + gc := []string{"g1", "g2", "g" + strconv.Itoa(i)} + hf.AddRule(gc, "", netip.Prefix{}, "", "") + } + + for i := 0; i <= 100; i++ { + p, _ := netip.ParsePrefix("10.1.1." + strconv.Itoa(i) + "/32") + hf.AddRule(nil, "", p, "", "") + } + + ca := "TestCA" + for i := 0; i <= 100; i++ { + hf.AddRule(nil, "", netip.Prefix{}, ca+strconv.Itoa(i), "") + } + + fp := "3fc204e4d45e8b22ed0879bcd7cb5bf93cdc1c7a309c5dcedddc03aed33a47c6" + for i := 0; i <= 100; i++ { + hf.AddRule(nil, "", netip.Prefix{}, "", fp+strconv.Itoa(i)) + } + assert.True(t, hf.IsModifiedSinceLastMashalling.Load()) + hfwl := hf.MarshalToHfwList(mtu) + for _, hfw := range hfwl { + assert.LessOrEqual(t, hfw.Size(), mtu) + } + assert.False(t, hfwl[0].Append) + for i := 1; i < len(hfwl); i++ { + assert.True(t, hfwl[i].Append) + } + assert.False(t, hf.IsModifiedSinceLastMashalling.Load()) + hf = NewHandshakeFilter() + for _, hfw := range hfwl { + hf = hf.UnmarshalFromHfw(hfw) + } + for i := 0; i <= 100; i++ { + assert.Contains(t, hf.AllowedGroups, g+strconv.Itoa(i)) + } + for i := 0; i <= 100; i++ { + gc := map[string]struct{}{ + "g1": {}, + "g2": {}, + "g" + strconv.Itoa(i): {}, + } + + assert.Contains(t, hf.AllowedGroupsCombos, gc) + } + for i := 0; i <= 100; i++ { + assert.Contains(t, hf.AllowedHosts, h+strconv.Itoa(i)) + } + for i := 0; i <= 100; i++ { + p, _ := netip.ParsePrefix("10.1.1." + strconv.Itoa(i) + "/32") + assert.Contains(t, hf.AllowedCidrs, p) + } + for i := 0; i <= 100; i++ { + assert.Contains(t, hf.AllowedCANames, ca+strconv.Itoa(i)) + } + for i := 0; i <= 100; i++ { + assert.Contains(t, hf.AllowedCAShas, fp+strconv.Itoa(i)) + } + + assert.True(t, hf.IsModifiedSinceLastMashalling.Load()) +} + func Test_isSubset(t *testing.T) { subset := make(map[string]struct{}, 2) subset["g1"] = struct{}{} diff --git a/handshake_ix.go b/handshake_ix.go index 1cbb67f11..0639121c8 100644 --- a/handshake_ix.go +++ b/handshake_ix.go @@ -244,7 +244,6 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to generate index") return } - hostinfo := &HostInfo{ ConnectionState: ci, localIndexId: myIndex, @@ -252,6 +251,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet vpnAddrs: vpnAddrs, HandshakePacket: make(map[uint8][]byte, 0), lastHandshakeTime: hs.Details.Time, + hfwMessagesAckd: make(map[uint8]bool, 0), relayState: RelayState{ relays: map[netip.Addr]struct{}{}, relayForByAddr: map[netip.Addr]*Relay{}, diff --git a/handshake_manager.go b/handshake_manager.go index 6f954021f..87614e856 100644 --- a/handshake_manager.go +++ b/handshake_manager.go @@ -450,6 +450,7 @@ func (hm *HandshakeManager) StartHandshake(vpnAddr netip.Addr, cacheCb func(*Han hostinfo := &HostInfo{ vpnAddrs: []netip.Addr{vpnAddr}, HandshakePacket: make(map[uint8][]byte, 0), + hfwMessagesAckd: make(map[uint8]bool, 0), relayState: RelayState{ relays: map[netip.Addr]struct{}{}, relayForByAddr: map[netip.Addr]*Relay{}, diff --git a/hostmap.go b/hostmap.go index f9e3c4e50..681b19147 100644 --- a/hostmap.go +++ b/hostmap.go @@ -247,6 +247,9 @@ type HostInfo struct { lastRoam time.Time lastRoamRemote netip.AddrPort + // Tracks if the lh ack'd all our hfw messages. if not, the missing message will be resent + hfwMessagesAckd map[uint8]bool + // Used to track other hostinfos for this vpn ip since only 1 can be primary // Synchronised via hostmap lock and not the hostinfo lock. next, prev *HostInfo diff --git a/lighthouse.go b/lighthouse.go index 4087b43ff..550890949 100644 --- a/lighthouse.go +++ b/lighthouse.go @@ -19,6 +19,7 @@ import ( "github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/header" + "github.com/slackhq/nebula/overlay" "github.com/slackhq/nebula/udp" "github.com/slackhq/nebula/util" ) @@ -49,6 +50,10 @@ type LightHouse struct { // Controls weather to send the handshake white list rules to the lighthouses. enableHostQueryProtection atomic.Bool + // Routing information used for hfwl packet generation + routes []overlay.Route + defaultMTU int + // filters remote addresses allowed for each host // - When we are a lighthouse, this filters what addresses we store and // respond with. @@ -129,6 +134,7 @@ func NewLightHouseFromConfig(ctx context.Context, l *logrus.Logger, c *config.C, h.staticList.Store(&staticList) h.incomingHandshakeFiltering.Store(false) h.enableHostQueryProtection.Store(false) + h.routes = make([]overlay.Route, 0) if c.GetBool("stats.lighthouse_metrics", false) { h.metrics = newLighthouseMetrics() @@ -195,6 +201,15 @@ func (lh *LightHouse) GetUpdateInterval() int64 { return lh.interval.Load() } +func (lh *LightHouse) GetMTUForAddr(addr netip.Addr) int { + for _, r := range lh.routes { + if r.Cidr.Contains(addr) { + return r.MTU + } + } + return lh.defaultMTU +} + func (lh *LightHouse) reload(c *config.C, initial bool) error { if initial || c.HasChanged("lighthouse.advertise_addrs") { rawAdvAddrs := c.GetStringSlice("lighthouse.advertise_addrs", []string{}) @@ -271,6 +286,19 @@ func (lh *LightHouse) reload(c *config.C, initial bool) error { } } + if initial || c.HasChanged("tun.routes") { + routes, err := overlay.ParseRoutes(c, lh.myVpnNetworks) + if err != nil { + return util.NewContextualError("Could not parse tun.routes", nil, err) + } + + lh.routes = routes + } + + if initial || c.HasChanged("tun.mtu") { + lh.defaultMTU = c.GetInt("tun.mtu", overlay.DefaultMTU) + } + if initial || c.HasChanged("lighthouse.remote_allow_list") || c.HasChanged("lighthouse.remote_allow_ranges") { ral, err := NewRemoteAllowListFromConfig(c, "lighthouse.remote_allow_list", "lighthouse.remote_allow_ranges") if err != nil { @@ -569,7 +597,7 @@ func (lh *LightHouse) IsHostQueryAllowed(targetAddr netip.Addr, groups []string, lh.RUnlock() - return v.hf.IsEmtpy.Load() || v.hf.IsHandshakeAllowed(groups, host, queryAddrs, CAName, CASha) + return v.hf.IsHandshakeAllowed(groups, host, queryAddrs, CAName, CASha) } lh.RUnlock() return true @@ -926,18 +954,11 @@ func (lh *LightHouse) SendUpdate() { nb := make([]byte, 12, 12) out := make([]byte, mtu) - // cache for v1Update/v2Update with or without hfwl - updateMessageCache := map[cert.Version]map[bool][]byte{ - cert.Version1: map[bool][]byte{}, - cert.Version2: map[bool][]byte{}, - } - - sendHfw := lh.enableHostQueryProtection.Load() && - lh.hf.IsModifiedSinceLastMashalling.Load() - + var v1Update, v2Update []byte var err error updated := 0 lighthouses := lh.GetLighthouses() + for lhVpnAddr := range lighthouses { var v cert.Version hi := lh.ifce.GetHostInfo(lhVpnAddr) @@ -946,10 +967,8 @@ func (lh *LightHouse) SendUpdate() { } else { v = lh.ifce.GetCertState().defaultVersion } - - sendHfwToLh := hi == nil || sendHfw if v == cert.Version1 { - if _, ok := updateMessageCache[v][sendHfwToLh]; !ok { + if v1Update == nil { if !lh.myVpnNetworks[0].Addr().Is4() { lh.l.WithField("lighthouseAddr", lhVpnAddr). Warn("cannot update lighthouse using v1 protocol without an IPv4 address") @@ -974,20 +993,7 @@ func (lh *LightHouse) SendUpdate() { }, } - if sendHfwToLh { - msg.Details.HandshakeFilteringWhitelist = lh.hf.MarshalToHfw() - if msg.Details.HandshakeFilteringWhitelist != nil && lh.l.Level >= logrus.DebugLevel { - lh.l.WithField("hosts", msg.Details.HandshakeFilteringWhitelist.AllowedHosts). - WithField("groups", msg.Details.HandshakeFilteringWhitelist.AllowedGroups). - WithField("groupcombos", msg.Details.HandshakeFilteringWhitelist.AllowedGroupsCombos). - WithField("cidrs", msg.Details.HandshakeFilteringWhitelist.AllowedCidrs). - WithField("canames", msg.Details.HandshakeFilteringWhitelist.AllowedCANames). - WithField("cashas", msg.Details.HandshakeFilteringWhitelist.AllowedCAShas). - Debug("Sending handshake filtering whitelist to lighthouse") - } - } - - updateMessageCache[v][sendHfwToLh], err = msg.Marshal() + v1Update, err = msg.Marshal() if err != nil { lh.l.WithError(err).WithField("lighthouseAddr", lhVpnAddr). Error("Error while marshaling for lighthouse v1 update") @@ -995,11 +1001,11 @@ func (lh *LightHouse) SendUpdate() { } } - lh.ifce.SendMessageToVpnAddr(header.LightHouse, 0, lhVpnAddr, updateMessageCache[v][sendHfwToLh], nb, out) + lh.ifce.SendMessageToVpnAddr(header.LightHouse, 0, lhVpnAddr, v1Update, nb, out) updated++ } else if v == cert.Version2 { - if _, ok := updateMessageCache[v][sendHfwToLh]; !ok { + if v2Update == nil { var relays []*Addr for _, r := range lh.GetRelaysForMe() { relays = append(relays, netAddrToProtoAddr(r)) @@ -1015,22 +1021,7 @@ func (lh *LightHouse) SendUpdate() { }, } - if sendHfwToLh { - msg.Details.HandshakeFilteringWhitelist = lh.hf.MarshalToHfw() - if lh.l.Level >= logrus.DebugLevel { - if msg.Details.HandshakeFilteringWhitelist != nil && lh.l.Level >= logrus.DebugLevel { - lh.l.WithField("hosts", msg.Details.HandshakeFilteringWhitelist.AllowedHosts). - WithField("groups", msg.Details.HandshakeFilteringWhitelist.AllowedGroups). - WithField("groupcombos", msg.Details.HandshakeFilteringWhitelist.AllowedGroupsCombos). - WithField("cidrs", msg.Details.HandshakeFilteringWhitelist.AllowedCidrs). - WithField("canames", msg.Details.HandshakeFilteringWhitelist.AllowedCANames). - WithField("cashas", msg.Details.HandshakeFilteringWhitelist.AllowedCAShas). - Debug("Sending handshake filtering whitelist to lighthouse") - } - } - } - - updateMessageCache[v][sendHfwToLh], err = msg.Marshal() + v2Update, err = msg.Marshal() if err != nil { lh.l.WithError(err).WithField("lighthouseAddr", lhVpnAddr). Error("Error while marshaling for lighthouse v2 update") @@ -1038,7 +1029,7 @@ func (lh *LightHouse) SendUpdate() { } } - lh.ifce.SendMessageToVpnAddr(header.LightHouse, 0, lhVpnAddr, updateMessageCache[v][sendHfwToLh], nb, out) + lh.ifce.SendMessageToVpnAddr(header.LightHouse, 0, lhVpnAddr, v2Update, nb, out) updated++ } else { @@ -1136,7 +1127,13 @@ func (lhh *LightHouseHandler) HandleRequest(rAddr netip.AddrPort, hostInfo *Host lhh.handleHostPunchNotification(n, fromVpnAddrs, w) case NebulaMeta_HostUpdateNotificationAck: - // noop + lhh.handleHostUpdateNotificationAck(n, fromVpnAddrs, w) + + case NebulaMeta_HostQueryWhitelist: + lhh.handleHostQueryWhitelist(n, fromVpnAddrs, w) + + case NebulaMeta_HostQueryWhitelistAck: + lhh.handleHostQueryWhitelistAck(n, fromVpnAddrs, w) } } @@ -1379,7 +1376,6 @@ func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, fromVp } relays := n.Details.GetRelays() - hfws := n.Details.GetHandshakeFilteringWhitelist() lhh.lh.Lock() am := lhh.lh.unlockedGetRemoteList(fromVpnAddrs) @@ -1389,21 +1385,8 @@ func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, fromVp am.unlockedSetV4(fromVpnAddrs[0], detailsVpnAddr, n.Details.V4AddrPorts, lhh.lh.unlockedShouldAddV4) am.unlockedSetV6(fromVpnAddrs[0], detailsVpnAddr, n.Details.V6AddrPorts, lhh.lh.unlockedShouldAddV6) am.unlockedSetRelay(fromVpnAddrs[0], relays) - am.unlockedSetHandshakeFilteringWhitelist(hfws) am.Unlock() - if hfws != nil && lhh.l.Level >= logrus.DebugLevel { - lhh.l.WithField("vpnAddrs", fromVpnAddrs). - WithField("hosts", hfws.AllowedHosts). - WithField("groups", hfws.AllowedGroups). - WithField("groupcombos", hfws.AllowedGroupsCombos). - WithField("cidrs", hfws.AllowedCidrs). - WithField("canames", hfws.AllowedCANames). - WithField("cashas", hfws.AllowedCAShas). - WithField("setempty", hfws.SetEmpty). - Debug("Received host query filter") - } - n = lhh.resetMeta() n.Type = NebulaMeta_HostUpdateNotificationAck @@ -1497,6 +1480,159 @@ func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, fromVpn }() } } +func (lhh *LightHouseHandler) handleHostUpdateNotificationAck(n *NebulaMeta, fromVpnAddrs []netip.Addr, w EncWriter) { + if !lhh.lh.IsAnyLighthouseAddr(fromVpnAddrs) { + return + } + + nb := make([]byte, 12, 12) + out := make([]byte, mtu) + + // make sure to send hfwl to all lighthouses if there were changes + if lhh.lh.enableHostQueryProtection.Load() && lhh.lh.hf.IsModifiedSinceLastMashalling.Load() { + lighthouses := lhh.lh.GetLighthouses() + for lhVpnAddr := range lighthouses { + hi := lhh.lh.ifce.GetHostInfo(lhVpnAddr) + if hi == nil { + continue + } + + hi.hfwMessagesAckd = make(map[uint8]bool, 0) + } + } + + for _, lhVpnAddr := range fromVpnAddrs { + hi := lhh.lh.ifce.GetHostInfo(lhVpnAddr) + if hi == nil { + continue + } + + nonAckedHfwMessageCounter := 0 + for _, value := range hi.hfwMessagesAckd { + if !value { + nonAckedHfwMessageCounter++ + } + } + + if len(hi.hfwMessagesAckd) == 0 || nonAckedHfwMessageCounter != 0 { + lhMtu := lhh.lh.GetMTUForAddr(lhVpnAddr) + hfwList := lhh.lh.hf.MarshalToHfwList(lhMtu) + + for i, hfw := range hfwList { + if _, ok := hi.hfwMessagesAckd[uint8(i)]; ok { + if hi.hfwMessagesAckd[uint8(i)] { + // skip as already ack'd + continue + } + } + + msg := NebulaMeta{ + Type: NebulaMeta_HostQueryWhitelist, + Details: &NebulaMetaDetails{ + HandshakeFilteringWhitelist: hfw, + }, + } + + msg.Details.HandshakeFilteringWhitelist.MessageId = uint32(i) + + if msg.Details.HandshakeFilteringWhitelist != nil && lhh.lh.l.Level >= logrus.DebugLevel { + lhh.lh.l.WithField("hosts", msg.Details.HandshakeFilteringWhitelist.AllowedHosts). + WithField("groups", msg.Details.HandshakeFilteringWhitelist.AllowedGroups). + WithField("groupcombos", msg.Details.HandshakeFilteringWhitelist.AllowedGroupsCombos). + WithField("cidrs", msg.Details.HandshakeFilteringWhitelist.AllowedCidrs). + WithField("canames", msg.Details.HandshakeFilteringWhitelist.AllowedCANames). + WithField("cashas", msg.Details.HandshakeFilteringWhitelist.AllowedCAShas). + WithField("i", i). + Debug("Sending hfw message") + } + msgSerialized, err := msg.Marshal() + + if err != nil { + lhh.lh.l.WithError(err). + WithField("lighthouseAddr", lhVpnAddr). + Error("Error while marshaling for lighthouse hfw update") + break + } + + hi.hfwMessagesAckd[uint8(i)] = false + lhh.lh.ifce.SendMessageToVpnAddr(header.LightHouse, 0, lhVpnAddr, msgSerialized, nb, out) + } + + return + } + } +} + +func (lhh *LightHouseHandler) handleHostQueryWhitelist(n *NebulaMeta, fromVpnAddrs []netip.Addr, w EncWriter) { + if !lhh.lh.amLighthouse { + if lhh.l.Level >= logrus.DebugLevel { + lhh.l.Debugln("I am not a lighthouse, do not take host query whitelists: ", fromVpnAddrs) + } + return + } + + hfws := n.Details.GetHandshakeFilteringWhitelist() + + if hfws != nil && lhh.l.Level >= logrus.DebugLevel { + lhh.l.WithField("vpnAddrs", fromVpnAddrs). + WithField("hosts", hfws.AllowedHosts). + WithField("groups", hfws.AllowedGroups). + WithField("groupcombos", hfws.AllowedGroupsCombos). + WithField("cidrs", hfws.AllowedCidrs). + WithField("canames", hfws.AllowedCANames). + WithField("cashas", hfws.AllowedCAShas). + WithField("MessageId", hfws.MessageId). + WithField("append", hfws.Append). + Debug("Received host query filter") + } + + lhh.lh.Lock() + am := lhh.lh.unlockedGetRemoteList(fromVpnAddrs) + am.Lock() + lhh.lh.Unlock() + am.unlockedSetHandshakeFilteringWhitelist(hfws) + am.Unlock() + + msg := NebulaMeta{ + Type: NebulaMeta_HostQueryWhitelistAck, + Details: &NebulaMetaDetails{ + HandshakeFilteringWhitelist: &HandshakeFilteringWhitelist{ + MessageId: hfws.MessageId, + }, + }, + } + + msgSerialized, err := msg.Marshal() + if err != nil { + lhh.lh.l.WithError(err). + WithField("fromVpnAddrs", fromVpnAddrs). + Error("Error while marshaling for lighthouse hfw ack") + return + } + + nb := make([]byte, 12, 12) + out := make([]byte, mtu) + lhh.lh.ifce.SendMessageToVpnAddr(header.LightHouse, 0, fromVpnAddrs[0], msgSerialized, nb, out) +} + +func (lhh *LightHouseHandler) handleHostQueryWhitelistAck(n *NebulaMeta, fromVpnAddrs []netip.Addr, w EncWriter) { + if lhh.lh.amLighthouse { + if lhh.l.Level >= logrus.DebugLevel { + lhh.l.Debugln("I am a lighthouse, do not take host query whitelist acks: ", fromVpnAddrs) + } + return + } + + hi := lhh.lh.ifce.GetHostInfo(fromVpnAddrs[0]) + + if lhh.l.Level >= logrus.DebugLevel { + lhh.l.WithField("vpnAddrs", fromVpnAddrs). + WithField("MessageId", n.Details.HandshakeFilteringWhitelist.MessageId). + Debug("Received HostQueryWhitelistAck") + } + + hi.hfwMessagesAckd[uint8(n.Details.HandshakeFilteringWhitelist.MessageId)] = true +} func protoAddrToNetAddr(addr *Addr) netip.Addr { b := [16]byte{} diff --git a/nebula.pb.go b/nebula.pb.go index 70877fd8b..f3faadc37 100644 --- a/nebula.pb.go +++ b/nebula.pb.go @@ -36,6 +36,8 @@ const ( NebulaMeta_PathCheck NebulaMeta_MessageType = 8 NebulaMeta_PathCheckReply NebulaMeta_MessageType = 9 NebulaMeta_HostUpdateNotificationAck NebulaMeta_MessageType = 10 + NebulaMeta_HostQueryWhitelist NebulaMeta_MessageType = 11 + NebulaMeta_HostQueryWhitelistAck NebulaMeta_MessageType = 12 ) var NebulaMeta_MessageType_name = map[int32]string{ @@ -50,6 +52,8 @@ var NebulaMeta_MessageType_name = map[int32]string{ 8: "PathCheck", 9: "PathCheckReply", 10: "HostUpdateNotificationAck", + 11: "HostQueryWhitelist", + 12: "HostQueryWhitelistAck", } var NebulaMeta_MessageType_value = map[string]int32{ @@ -64,6 +68,8 @@ var NebulaMeta_MessageType_value = map[string]int32{ "PathCheck": 8, "PathCheckReply": 9, "HostUpdateNotificationAck": 10, + "HostQueryWhitelist": 11, + "HostQueryWhitelistAck": 12, } func (x NebulaMeta_MessageType) String() string { @@ -184,7 +190,8 @@ type NebulaMetaDetails struct { VpnAddr *Addr `protobuf:"bytes,6,opt,name=VpnAddr,proto3" json:"VpnAddr,omitempty"` OldRelayVpnAddrs []uint32 `protobuf:"varint,5,rep,packed,name=OldRelayVpnAddrs,proto3" json:"OldRelayVpnAddrs,omitempty"` // Deprecated: Do not use. RelayVpnAddrs []*Addr `protobuf:"bytes,7,rep,name=RelayVpnAddrs,proto3" json:"RelayVpnAddrs,omitempty"` - HandshakeFilteringWhitelist *HandshakeFilteringWhitelist `protobuf:"bytes,8,opt,name=HandshakeFilteringWhitelist,proto3" json:"HandshakeFilteringWhitelist,omitempty"` + EnableHostQueryFiltering bool `protobuf:"varint,8,opt,name=EnableHostQueryFiltering,proto3" json:"EnableHostQueryFiltering,omitempty"` + HandshakeFilteringWhitelist *HandshakeFilteringWhitelist `protobuf:"bytes,9,opt,name=HandshakeFilteringWhitelist,proto3" json:"HandshakeFilteringWhitelist,omitempty"` V4AddrPorts []*V4AddrPort `protobuf:"bytes,2,rep,name=V4AddrPorts,proto3" json:"V4AddrPorts,omitempty"` V6AddrPorts []*V6AddrPort `protobuf:"bytes,4,rep,name=V6AddrPorts,proto3" json:"V6AddrPorts,omitempty"` Counter uint32 `protobuf:"varint,3,opt,name=counter,proto3" json:"counter,omitempty"` @@ -253,6 +260,13 @@ func (m *NebulaMetaDetails) GetRelayVpnAddrs() []*Addr { return nil } +func (m *NebulaMetaDetails) GetEnableHostQueryFiltering() bool { + if m != nil { + return m.EnableHostQueryFiltering + } + return false +} + func (m *NebulaMetaDetails) GetHandshakeFilteringWhitelist() *HandshakeFilteringWhitelist { if m != nil { return m.HandshakeFilteringWhitelist @@ -452,7 +466,8 @@ type HandshakeFilteringWhitelist struct { AllowedCidrs []string `protobuf:"bytes,4,rep,name=AllowedCidrs,proto3" json:"AllowedCidrs,omitempty"` AllowedCANames []string `protobuf:"bytes,5,rep,name=AllowedCANames,proto3" json:"AllowedCANames,omitempty"` AllowedCAShas []string `protobuf:"bytes,6,rep,name=AllowedCAShas,proto3" json:"AllowedCAShas,omitempty"` - SetEmpty bool `protobuf:"varint,7,opt,name=SetEmpty,proto3" json:"SetEmpty,omitempty"` + MessageId uint32 `protobuf:"varint,7,opt,name=MessageId,proto3" json:"MessageId,omitempty"` + Append bool `protobuf:"varint,8,opt,name=Append,proto3" json:"Append,omitempty"` } func (m *HandshakeFilteringWhitelist) Reset() { *m = HandshakeFilteringWhitelist{} } @@ -530,9 +545,16 @@ func (m *HandshakeFilteringWhitelist) GetAllowedCAShas() []string { return nil } -func (m *HandshakeFilteringWhitelist) GetSetEmpty() bool { +func (m *HandshakeFilteringWhitelist) GetMessageId() uint32 { if m != nil { - return m.SetEmpty + return m.MessageId + } + return 0 +} + +func (m *HandshakeFilteringWhitelist) GetAppend() bool { + if m != nil { + return m.Append } return false } @@ -883,67 +905,69 @@ func init() { func init() { proto.RegisterFile("nebula.proto", fileDescriptor_2d65afa7693df5ef) } var fileDescriptor_2d65afa7693df5ef = []byte{ - // 949 bytes of a gzipped FileDescriptorProto - 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0x84, 0x56, 0xcd, 0x72, 0x1b, 0x45, - 0x10, 0xd6, 0x4a, 0xab, 0xbf, 0xd6, 0x4f, 0x96, 0x76, 0x30, 0xeb, 0x50, 0xa8, 0xc4, 0x92, 0x72, - 0xb9, 0x38, 0x28, 0x94, 0x6d, 0x52, 0x1c, 0x51, 0x04, 0x46, 0x49, 0xc5, 0x8e, 0x99, 0x18, 0xa7, - 0x8a, 0x0b, 0xb5, 0xd6, 0x0e, 0xde, 0x29, 0xad, 0x76, 0x94, 0xdd, 0x11, 0x44, 0x6f, 0xc1, 0xa3, - 0x70, 0xe0, 0xca, 0x9d, 0x63, 0x4e, 0x14, 0x47, 0xca, 0x3e, 0x72, 0xe4, 0x05, 0xa8, 0x99, 0xfd, - 0x97, 0x16, 0xe7, 0x36, 0xdd, 0xfd, 0x75, 0xef, 0x37, 0x5f, 0xb7, 0x7a, 0x04, 0x5d, 0x9f, 0x5e, - 0xad, 0x3c, 0x7b, 0xb4, 0x0c, 0xb8, 0xe0, 0xd8, 0x88, 0x2c, 0xeb, 0x9f, 0x2a, 0xc0, 0x99, 0x3a, - 0x9e, 0x52, 0x61, 0xe3, 0x21, 0xe8, 0x17, 0xeb, 0x25, 0x35, 0xb5, 0xa1, 0x76, 0xd0, 0x3f, 0x1c, - 0x8c, 0xe2, 0x9c, 0x0c, 0x31, 0x3a, 0xa5, 0x61, 0x68, 0x5f, 0x53, 0x89, 0x22, 0x0a, 0x8b, 0x47, - 0xd0, 0xfc, 0x8a, 0x0a, 0x9b, 0x79, 0xa1, 0x59, 0x1d, 0x6a, 0x07, 0x9d, 0xc3, 0xbd, 0xed, 0xb4, - 0x18, 0x40, 0x12, 0xa4, 0xf5, 0xaf, 0x06, 0x9d, 0x5c, 0x29, 0x6c, 0x81, 0x7e, 0xc6, 0x7d, 0x6a, - 0x54, 0xb0, 0x07, 0xed, 0x29, 0x0f, 0xc5, 0xb7, 0x2b, 0x1a, 0xac, 0x0d, 0x0d, 0x11, 0xfa, 0xa9, - 0x49, 0xe8, 0xd2, 0x5b, 0x1b, 0x55, 0x7c, 0x00, 0xbb, 0xd2, 0xf7, 0xdd, 0xd2, 0xb1, 0x05, 0x3d, - 0xe3, 0x82, 0xfd, 0xc8, 0x66, 0xb6, 0x60, 0xdc, 0x37, 0x6a, 0xb8, 0x07, 0xef, 0xcb, 0xd8, 0x29, - 0xff, 0x89, 0x3a, 0x85, 0x90, 0x9e, 0x84, 0xce, 0x57, 0xfe, 0xcc, 0x2d, 0x84, 0xea, 0xd8, 0x07, - 0x90, 0xa1, 0x57, 0x2e, 0xb7, 0x17, 0xcc, 0x68, 0xe0, 0x0e, 0xdc, 0xcb, 0xec, 0xe8, 0xb3, 0x4d, - 0xc9, 0xec, 0xdc, 0x16, 0xee, 0xc4, 0xa5, 0xb3, 0xb9, 0xd1, 0x92, 0xcc, 0x52, 0x33, 0x82, 0xb4, - 0xf1, 0x23, 0xd8, 0x2b, 0x67, 0x36, 0x9e, 0xcd, 0x0d, 0xb0, 0x7e, 0xad, 0xc1, 0x7b, 0x5b, 0xa2, - 0xa0, 0x05, 0xf0, 0xc2, 0x73, 0x2e, 0x97, 0xfe, 0xd8, 0x71, 0x02, 0x25, 0x7d, 0xef, 0x49, 0xd5, - 0xd4, 0x48, 0xce, 0x8b, 0xfb, 0xd0, 0x4c, 0x00, 0x0d, 0x25, 0x72, 0x37, 0x11, 0x59, 0xfa, 0x48, - 0x12, 0xc4, 0x11, 0x18, 0x2f, 0x3c, 0x87, 0x50, 0xcf, 0x5e, 0xc7, 0xae, 0xd0, 0xac, 0x0f, 0x6b, - 0x71, 0xc5, 0xad, 0x18, 0x1e, 0x42, 0xaf, 0x08, 0x6e, 0x0e, 0x6b, 0x5b, 0xd5, 0x8b, 0x10, 0xa4, - 0xf0, 0xe1, 0xd4, 0xf6, 0x9d, 0xd0, 0xb5, 0xe7, 0xf4, 0x84, 0x79, 0x82, 0x06, 0xcc, 0xbf, 0x7e, - 0xe5, 0x32, 0x41, 0x3d, 0x16, 0x0a, 0xb3, 0xa5, 0xf8, 0x7d, 0x92, 0x54, 0xb8, 0x03, 0x4a, 0xee, - 0xaa, 0x83, 0xc7, 0xd0, 0xb9, 0x3c, 0x96, 0x5f, 0x3c, 0xe7, 0x81, 0x90, 0xb3, 0x25, 0x89, 0x61, - 0x52, 0x36, 0x0b, 0x91, 0x3c, 0x4c, 0x65, 0x3d, 0xce, 0xb2, 0xf4, 0x8d, 0xac, 0xc7, 0xb9, 0xac, - 0x0c, 0x86, 0x26, 0x34, 0x67, 0x7c, 0xe5, 0x0b, 0x1a, 0x98, 0x35, 0xa9, 0x3f, 0x49, 0x4c, 0x6b, - 0x1f, 0x74, 0x25, 0x6c, 0x1f, 0xaa, 0x53, 0xa6, 0x9a, 0xa3, 0x93, 0xea, 0x94, 0x49, 0xfb, 0x39, - 0x57, 0x03, 0xaf, 0x93, 0xea, 0x73, 0x6e, 0x1d, 0x03, 0x64, 0x34, 0x10, 0xa3, 0xac, 0xa8, 0x99, - 0x24, 0xaa, 0x80, 0xa0, 0xcb, 0x98, 0xca, 0xe9, 0x11, 0x75, 0xb6, 0xbe, 0x04, 0xc8, 0x68, 0xbc, - 0xeb, 0x1b, 0x69, 0x85, 0x5a, 0xae, 0xc2, 0xef, 0xd5, 0x3b, 0xbb, 0x81, 0x16, 0x74, 0xc7, 0x9e, - 0xc7, 0x7f, 0xa6, 0x8e, 0x1c, 0xcc, 0xd0, 0xd4, 0x86, 0xb5, 0x83, 0x36, 0x29, 0xf8, 0xf0, 0x21, - 0xf4, 0x62, 0xfb, 0x9b, 0x80, 0xaf, 0x96, 0x91, 0xd6, 0x6d, 0x52, 0x74, 0xe2, 0x09, 0xec, 0x14, - 0x1c, 0x13, 0xbe, 0xb8, 0xe2, 0xa1, 0x59, 0x53, 0x0a, 0xdf, 0x4f, 0x14, 0xce, 0xc7, 0x48, 0x59, - 0x42, 0x8e, 0xd1, 0x84, 0xc9, 0x89, 0xd3, 0x0b, 0x8c, 0x94, 0x0f, 0xf7, 0xa1, 0x9f, 0xd8, 0xe3, - 0x33, 0x7b, 0x41, 0xa3, 0x21, 0x6e, 0x93, 0x0d, 0x6f, 0x8e, 0xf9, 0x64, 0xfc, 0xd2, 0xb5, 0x43, - 0xb3, 0x51, 0x60, 0x1e, 0x39, 0xf1, 0x01, 0xb4, 0x5e, 0x52, 0xf1, 0xf5, 0x62, 0x29, 0xd6, 0x66, - 0x73, 0xa8, 0x1d, 0xb4, 0x48, 0x6a, 0x5b, 0x0f, 0xa1, 0x5b, 0x60, 0x77, 0x1f, 0xea, 0xca, 0x8e, - 0x85, 0x8a, 0x0c, 0xeb, 0x4d, 0xb2, 0x25, 0xcf, 0x99, 0x7f, 0x7d, 0xf7, 0x96, 0x94, 0x88, 0x92, - 0x2d, 0x89, 0xa0, 0x5f, 0xb0, 0x05, 0x8d, 0xbb, 0xa9, 0xce, 0x96, 0xb5, 0xb5, 0x03, 0x65, 0xb2, - 0x51, 0xc1, 0x36, 0xd4, 0xa3, 0x8d, 0xa2, 0x59, 0x3f, 0xc0, 0xbd, 0xa8, 0x6e, 0xda, 0x64, 0xfc, - 0x22, 0x5b, 0xb8, 0x9a, 0xfa, 0xad, 0x6d, 0x30, 0x48, 0x91, 0x9b, 0x5b, 0x57, 0x92, 0x98, 0x2e, - 0xec, 0x99, 0x22, 0xd1, 0x25, 0xea, 0x6c, 0xfd, 0xa9, 0xc1, 0x6e, 0x79, 0x9e, 0x84, 0x4f, 0x68, - 0x20, 0xd4, 0x57, 0xba, 0x44, 0x9d, 0x65, 0x67, 0x9e, 0xfa, 0x4c, 0x30, 0x5b, 0xf0, 0xe0, 0xa9, - 0xef, 0xd0, 0x37, 0xf1, 0x3c, 0x6f, 0x78, 0x25, 0x8e, 0xd0, 0x70, 0xc9, 0x7d, 0x87, 0xc6, 0xb8, - 0x68, 0x6a, 0x37, 0xbc, 0xb8, 0x0b, 0x8d, 0x09, 0xe7, 0x73, 0x46, 0x4d, 0x5d, 0x29, 0x13, 0x5b, - 0xa9, 0x5e, 0xf5, 0x4c, 0x2f, 0x1c, 0x42, 0x47, 0x72, 0xb8, 0xa4, 0x41, 0xc8, 0xb8, 0xaf, 0x16, - 0x4d, 0x8f, 0xe4, 0x5d, 0xcf, 0xf4, 0x56, 0xc3, 0x68, 0x3e, 0xd3, 0x5b, 0x4d, 0xa3, 0x65, 0xfd, - 0x56, 0x83, 0x5e, 0x74, 0xb1, 0x09, 0xf7, 0x45, 0xc0, 0x3d, 0xfc, 0xbc, 0xd0, 0xb7, 0x8f, 0x8b, - 0xaa, 0xc5, 0xa0, 0x92, 0xd6, 0x7d, 0x06, 0x3b, 0xe9, 0xe5, 0xd4, 0x26, 0xcc, 0xdf, 0xbb, 0x2c, - 0x24, 0x33, 0xd2, 0x6b, 0xe6, 0x32, 0x22, 0x05, 0xca, 0x42, 0xf8, 0x29, 0xf4, 0x93, 0xdd, 0x7c, - 0xc1, 0xd5, 0xea, 0xd0, 0xd3, 0x77, 0x60, 0x23, 0x92, 0xdf, 0xf1, 0x27, 0x01, 0x5f, 0x28, 0x74, - 0x3d, 0x45, 0x6f, 0xc5, 0x70, 0x04, 0x9d, 0x7c, 0xe1, 0xb2, 0xf7, 0x23, 0x0f, 0x48, 0xdf, 0x84, - 0xb4, 0x78, 0xb3, 0x24, 0xa3, 0x08, 0xb1, 0xa6, 0xff, 0xf7, 0x9c, 0xef, 0x02, 0x4e, 0x02, 0x6a, - 0x0b, 0xaa, 0xf0, 0x84, 0xbe, 0x5e, 0xd1, 0x50, 0x18, 0x1a, 0x7e, 0x00, 0x3b, 0x05, 0xbf, 0x94, - 0x24, 0xa4, 0x46, 0xf5, 0xc9, 0xd1, 0x1f, 0x37, 0x03, 0xed, 0xed, 0xcd, 0x40, 0xfb, 0xfb, 0x66, - 0xa0, 0xfd, 0x72, 0x3b, 0xa8, 0xbc, 0xbd, 0x1d, 0x54, 0xfe, 0xba, 0x1d, 0x54, 0xbe, 0xdf, 0xbb, - 0x66, 0xc2, 0x5d, 0x5d, 0x8d, 0x66, 0x7c, 0xf1, 0x28, 0xf4, 0xec, 0xd9, 0xdc, 0x7d, 0xfd, 0x28, - 0xa2, 0x74, 0xd5, 0x50, 0xff, 0x6a, 0x8e, 0xfe, 0x0b, 0x00, 0x00, 0xff, 0xff, 0x6f, 0xb9, 0x8e, - 0x82, 0xe5, 0x08, 0x00, 0x00, + // 991 bytes of a gzipped FileDescriptorProto + 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0x84, 0x56, 0xcd, 0x72, 0xe3, 0x44, + 0x10, 0xb6, 0x2c, 0xf9, 0xaf, 0xfd, 0xb3, 0xa2, 0xb3, 0x18, 0x85, 0x1f, 0x97, 0x11, 0x5b, 0xa9, + 0x14, 0x07, 0x2f, 0x95, 0x84, 0x2d, 0x8a, 0x13, 0x5e, 0x43, 0x70, 0xb6, 0x36, 0xd9, 0x30, 0x84, + 0x6c, 0x15, 0x17, 0x4a, 0xb6, 0x86, 0x78, 0xca, 0xb2, 0xc6, 0x2b, 0x8d, 0x61, 0xf3, 0x16, 0x3c, + 0x01, 0x47, 0x9e, 0x80, 0x87, 0xe0, 0xb8, 0x27, 0xe0, 0x48, 0x25, 0x57, 0x1e, 0x82, 0x9a, 0xd1, + 0xbf, 0xed, 0x84, 0xdb, 0xf4, 0xd7, 0x5f, 0xb7, 0xbe, 0xee, 0x1e, 0xb5, 0x04, 0x2d, 0x9f, 0x4e, + 0x56, 0x9e, 0x33, 0x58, 0x06, 0x5c, 0x70, 0xac, 0x46, 0x96, 0xfd, 0xab, 0x0e, 0x70, 0xa6, 0x8e, + 0xa7, 0x54, 0x38, 0x78, 0x00, 0xc6, 0xc5, 0xf5, 0x92, 0x5a, 0x5a, 0x5f, 0xdb, 0xef, 0x1c, 0xf4, + 0x06, 0x71, 0x4c, 0xc6, 0x18, 0x9c, 0xd2, 0x30, 0x74, 0xae, 0xa8, 0x64, 0x11, 0xc5, 0xc5, 0x43, + 0xa8, 0x7d, 0x49, 0x85, 0xc3, 0xbc, 0xd0, 0x2a, 0xf7, 0xb5, 0xfd, 0xe6, 0xc1, 0xee, 0x66, 0x58, + 0x4c, 0x20, 0x09, 0xd3, 0xfe, 0xad, 0x0c, 0xcd, 0x5c, 0x2a, 0xac, 0x83, 0x71, 0xc6, 0x7d, 0x6a, + 0x96, 0xb0, 0x0d, 0x8d, 0x31, 0x0f, 0xc5, 0x37, 0x2b, 0x1a, 0x5c, 0x9b, 0x1a, 0x22, 0x74, 0x52, + 0x93, 0xd0, 0xa5, 0x77, 0x6d, 0x96, 0xf1, 0x5d, 0xe8, 0x4a, 0xec, 0xbb, 0xa5, 0xeb, 0x08, 0x7a, + 0xc6, 0x05, 0xfb, 0x91, 0x4d, 0x1d, 0xc1, 0xb8, 0x6f, 0xea, 0xb8, 0x0b, 0x6f, 0x4b, 0xdf, 0x29, + 0xff, 0x89, 0xba, 0x05, 0x97, 0x91, 0xb8, 0xce, 0x57, 0xfe, 0x74, 0x56, 0x70, 0x55, 0xb0, 0x03, + 0x20, 0x5d, 0x2f, 0x67, 0xdc, 0x59, 0x30, 0xb3, 0x8a, 0x3b, 0xf0, 0x20, 0xb3, 0xa3, 0xc7, 0xd6, + 0xa4, 0xb2, 0x73, 0x47, 0xcc, 0x46, 0x33, 0x3a, 0x9d, 0x9b, 0x75, 0xa9, 0x2c, 0x35, 0x23, 0x4a, + 0x03, 0x3f, 0x80, 0xdd, 0xed, 0xca, 0x86, 0xd3, 0xb9, 0x09, 0xd8, 0x05, 0x4c, 0x8b, 0x79, 0x39, + 0x63, 0x82, 0x7a, 0x2c, 0x14, 0x66, 0x33, 0x51, 0x56, 0xc4, 0x65, 0x48, 0xcb, 0xfe, 0x57, 0x87, + 0xb7, 0x36, 0xfa, 0x88, 0x36, 0xc0, 0x0b, 0xcf, 0xbd, 0x5c, 0xfa, 0x43, 0xd7, 0x0d, 0xd4, 0xb4, + 0xda, 0x4f, 0xcb, 0x96, 0x46, 0x72, 0x28, 0xee, 0x41, 0x2d, 0x21, 0x54, 0xd5, 0x5c, 0x5a, 0xc9, + 0x5c, 0x24, 0x46, 0x12, 0x27, 0x0e, 0xc0, 0x7c, 0xe1, 0xb9, 0x84, 0x7a, 0xce, 0x75, 0x0c, 0x85, + 0x56, 0xa5, 0xaf, 0xc7, 0x19, 0x37, 0x7c, 0x78, 0x00, 0xed, 0x22, 0xb9, 0xd6, 0xd7, 0x37, 0xb2, + 0x17, 0x29, 0xf8, 0x39, 0x58, 0x5f, 0xf9, 0xce, 0xc4, 0xa3, 0x69, 0x99, 0xc7, 0xcc, 0x13, 0x34, + 0x60, 0xfe, 0x95, 0x55, 0xef, 0x6b, 0xfb, 0x75, 0x72, 0xa7, 0x1f, 0x29, 0xbc, 0x37, 0x76, 0x7c, + 0x37, 0x9c, 0x39, 0x73, 0x9a, 0xa2, 0x69, 0x97, 0xac, 0x86, 0xaa, 0xed, 0xa3, 0xe4, 0xe9, 0xf7, + 0x50, 0xc9, 0x7d, 0x79, 0xf0, 0x08, 0x9a, 0x97, 0x47, 0x52, 0xed, 0x39, 0x0f, 0x84, 0xbc, 0xca, + 0xb2, 0x28, 0x4c, 0xd2, 0x66, 0x2e, 0x92, 0xa7, 0xa9, 0xa8, 0x27, 0x59, 0x94, 0xb1, 0x16, 0xf5, + 0x24, 0x17, 0x95, 0xd1, 0xd0, 0x82, 0xda, 0x94, 0xaf, 0x7c, 0x41, 0x03, 0x4b, 0x97, 0xb3, 0x23, + 0x89, 0x69, 0xef, 0x81, 0xa1, 0x86, 0xd2, 0x81, 0xf2, 0x98, 0xa9, 0xc1, 0x1a, 0xa4, 0x3c, 0x66, + 0xd2, 0x7e, 0xce, 0xd5, 0xfb, 0x65, 0x90, 0xf2, 0x73, 0x6e, 0x1f, 0x01, 0x64, 0x32, 0x10, 0xa3, + 0xa8, 0xe8, 0x22, 0x90, 0x28, 0x03, 0x82, 0x21, 0x7d, 0x2a, 0xa6, 0x4d, 0xd4, 0xd9, 0xfe, 0x02, + 0x20, 0x93, 0xf1, 0x7f, 0xcf, 0x48, 0x33, 0xe8, 0xb9, 0x0c, 0x7f, 0x95, 0xef, 0x9d, 0x06, 0xda, + 0xd0, 0x1a, 0x7a, 0x1e, 0xff, 0x99, 0xba, 0x72, 0x92, 0xa1, 0xa5, 0xf5, 0xf5, 0xfd, 0x06, 0x29, + 0x60, 0xf8, 0x08, 0xda, 0xb1, 0xfd, 0x75, 0xc0, 0x57, 0xcb, 0xa8, 0xd7, 0x0d, 0x52, 0x04, 0xf1, + 0x18, 0x76, 0x0a, 0xc0, 0x88, 0x2f, 0x26, 0x3c, 0xb4, 0x74, 0xd5, 0xe1, 0x87, 0x49, 0x87, 0xf3, + 0x3e, 0xb2, 0x2d, 0x20, 0xa7, 0x68, 0xc4, 0xe4, 0x6d, 0x35, 0x0a, 0x8a, 0x14, 0x86, 0x7b, 0xd0, + 0x49, 0xec, 0xe1, 0x99, 0xb3, 0xa0, 0xd1, 0x0b, 0xd0, 0x20, 0x6b, 0x68, 0x4e, 0xf9, 0x68, 0xf8, + 0xed, 0xcc, 0x09, 0xad, 0x6a, 0x41, 0x79, 0x04, 0xe2, 0xfb, 0xd0, 0x88, 0x57, 0xdb, 0x89, 0x6b, + 0xd5, 0x54, 0xf3, 0x32, 0x00, 0xbb, 0x50, 0x1d, 0x2e, 0x97, 0xd4, 0x77, 0xe3, 0x8b, 0x1f, 0x5b, + 0xf6, 0x23, 0x68, 0x15, 0x74, 0x3f, 0x84, 0x8a, 0xb2, 0xe3, 0x16, 0x46, 0x86, 0xfd, 0x3a, 0x59, + 0xd7, 0xe7, 0xf2, 0xd5, 0xb8, 0x77, 0x5d, 0x4b, 0xc6, 0x96, 0x75, 0x8d, 0x60, 0x5c, 0xb0, 0x05, + 0x8d, 0xe7, 0xac, 0xce, 0xb6, 0xbd, 0xb1, 0x8c, 0x65, 0xb0, 0x59, 0xc2, 0x06, 0x54, 0xa2, 0xd5, + 0xa6, 0xd9, 0x3f, 0xc0, 0x83, 0x28, 0x6f, 0x3a, 0x7e, 0xfc, 0x2c, 0xdb, 0xfc, 0x9a, 0x7a, 0x0b, + 0xd7, 0x14, 0xa4, 0xcc, 0xf5, 0xf5, 0x2f, 0x45, 0x8c, 0x17, 0xce, 0x54, 0x89, 0x68, 0x11, 0x75, + 0xb6, 0xff, 0xd4, 0xa0, 0xbb, 0x3d, 0x4e, 0xd2, 0x47, 0x34, 0x10, 0xea, 0x29, 0x2d, 0xa2, 0xce, + 0x72, 0x66, 0x27, 0x3e, 0x13, 0xcc, 0x11, 0x3c, 0x38, 0xf1, 0x5d, 0xfa, 0x3a, 0xbe, 0xe9, 0x6b, + 0xa8, 0xe4, 0x11, 0x1a, 0x2e, 0xb9, 0xef, 0xd2, 0x98, 0x17, 0xdd, 0xe7, 0x35, 0x54, 0xce, 0x65, + 0xc4, 0xf9, 0x9c, 0x51, 0xcb, 0x50, 0x9d, 0x89, 0xad, 0xb4, 0x5f, 0x95, 0xac, 0x5f, 0xd8, 0x87, + 0xa6, 0xd4, 0x70, 0x49, 0x83, 0x90, 0x71, 0x5f, 0x0d, 0xb2, 0x4d, 0xf2, 0xd0, 0x33, 0xa3, 0x5e, + 0x35, 0x6b, 0xcf, 0x8c, 0x7a, 0xcd, 0xac, 0xdb, 0xbf, 0xeb, 0xd0, 0x8e, 0x0a, 0x1b, 0x71, 0x5f, + 0x04, 0xdc, 0xc3, 0x4f, 0x0b, 0x73, 0xfb, 0xb0, 0xd8, 0xb5, 0x98, 0xb4, 0x65, 0x74, 0x9f, 0xc0, + 0x4e, 0x5a, 0x9c, 0xda, 0xaf, 0xf9, 0xba, 0xb7, 0xb9, 0x64, 0x44, 0x5a, 0x66, 0x2e, 0x22, 0xea, + 0xc0, 0x36, 0x17, 0x7e, 0x0c, 0x9d, 0x64, 0xe3, 0x5f, 0x70, 0xb5, 0x54, 0x8c, 0xf4, 0xeb, 0xb2, + 0xe6, 0xc9, 0x7f, 0x39, 0x8e, 0x03, 0xbe, 0x50, 0xec, 0x4a, 0xca, 0xde, 0xf0, 0xe1, 0x00, 0x9a, + 0xf9, 0xc4, 0xdb, 0xbe, 0x4a, 0x79, 0x42, 0xfa, 0xa5, 0x49, 0x93, 0xd7, 0xb6, 0x44, 0x14, 0x29, + 0xf6, 0xf8, 0xae, 0xff, 0x8a, 0x2e, 0xe0, 0x28, 0xa0, 0x8e, 0xa0, 0x8a, 0x4f, 0xe8, 0xab, 0x15, + 0x0d, 0x85, 0xa9, 0xe1, 0x3b, 0xb0, 0x53, 0xc0, 0x65, 0x4b, 0x42, 0x6a, 0x96, 0x9f, 0x1e, 0xfe, + 0x71, 0xd3, 0xd3, 0xde, 0xdc, 0xf4, 0xb4, 0x7f, 0x6e, 0x7a, 0xda, 0x2f, 0xb7, 0xbd, 0xd2, 0x9b, + 0xdb, 0x5e, 0xe9, 0xef, 0xdb, 0x5e, 0xe9, 0xfb, 0xdd, 0x2b, 0x26, 0x66, 0xab, 0xc9, 0x60, 0xca, + 0x17, 0x8f, 0x43, 0xcf, 0x99, 0xce, 0x67, 0xaf, 0x1e, 0x47, 0x92, 0x26, 0x55, 0xf5, 0x7b, 0x75, + 0xf8, 0x5f, 0x00, 0x00, 0x00, 0xff, 0xff, 0x13, 0xbd, 0x10, 0x52, 0x6e, 0x09, 0x00, 0x00, } func (m *NebulaMeta) Marshal() (dAtA []byte, err error) { @@ -1016,7 +1040,17 @@ func (m *NebulaMetaDetails) MarshalToSizedBuffer(dAtA []byte) (int, error) { i = encodeVarintNebula(dAtA, i, uint64(size)) } i-- - dAtA[i] = 0x42 + dAtA[i] = 0x4a + } + if m.EnableHostQueryFiltering { + i-- + if m.EnableHostQueryFiltering { + dAtA[i] = 1 + } else { + dAtA[i] = 0 + } + i-- + dAtA[i] = 0x40 } if len(m.RelayVpnAddrs) > 0 { for iNdEx := len(m.RelayVpnAddrs) - 1; iNdEx >= 0; iNdEx-- { @@ -1227,14 +1261,19 @@ func (m *HandshakeFilteringWhitelist) MarshalToSizedBuffer(dAtA []byte) (int, er _ = i var l int _ = l - if m.SetEmpty { + if m.Append { i-- - if m.SetEmpty { + if m.Append { dAtA[i] = 1 } else { dAtA[i] = 0 } i-- + dAtA[i] = 0x40 + } + if m.MessageId != 0 { + i = encodeVarintNebula(dAtA, i, uint64(m.MessageId)) + i-- dAtA[i] = 0x38 } if len(m.AllowedCAShas) > 0 { @@ -1601,6 +1640,9 @@ func (m *NebulaMetaDetails) Size() (n int) { n += 1 + l + sovNebula(uint64(l)) } } + if m.EnableHostQueryFiltering { + n += 2 + } if m.HandshakeFilteringWhitelist != nil { l = m.HandshakeFilteringWhitelist.Size() n += 1 + l + sovNebula(uint64(l)) @@ -1698,7 +1740,10 @@ func (m *HandshakeFilteringWhitelist) Size() (n int) { n += 1 + l + sovNebula(uint64(l)) } } - if m.SetEmpty { + if m.MessageId != 0 { + n += 1 + sovNebula(uint64(m.MessageId)) + } + if m.Append { n += 2 } return n @@ -2204,6 +2249,26 @@ func (m *NebulaMetaDetails) Unmarshal(dAtA []byte) error { } iNdEx = postIndex case 8: + if wireType != 0 { + return fmt.Errorf("proto: wrong wireType = %d for field EnableHostQueryFiltering", wireType) + } + var v int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowNebula + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + v |= int(b&0x7F) << shift + if b < 0x80 { + break + } + } + m.EnableHostQueryFiltering = bool(v != 0) + case 9: if wireType != 2 { return fmt.Errorf("proto: wrong wireType = %d for field HandshakeFilteringWhitelist", wireType) } @@ -2768,7 +2833,26 @@ func (m *HandshakeFilteringWhitelist) Unmarshal(dAtA []byte) error { iNdEx = postIndex case 7: if wireType != 0 { - return fmt.Errorf("proto: wrong wireType = %d for field SetEmpty", wireType) + return fmt.Errorf("proto: wrong wireType = %d for field MessageId", wireType) + } + m.MessageId = 0 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowNebula + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + m.MessageId |= uint32(b&0x7F) << shift + if b < 0x80 { + break + } + } + case 8: + if wireType != 0 { + return fmt.Errorf("proto: wrong wireType = %d for field Append", wireType) } var v int for shift := uint(0); ; shift += 7 { @@ -2785,7 +2869,7 @@ func (m *HandshakeFilteringWhitelist) Unmarshal(dAtA []byte) error { break } } - m.SetEmpty = bool(v != 0) + m.Append = bool(v != 0) default: iNdEx = preIndex skippy, err := skipNebula(dAtA[iNdEx:]) diff --git a/nebula.proto b/nebula.proto index d68d7050d..b82bba9d6 100644 --- a/nebula.proto +++ b/nebula.proto @@ -16,6 +16,8 @@ message NebulaMeta { PathCheck = 8; PathCheckReply = 9; HostUpdateNotificationAck = 10; + HostQueryWhitelist = 11; + HostQueryWhitelistAck = 12; } MessageType Type = 1; @@ -28,7 +30,8 @@ message NebulaMetaDetails { repeated uint32 OldRelayVpnAddrs = 5 [deprecated = true]; repeated Addr RelayVpnAddrs = 7; - HandshakeFilteringWhitelist HandshakeFilteringWhitelist = 8; + bool EnableHostQueryFiltering = 8; + HandshakeFilteringWhitelist HandshakeFilteringWhitelist = 9; repeated V4AddrPort V4AddrPorts = 2; repeated V6AddrPort V6AddrPorts = 4; @@ -58,7 +61,8 @@ message HandshakeFilteringWhitelist { repeated string AllowedCidrs = 4; repeated string AllowedCANames = 5; repeated string AllowedCAShas = 6; - bool SetEmpty = 7; + uint32 MessageId = 7; + bool Append = 8; } message GroupsCombos { diff --git a/overlay/route.go b/overlay/route.go index 687cc11b8..bb4d3e783 100644 --- a/overlay/route.go +++ b/overlay/route.go @@ -61,7 +61,7 @@ func makeRouteTree(l *logrus.Logger, routes []Route, allowMTU bool) (*bart.Table return routeTree, nil } -func parseRoutes(c *config.C, networks []netip.Prefix) ([]Route, error) { +func ParseRoutes(c *config.C, networks []netip.Prefix) ([]Route, error) { var err error r := c.Get("tun.routes") diff --git a/overlay/route_test.go b/overlay/route_test.go index 8f2c094ac..8ee7131e2 100644 --- a/overlay/route_test.go +++ b/overlay/route_test.go @@ -18,73 +18,73 @@ func Test_parseRoutes(t *testing.T) { require.NoError(t, err) // test no routes config - routes, err := parseRoutes(c, []netip.Prefix{n}) + routes, err := ParseRoutes(c, []netip.Prefix{n}) require.NoError(t, err) assert.Empty(t, routes) // not an array c.Settings["tun"] = map[interface{}]interface{}{"routes": "hi"} - routes, err = parseRoutes(c, []netip.Prefix{n}) + routes, err = ParseRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) require.EqualError(t, err, "tun.routes is not an array") // no routes c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{}} - routes, err = parseRoutes(c, []netip.Prefix{n}) + routes, err = ParseRoutes(c, []netip.Prefix{n}) require.NoError(t, err) assert.Empty(t, routes) // weird route c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{"asdf"}} - routes, err = parseRoutes(c, []netip.Prefix{n}) + routes, err = ParseRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) require.EqualError(t, err, "entry 1 in tun.routes is invalid") // no mtu c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{}}} - routes, err = parseRoutes(c, []netip.Prefix{n}) + routes, err = ParseRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) require.EqualError(t, err, "entry 1.mtu in tun.routes is not present") // bad mtu c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "nope"}}} - routes, err = parseRoutes(c, []netip.Prefix{n}) + routes, err = ParseRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) require.EqualError(t, err, "entry 1.mtu in tun.routes is not an integer: strconv.Atoi: parsing \"nope\": invalid syntax") // low mtu c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "499"}}} - routes, err = parseRoutes(c, []netip.Prefix{n}) + routes, err = ParseRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) require.EqualError(t, err, "entry 1.mtu in tun.routes is below 500: 499") // missing route c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "500"}}} - routes, err = parseRoutes(c, []netip.Prefix{n}) + routes, err = ParseRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) require.EqualError(t, err, "entry 1.route in tun.routes is not present") // unparsable route c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "route": "nope"}}} - routes, err = parseRoutes(c, []netip.Prefix{n}) + routes, err = ParseRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) require.EqualError(t, err, "entry 1.route in tun.routes failed to parse: netip.ParsePrefix(\"nope\"): no '/'") // below network range c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "route": "1.0.0.0/8"}}} - routes, err = parseRoutes(c, []netip.Prefix{n}) + routes, err = ParseRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) require.EqualError(t, err, "entry 1.route in tun.routes is not contained within the configured vpn networks; route: 1.0.0.0/8, networks: [10.0.0.0/24]") // above network range c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "route": "10.0.1.0/24"}}} - routes, err = parseRoutes(c, []netip.Prefix{n}) + routes, err = ParseRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) require.EqualError(t, err, "entry 1.route in tun.routes is not contained within the configured vpn networks; route: 10.0.1.0/24, networks: [10.0.0.0/24]") // Not in multiple ranges c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "route": "192.0.0.0/24"}}} - routes, err = parseRoutes(c, []netip.Prefix{n, netip.MustParsePrefix("192.1.0.0/24")}) + routes, err = ParseRoutes(c, []netip.Prefix{n, netip.MustParsePrefix("192.1.0.0/24")}) assert.Nil(t, routes) require.EqualError(t, err, "entry 1.route in tun.routes is not contained within the configured vpn networks; route: 192.0.0.0/24, networks: [10.0.0.0/24 192.1.0.0/24]") @@ -93,7 +93,7 @@ func Test_parseRoutes(t *testing.T) { map[interface{}]interface{}{"mtu": "9000", "route": "10.0.0.0/29"}, map[interface{}]interface{}{"mtu": "8000", "route": "10.0.0.1/32"}, }} - routes, err = parseRoutes(c, []netip.Prefix{n}) + routes, err = ParseRoutes(c, []netip.Prefix{n}) require.NoError(t, err) assert.Len(t, routes, 2) diff --git a/overlay/tun.go b/overlay/tun.go index 4a6377d2a..596afe235 100644 --- a/overlay/tun.go +++ b/overlay/tun.go @@ -35,7 +35,7 @@ func getAllRoutesFromConfig(c *config.C, vpnNetworks []netip.Prefix, initial boo return false, nil, nil } - routes, err := parseRoutes(c, vpnNetworks) + routes, err := ParseRoutes(c, vpnNetworks) if err != nil { return true, nil, util.NewContextualError("Could not parse tun.routes", nil, err) } diff --git a/remote_list.go b/remote_list.go index c8baed9ad..c46b378f2 100644 --- a/remote_list.go +++ b/remote_list.go @@ -445,10 +445,7 @@ func (r *RemoteList) unlockedSetHandshakeFilteringWhitelist(hfwl *HandshakeFilte return } - r.hf = NewHandshakeFilter() - if !hfwl.GetSetEmpty() { - r.hf.UnmarshalFromHfw(hfwl) - } + r.hf.UnmarshalFromHfw(hfwl) } // unlockedPrependV4 assumes you have the write lock and prepends the address in the reported list for this owner From 4e577dda773dd31c05c1d9ed10943b6590bbc7eb Mon Sep 17 00:00:00 2001 From: Daniel Jampen Date: Sun, 30 Mar 2025 13:31:20 +0200 Subject: [PATCH 6/8] fix HostQueryWhitelist message sent even if enableHostQueryProtection was false --- lighthouse.go | 118 +++++++++++++++++++++++++------------------------- 1 file changed, 60 insertions(+), 58 deletions(-) diff --git a/lighthouse.go b/lighthouse.go index 550890949..9114aa7be 100644 --- a/lighthouse.go +++ b/lighthouse.go @@ -1485,80 +1485,82 @@ func (lhh *LightHouseHandler) handleHostUpdateNotificationAck(n *NebulaMeta, fro return } - nb := make([]byte, 12, 12) - out := make([]byte, mtu) + if lhh.lh.enableHostQueryProtection.Load() { + nb := make([]byte, 12, 12) + out := make([]byte, mtu) + + // make sure to send hfwl to all lighthouses if there were changes + if lhh.lh.hf.IsModifiedSinceLastMashalling.Load() { + lighthouses := lhh.lh.GetLighthouses() + for lhVpnAddr := range lighthouses { + hi := lhh.lh.ifce.GetHostInfo(lhVpnAddr) + if hi == nil { + continue + } + + hi.hfwMessagesAckd = make(map[uint8]bool, 0) + } + } - // make sure to send hfwl to all lighthouses if there were changes - if lhh.lh.enableHostQueryProtection.Load() && lhh.lh.hf.IsModifiedSinceLastMashalling.Load() { - lighthouses := lhh.lh.GetLighthouses() - for lhVpnAddr := range lighthouses { + for _, lhVpnAddr := range fromVpnAddrs { hi := lhh.lh.ifce.GetHostInfo(lhVpnAddr) if hi == nil { continue } - hi.hfwMessagesAckd = make(map[uint8]bool, 0) - } - } - - for _, lhVpnAddr := range fromVpnAddrs { - hi := lhh.lh.ifce.GetHostInfo(lhVpnAddr) - if hi == nil { - continue - } - - nonAckedHfwMessageCounter := 0 - for _, value := range hi.hfwMessagesAckd { - if !value { - nonAckedHfwMessageCounter++ + nonAckedHfwMessageCounter := 0 + for _, value := range hi.hfwMessagesAckd { + if !value { + nonAckedHfwMessageCounter++ + } } - } - if len(hi.hfwMessagesAckd) == 0 || nonAckedHfwMessageCounter != 0 { - lhMtu := lhh.lh.GetMTUForAddr(lhVpnAddr) - hfwList := lhh.lh.hf.MarshalToHfwList(lhMtu) + if len(hi.hfwMessagesAckd) == 0 || nonAckedHfwMessageCounter != 0 { + lhMtu := lhh.lh.GetMTUForAddr(lhVpnAddr) + hfwList := lhh.lh.hf.MarshalToHfwList(lhMtu) - for i, hfw := range hfwList { - if _, ok := hi.hfwMessagesAckd[uint8(i)]; ok { - if hi.hfwMessagesAckd[uint8(i)] { - // skip as already ack'd - continue + for i, hfw := range hfwList { + if _, ok := hi.hfwMessagesAckd[uint8(i)]; ok { + if hi.hfwMessagesAckd[uint8(i)] { + // skip as already ack'd + continue + } } - } - msg := NebulaMeta{ - Type: NebulaMeta_HostQueryWhitelist, - Details: &NebulaMetaDetails{ - HandshakeFilteringWhitelist: hfw, - }, - } + msg := NebulaMeta{ + Type: NebulaMeta_HostQueryWhitelist, + Details: &NebulaMetaDetails{ + HandshakeFilteringWhitelist: hfw, + }, + } - msg.Details.HandshakeFilteringWhitelist.MessageId = uint32(i) - - if msg.Details.HandshakeFilteringWhitelist != nil && lhh.lh.l.Level >= logrus.DebugLevel { - lhh.lh.l.WithField("hosts", msg.Details.HandshakeFilteringWhitelist.AllowedHosts). - WithField("groups", msg.Details.HandshakeFilteringWhitelist.AllowedGroups). - WithField("groupcombos", msg.Details.HandshakeFilteringWhitelist.AllowedGroupsCombos). - WithField("cidrs", msg.Details.HandshakeFilteringWhitelist.AllowedCidrs). - WithField("canames", msg.Details.HandshakeFilteringWhitelist.AllowedCANames). - WithField("cashas", msg.Details.HandshakeFilteringWhitelist.AllowedCAShas). - WithField("i", i). - Debug("Sending hfw message") - } - msgSerialized, err := msg.Marshal() + msg.Details.HandshakeFilteringWhitelist.MessageId = uint32(i) + + if msg.Details.HandshakeFilteringWhitelist != nil && lhh.lh.l.Level >= logrus.DebugLevel { + lhh.lh.l.WithField("hosts", msg.Details.HandshakeFilteringWhitelist.AllowedHosts). + WithField("groups", msg.Details.HandshakeFilteringWhitelist.AllowedGroups). + WithField("groupcombos", msg.Details.HandshakeFilteringWhitelist.AllowedGroupsCombos). + WithField("cidrs", msg.Details.HandshakeFilteringWhitelist.AllowedCidrs). + WithField("canames", msg.Details.HandshakeFilteringWhitelist.AllowedCANames). + WithField("cashas", msg.Details.HandshakeFilteringWhitelist.AllowedCAShas). + WithField("i", i). + Debug("Sending hfw message") + } + msgSerialized, err := msg.Marshal() - if err != nil { - lhh.lh.l.WithError(err). - WithField("lighthouseAddr", lhVpnAddr). - Error("Error while marshaling for lighthouse hfw update") - break + if err != nil { + lhh.lh.l.WithError(err). + WithField("lighthouseAddr", lhVpnAddr). + Error("Error while marshaling for lighthouse hfw update") + break + } + + hi.hfwMessagesAckd[uint8(i)] = false + lhh.lh.ifce.SendMessageToVpnAddr(header.LightHouse, 0, lhVpnAddr, msgSerialized, nb, out) } - hi.hfwMessagesAckd[uint8(i)] = false - lhh.lh.ifce.SendMessageToVpnAddr(header.LightHouse, 0, lhVpnAddr, msgSerialized, nb, out) + return } - - return } } } From fb681dd278d875f8948253f9f4d305f228d7bebe Mon Sep 17 00:00:00 2001 From: Daniel Jampen Date: Sun, 30 Mar 2025 13:33:14 +0200 Subject: [PATCH 7/8] use assert.Len inHandshakeFilter marshalling tests --- firewall_test.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/firewall_test.go b/firewall_test.go index 4c2046cc7..d0dcec73c 100644 --- a/firewall_test.go +++ b/firewall_test.go @@ -702,7 +702,7 @@ func TestHandshakeFilter_Marshalling(t *testing.T) { assert.True(t, hf.IsEmtpy.Load()) hfwl := hf.MarshalToHfwList(1300) assert.False(t, hf.IsModifiedSinceLastMashalling.Load()) - assert.Equal(t, 1, len(hfwl)) + assert.Len(t, hfwl, 1) assert.Empty(t, hfwl[0].AllowedGroups) assert.Empty(t, hfwl[0].AllowedGroupsCombos) assert.Empty(t, hfwl[0].AllowedHosts) @@ -716,7 +716,7 @@ func TestHandshakeFilter_Marshalling(t *testing.T) { hf.AddRule([]string{g}, "", netip.Prefix{}, "", "") assert.True(t, hf.IsModifiedSinceLastMashalling.Load()) hfwl = hf.MarshalToHfwList(1300) - assert.Equal(t, 1, len(hfwl)) + assert.Len(t, hfwl, 1) assert.Contains(t, hfwl[0].AllowedGroups, g) assert.Empty(t, hfwl[0].AllowedGroupsCombos) assert.Empty(t, hfwl[0].AllowedHosts) @@ -741,7 +741,7 @@ func TestHandshakeFilter_Marshalling(t *testing.T) { assert.Len(t, hf.AllowedGroupsCombos, 1) assert.True(t, hf.IsModifiedSinceLastMashalling.Load()) hfwl = hf.MarshalToHfwList(1300) - assert.Equal(t, 1, len(hfwl)) + assert.Len(t, hfwl, 1) assert.Empty(t, hfwl[0].AllowedGroups) assert.Len(t, hfwl[0].AllowedGroupsCombos, 1) for _, g := range gc { From a39d9b36d3bc7f7fa601f001a9526c63b076a280 Mon Sep 17 00:00:00 2001 From: Daniel Jampen Date: Sun, 30 Mar 2025 13:36:51 +0200 Subject: [PATCH 8/8] use assert.Len for all len tests in TestHandshakeFilter_Marshalling --- firewall_test.go | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/firewall_test.go b/firewall_test.go index d0dcec73c..093ebcf5a 100644 --- a/firewall_test.go +++ b/firewall_test.go @@ -774,7 +774,7 @@ func TestHandshakeFilter_Marshalling(t *testing.T) { hf.AddRule(nil, h, netip.Prefix{}, "", "") assert.True(t, hf.IsModifiedSinceLastMashalling.Load()) hfwl = hf.MarshalToHfwList(1300) - assert.Equal(t, 1, len(hfwl)) + assert.Len(t, hfwl, 1) assert.Empty(t, hfwl[0].AllowedGroups) assert.Empty(t, hfwl[0].AllowedGroupsCombos) assert.Contains(t, hfwl[0].AllowedHosts, h) @@ -798,7 +798,7 @@ func TestHandshakeFilter_Marshalling(t *testing.T) { hf.AddRule(nil, "", p, "", "") assert.True(t, hf.IsModifiedSinceLastMashalling.Load()) hfwl = hf.MarshalToHfwList(1300) - assert.Equal(t, 1, len(hfwl)) + assert.Len(t, hfwl, 1) assert.Empty(t, hfwl[0].AllowedGroups) assert.Empty(t, hfwl[0].AllowedGroupsCombos) assert.Empty(t, hfwl[0].AllowedHosts) @@ -822,7 +822,7 @@ func TestHandshakeFilter_Marshalling(t *testing.T) { hf.AddRule(nil, "", netip.Prefix{}, ca, "") assert.True(t, hf.IsModifiedSinceLastMashalling.Load()) hfwl = hf.MarshalToHfwList(1300) - assert.Equal(t, 1, len(hfwl)) + assert.Len(t, hfwl, 1) assert.Empty(t, hfwl[0].AllowedGroups) assert.Empty(t, hfwl[0].AllowedGroupsCombos) assert.Empty(t, hfwl[0].AllowedHosts) @@ -846,7 +846,7 @@ func TestHandshakeFilter_Marshalling(t *testing.T) { hf.AddRule(nil, "", netip.Prefix{}, "", fp) assert.True(t, hf.IsModifiedSinceLastMashalling.Load()) hfwl = hf.MarshalToHfwList(1300) - assert.Equal(t, 1, len(hfwl)) + assert.Len(t, hfwl, 1) assert.Empty(t, hfwl[0].AllowedGroups) assert.Empty(t, hfwl[0].AllowedGroupsCombos) assert.Empty(t, hfwl[0].AllowedHosts)