Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 17 additions & 1 deletion examples/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,21 @@ lighthouse:
hosts:
- "192.168.100.1"

# This feature allows handshakes only from hosts whose hostname, groups, or IP address are specified in
# any inbound firewall rule. By doing so, it prevents unauthorized nodes from establishing a tunnel, conserving
# resources on the host and adding an additional security layer that reduces the potential attack surface.
# - If enabled on a lighthouse, this will prevent nodes from accessing lighthouse features unless there is an
# incoming rule allowing access. Use allow rules with `port: nebula` to grant access to nodes.
# - Similar considerations should be made before enabling the feature on relay nodes, as it can restrict
# access to certain nodes only. Both ends of the relayed tunnel must be allowed to communicate with this node
# for relaying to function.
#incoming_handshake_filtering: false

# This setting on a lighthouse determines whether to enforce the host query protection
# whitelist received from a node. On a node, this setting controls whether the node
# sends its handshake filtering whitelist to the lighthouses at all, not reloadable.
#enable_host_query_protection: false

# remote_allow_list allows you to control ip ranges that this node will
# consider when handshaking to another node. By default, any remote IPs are
# allowed. You can provide CIDRs here with `true` to allow and `false` to
Expand Down Expand Up @@ -340,7 +355,8 @@ firewall:
# The firewall is default deny. There is no way to write a deny rule.
# Rules are comprised of a protocol, port, and one or more of host, group, or CIDR
# Logical evaluation is roughly: port AND proto AND (ca_sha OR ca_name) AND (host OR group OR groups OR cidr) AND (local cidr)
# - port: Takes `0` or `any` as any, a single number `80`, a range `200-901`, or `fragment` to match second and further fragments of fragmented packets (since there is no port available).
# - port: Takes `0` or `any` as any, a single number `80`, a range `200-901`, `fragment` to match second and further fragments of fragmented
# packets (since there is no port available) or `nebula` to create nebula access rules (udp) for the handshake filtering feature.
# code: same as port but makes more sense when talking about ICMP, TODO: this is not currently implemented in a way that works, use `any`
# proto: `any`, `tcp`, `udp`, or `icmp`
# host: `any` or a literal hostname, ie `test-host`
Expand Down
296 changes: 289 additions & 7 deletions firewall.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"strconv"
"strings"
"sync"
"sync/atomic"
"time"

"github.com/gaissmai/bart"
Expand Down Expand Up @@ -82,6 +83,18 @@ type FirewallConntrack struct {
TimerWheel *TimerWheel[firewall.Packet]
}

type HandshakeFilter struct {
AllowedHosts map[string]struct{}
AllowedGroups map[string]struct{}
AllowedGroupsCombos []map[string]struct{}
AllowedCidrs []netip.Prefix
AllowedCANames map[string]struct{}
AllowedCAShas map[string]struct{}

IsEmtpy atomic.Bool
IsModifiedSinceLastMashalling atomic.Bool
}

// FirewallTable is the entry point for a rule, the evaluation order is:
// Proto AND port AND (CA SHA or CA name) AND local CIDR AND (group OR groups OR name OR remote CIDR)
type FirewallTable struct {
Expand Down Expand Up @@ -190,7 +203,7 @@ func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.D
}
}

func NewFirewallFromConfig(l *logrus.Logger, cs *CertState, c *config.C) (*Firewall, error) {
func NewFirewallFromConfig(l *logrus.Logger, cs *CertState, c *config.C) (*Firewall, *HandshakeFilter, error) {
certificate := cs.getCertificate(cert.Version2)
if certificate == nil {
certificate = cs.getCertificate(cert.Version1)
Expand All @@ -209,6 +222,8 @@ func NewFirewallFromConfig(l *logrus.Logger, cs *CertState, c *config.C) (*Firew
//TODO: max_connections
)

hf := NewHandshakeFilter()

fw.defaultLocalCIDRAny = c.GetBool("firewall.default_local_cidr_any", false)

inboundAction := c.GetString("firewall.inbound_action", "drop")
Expand All @@ -233,17 +248,17 @@ func NewFirewallFromConfig(l *logrus.Logger, cs *CertState, c *config.C) (*Firew
fw.OutSendReject = false
}

err := AddFirewallRulesFromConfig(l, false, c, fw)
err := AddFirewallRulesFromConfig(l, false, c, fw, nil)
if err != nil {
return nil, err
return nil, nil, err
}

err = AddFirewallRulesFromConfig(l, true, c, fw)
err = AddFirewallRulesFromConfig(l, true, c, fw, hf)
if err != nil {
return nil, err
return nil, nil, err
}

return fw, nil
return fw, hf, nil
}

// AddRule properly creates the in memory rule structure for a firewall table.
Expand Down Expand Up @@ -318,7 +333,7 @@ func (f *Firewall) GetRuleHashes() string {
return "SHA:" + f.GetRuleHash() + ",FNV:" + strconv.FormatUint(uint64(f.GetRuleHashFNV()), 10)
}

func AddFirewallRulesFromConfig(l *logrus.Logger, inbound bool, c *config.C, fw FirewallInterface) error {
func AddFirewallRulesFromConfig(l *logrus.Logger, inbound bool, c *config.C, fw FirewallInterface, hf *HandshakeFilter) error {
var table string
if inbound {
table = "firewall.inbound"
Expand Down Expand Up @@ -412,6 +427,10 @@ func AddFirewallRulesFromConfig(l *logrus.Logger, inbound bool, c *config.C, fw
if err != nil {
return fmt.Errorf("%s rule #%v; `%s`", table, i, err)
}

if hf != nil {
hf.AddRule(groups, r.Host, cidr, r.CAName, r.CASha)
}
}

return nil
Expand Down Expand Up @@ -981,6 +1000,10 @@ func parsePort(s string) (startPort, endPort int32, err error) {
startPort = firewall.PortFragment
endPort = firewall.PortFragment

} else if s == "nebula" {
startPort = firewall.PortNebula
endPort = firewall.PortNebula

} else if strings.Contains(s, `-`) {
sPorts := strings.SplitN(s, `-`, 2)
sPorts[0] = strings.Trim(sPorts[0], " ")
Expand Down Expand Up @@ -1018,3 +1041,262 @@ func parsePort(s string) (startPort, endPort int32, err error) {

return
}

func NewHandshakeFilter() *HandshakeFilter {
hf := &HandshakeFilter{
AllowedHosts: make(map[string]struct{}),
AllowedGroups: make(map[string]struct{}),
AllowedGroupsCombos: make([]map[string]struct{}, 0),
AllowedCidrs: make([]netip.Prefix, 0),
AllowedCANames: make(map[string]struct{}),
AllowedCAShas: make(map[string]struct{}),
}

hf.IsModifiedSinceLastMashalling.Store(false)
hf.IsEmtpy.Store(true)

return hf
}

func (hfws *HandshakeFilter) AddRule(groups []string, host string, localIp netip.Prefix, CAName string, CASha string) {
ruleAdded := false
if host != "" {
hfws.AllowedHosts[host] = struct{}{}
ruleAdded = true
}

if len(groups) > 1 {
gs := make(map[string]struct{}, len(groups))
for i := range groups {
gs[groups[i]] = struct{}{}
}

if !containsMap(hfws.AllowedGroupsCombos, gs) {
hfws.AllowedGroupsCombos = append(
hfws.AllowedGroupsCombos,
gs,
)
}
ruleAdded = true
} else if len(groups) == 1 {
hfws.AllowedGroups[groups[0]] = struct{}{}
ruleAdded = true
}

if localIp.IsValid() {
hfws.AllowedCidrs = append(
hfws.AllowedCidrs,
localIp,
)
ruleAdded = true
}

if CAName != "" {
hfws.AllowedCANames[CAName] = struct{}{}
ruleAdded = true
}

if CASha != "" {
hfws.AllowedCAShas[CASha] = struct{}{}
ruleAdded = true
}

hfws.IsModifiedSinceLastMashalling.Store(ruleAdded)

if ruleAdded {
hfws.IsEmtpy.Store(false)
}
}

func (hfws *HandshakeFilter) IsHandshakeAllowed(groups []string, host string, vpnAddrs []netip.Addr, CAName string, CASha string) bool {
if _, ok := hfws.AllowedHosts["any"]; ok {
return true
}
if _, ok := hfws.AllowedHosts[host]; ok {
return true
}

if _, ok := hfws.AllowedCANames[CAName]; ok {
return true
}

if _, ok := hfws.AllowedCAShas[CASha]; ok {
return true
}

if len(groups) != 0 {
if _, ok := hfws.AllowedGroups["any"]; ok {
return true
}
}
for _, g := range groups {
if _, ok := hfws.AllowedGroups[g]; ok {
return true
}
}

for _, c := range hfws.AllowedCidrs {
for _, a := range vpnAddrs {
if c.Contains(a) {
return true
}
}
}

for _, gc := range hfws.AllowedGroupsCombos {
if len(groups) < len(gc) {
continue
}

if isSubset(gc, groups) {
return true
}
}

return false
}

func (hfws *HandshakeFilter) MarshalToHfwList(maxMessageSize int) []*HandshakeFilteringWhitelist {
hfwList := make([]*HandshakeFilteringWhitelist, 0)
maxMessageSize = maxMessageSize - 10 // account for potential overhead

appendOldAndGetNewHfw := func(old *HandshakeFilteringWhitelist) *HandshakeFilteringWhitelist {
if old != nil {
hfwList = append(hfwList, old)
}
return &HandshakeFilteringWhitelist{
AllowedHosts: make([]string, 0),
AllowedGroups: make([]string, 0),
AllowedGroupsCombos: make([]*GroupsCombos, 0),
AllowedCidrs: make([]string, 0),
AllowedCANames: make([]string, 0),
AllowedCAShas: make([]string, 0),
Append: old != nil,
}
}

type appendToHfw func(hfw *HandshakeFilteringWhitelist, e string)
addIteratableToHfw := func(hfw *HandshakeFilteringWhitelist, e string, appendToHfwFunc appendToHfw) *HandshakeFilteringWhitelist {
if hfw.Size()+len(e) > maxMessageSize {
hfw = appendOldAndGetNewHfw(hfw)
}
appendToHfwFunc(hfw, e)
return hfw
}

hfw := appendOldAndGetNewHfw(nil)

for e := range hfws.AllowedHosts {
hfw = addIteratableToHfw(hfw, e, func(h *HandshakeFilteringWhitelist, e string) {
h.AllowedHosts = append(h.AllowedHosts, e)
})
}

for e := range hfws.AllowedGroups {
hfw = addIteratableToHfw(hfw, e, func(h *HandshakeFilteringWhitelist, e string) {
h.AllowedGroups = append(h.AllowedGroups, e)
})
}

for _, groupCombo := range hfws.AllowedGroupsCombos {
gc := &GroupsCombos{
Group: make([]string, len(groupCombo)),
}

j := 0
groupBytes := 0
for group := range groupCombo {
gc.Group[j] = group
groupBytes += len(group)
j += 1
}

if hfw.Size()+groupBytes > maxMessageSize {
hfw = appendOldAndGetNewHfw(hfw)
}
hfw.AllowedGroupsCombos = append(hfw.AllowedGroupsCombos, gc)
}

for _, e := range hfws.AllowedCidrs {
hfw = addIteratableToHfw(hfw, e.String(), func(h *HandshakeFilteringWhitelist, e string) {
h.AllowedCidrs = append(h.AllowedCidrs, e)
})
}

for e := range hfws.AllowedCANames {
hfw = addIteratableToHfw(hfw, e, func(h *HandshakeFilteringWhitelist, e string) {
h.AllowedCANames = append(h.AllowedCANames, e)
})
}

for e := range hfws.AllowedCAShas {
hfw = addIteratableToHfw(hfw, e, func(h *HandshakeFilteringWhitelist, e string) {
h.AllowedCAShas = append(h.AllowedCAShas, e)
})
}

hfws.IsModifiedSinceLastMashalling.Store(false)

hfwList = append(hfwList, hfw)
return hfwList
}

func (hfws *HandshakeFilter) UnmarshalFromHfw(hfw *HandshakeFilteringWhitelist) *HandshakeFilter {
if hfw == nil {
return hfws
}

if !hfw.Append {
hfws = NewHandshakeFilter()
}

for _, h := range hfw.AllowedHosts {
hfws.AddRule(nil, h, netip.Prefix{}, "", "")
}

for _, g := range hfw.AllowedGroups {
hfws.AddRule([]string{g}, "", netip.Prefix{}, "", "")
}

for _, gc := range hfw.AllowedGroupsCombos {
hfws.AddRule(gc.Group, "", netip.Prefix{}, "", "")
}

for _, cs := range hfw.AllowedCidrs {
c, err := netip.ParsePrefix(cs)
if err != nil {
continue
}
hfws.AddRule(nil, "", c, "", "")
}

for _, ca := range hfw.AllowedCANames {
hfws.AddRule(nil, "", netip.Prefix{}, ca, "")
}

for _, sha := range hfw.AllowedCAShas {
hfws.AddRule(nil, "", netip.Prefix{}, "", sha)
}

return hfws
}

func isSubset(subset map[string]struct{}, superset []string) bool {
ls := len(subset)
s := make(map[string]struct{}, ls)
for _, value := range superset {
if _, ok := subset[value]; ok {
s[value] = struct{}{}
}
}
return len(s) == ls
}

func containsMap(slice []map[string]struct{}, target map[string]struct{}) bool {
for _, m := range slice {
if reflect.DeepEqual(m, target) {
return true
}
}
return false
}
Loading
Loading