diff --git a/conntrack_linux.go b/conntrack_linux.go index ff20869b..43fed108 100644 --- a/conntrack_linux.go +++ b/conntrack_linux.go @@ -71,6 +71,12 @@ func ConntrackUpdate(table ConntrackTableType, family InetFamily, flow *Conntrac return pkgHandle.ConntrackUpdate(table, family, flow) } +// ConntrackDelete deletes an existing conntrack flow in the desired table using the handle +// conntrack -D [table] Delete conntrack flow +func ConntrackDelete(table ConntrackTableType, family InetFamily, flow *ConntrackFlow) error { + return pkgHandle.ConntrackDelete(table, family, flow) +} + // ConntrackDeleteFilter deletes entries on the specified table on the base of the filter // conntrack -D [table] parameters Delete conntrack or expectation // @@ -148,6 +154,23 @@ func (h *Handle) ConntrackUpdate(table ConntrackTableType, family InetFamily, fl return err } +// ConntrackDelete deletes an existing conntrack flow in the desired table using the handle +// conntrack -D [table] Delete a conntrack +func (h *Handle) ConntrackDelete(table ConntrackTableType, family InetFamily, flow *ConntrackFlow) error { + req := h.newConntrackRequest(table, family, nl.IPCTNL_MSG_CT_DELETE, unix.NLM_F_ACK) + attr, err := flow.toNlData() + if err != nil { + return err + } + + for _, a := range attr { + req.AddData(a) + } + + _, err = req.Execute(unix.NETLINK_NETFILTER, 0) + return err +} + // ConntrackDeleteFilter deletes entries on the specified table on the base of the filter using the netlink handle passed // conntrack -D [table] parameters Delete conntrack or expectation // diff --git a/conntrack_test.go b/conntrack_test.go index 48e5c4a1..de850e17 100644 --- a/conntrack_test.go +++ b/conntrack_test.go @@ -1511,6 +1511,226 @@ func TestConntrackCreateV6(t *testing.T) { checkProtoInfosEqual(t, flow.ProtoInfo, match.ProtoInfo) } +// TestConntrackDeleteV4 creates an IPv4 conntrack entry, verifies it exists, +// deletes it via the package-level wrapper ConntrackDelete (which uses pkgHandle), +// and verifies it was removed. +func TestConntrackDeleteV4(t *testing.T) { + // Print timestamps in UTC + os.Setenv("TZ", "") + + requiredModules := []string{"nf_conntrack", "nf_conntrack_netlink"} + k, m, err := KernelVersion() + if err != nil { + t.Fatal(err) + } + // Conntrack l3proto was unified since 4.19 + // https://github.com/torvalds/linux/commit/a0ae2562c6c4b2721d9fddba63b7286c13517d9f + if k < 4 || k == 4 && m < 19 { + requiredModules = append(requiredModules, "nf_conntrack_ipv4") + } + // Implicitly skips test if not root: + nsStr, teardown := setUpNamedNetlinkTestWithKModule(t, requiredModules...) + t.Cleanup(teardown) + + ns, err := netns.GetFromName(nsStr) + if err != nil { + t.Fatalf("couldn't get handle to generated namespace: %s", err) + } + + h, err := NewHandleAt(ns, nl.FAMILY_V4) + if err != nil { + t.Fatalf("failed to create netlink handle: %s", err) + } + + // Point pkgHandle to the namespaced handle so the package-level wrapper acts in this ns. + orig := pkgHandle + pkgHandle = h + defer func() { pkgHandle = orig }() + + flow := ConntrackFlow{ + FamilyType: FAMILY_V4, + Forward: IPTuple{ + SrcIP: net.IP{234, 234, 234, 234}, + DstIP: net.IP{123, 123, 123, 123}, + SrcPort: 48385, + DstPort: 53, + Protocol: unix.IPPROTO_TCP, + }, + Reverse: IPTuple{ + SrcIP: net.IP{123, 123, 123, 123}, + DstIP: net.IP{234, 234, 234, 234}, + SrcPort: 53, + DstPort: 48385, + Protocol: unix.IPPROTO_TCP, + }, + TimeOut: 100, + Mark: 12, + ProtoInfo: &ProtoInfoTCP{ + State: nl.TCP_CONNTRACK_ESTABLISHED, + }, + } + + // Create the entry using the handle + if err := h.ConntrackCreate(ConntrackTable, nl.FAMILY_V4, &flow); err != nil { + t.Fatalf("failed to insert conntrack: %s", err) + } + + // Verify it exists + flows, err := h.ConntrackTableList(ConntrackTable, nl.FAMILY_V4) + if err != nil { + t.Fatalf("failed to list conntracks following successful insert: %s", err) + } + filter := ConntrackFilter{ + ipNetFilter: map[ConntrackFilterType]*net.IPNet{ + ConntrackOrigSrcIP: NewIPNet(flow.Forward.SrcIP), + ConntrackOrigDstIP: NewIPNet(flow.Forward.DstIP), + ConntrackReplySrcIP: NewIPNet(flow.Reverse.SrcIP), + ConntrackReplyDstIP: NewIPNet(flow.Reverse.DstIP), + }, + portFilter: map[ConntrackFilterType]uint16{ + ConntrackOrigSrcPort: flow.Forward.SrcPort, + ConntrackOrigDstPort: flow.Forward.DstPort, + }, + protoFilter: unix.IPPROTO_TCP, + } + var match *ConntrackFlow + for _, f := range flows { + if filter.MatchConntrackFlow(f) { + match = f + break + } + } + if match == nil { + t.Fatalf("didn't find any matching conntrack entries for original flow: %+v\n Filter used: %+v", flow, filter) + } + + // Delete using the handler + if err := h.ConntrackDelete(ConntrackTable, InetFamily(nl.FAMILY_V4), &flow); err != nil { + t.Fatalf("failed to delete conntrack via handler: %s", err) + } + + // Verify it's gone + flows, err = h.ConntrackTableList(ConntrackTable, nl.FAMILY_V4) + if err != nil { + t.Fatalf("failed to list conntracks following delete: %s", err) + } + for _, f := range flows { + if filter.MatchConntrackFlow(f) { + t.Fatalf("found flow after delete: %+v", f) + } + } +} + +// TestConntrackDeleteV6 creates an IPv6 conntrack entry, verifies it exists, +// deletes it via the package-level wrapper ConntrackDelete (which uses pkgHandle), +// and verifies it was removed. +func TestConntrackDeleteV6(t *testing.T) { + // Print timestamps in UTC + os.Setenv("TZ", "") + + requiredModules := []string{"nf_conntrack", "nf_conntrack_netlink"} + k, m, err := KernelVersion() + if err != nil { + t.Fatal(err) + } + // Conntrack l3proto was unified since 4.19 + // https://github.com/torvalds/linux/commit/a0ae2562c6c4b2721d9fddba63b7286c13517d9f + if k < 4 || k == 4 && m < 19 { + requiredModules = append(requiredModules, "nf_conntrack_ipv4") + } + // Implicitly skips test if not root: + nsStr, teardown := setUpNamedNetlinkTestWithKModule(t, requiredModules...) + t.Cleanup(teardown) + + ns, err := netns.GetFromName(nsStr) + if err != nil { + t.Fatalf("couldn't get handle to generated namespace: %s", err) + } + + h, err := NewHandleAt(ns, nl.FAMILY_V6) + if err != nil { + t.Fatalf("failed to create netlink handle: %s", err) + } + + // Point pkgHandle to the namespaced handle so the package-level wrapper acts in this ns. + orig := pkgHandle + pkgHandle = h + defer func() { pkgHandle = orig }() + + flow := ConntrackFlow{ + FamilyType: FAMILY_V6, + Forward: IPTuple{ + SrcIP: net.ParseIP("2001:db8::68"), + DstIP: net.ParseIP("2001:db9::32"), + SrcPort: 48385, + DstPort: 53, + Protocol: unix.IPPROTO_TCP, + }, + Reverse: IPTuple{ + SrcIP: net.ParseIP("2001:db9::32"), + DstIP: net.ParseIP("2001:db8::68"), + SrcPort: 53, + DstPort: 48385, + Protocol: unix.IPPROTO_TCP, + }, + TimeOut: 100, + Mark: 12, + ProtoInfo: &ProtoInfoTCP{ + State: nl.TCP_CONNTRACK_ESTABLISHED, + }, + } + + // Create the entry using the handle + if err := h.ConntrackCreate(ConntrackTable, nl.FAMILY_V6, &flow); err != nil { + t.Fatalf("failed to insert conntrack: %s", err) + } + + // Verify it exists + flows, err := h.ConntrackTableList(ConntrackTable, nl.FAMILY_V6) + if err != nil { + t.Fatalf("failed to list conntracks following successful insert: %s", err) + } + filter := ConntrackFilter{ + ipNetFilter: map[ConntrackFilterType]*net.IPNet{ + ConntrackOrigSrcIP: NewIPNet(flow.Forward.SrcIP), + ConntrackOrigDstIP: NewIPNet(flow.Forward.DstIP), + ConntrackReplySrcIP: NewIPNet(flow.Reverse.SrcIP), + ConntrackReplyDstIP: NewIPNet(flow.Reverse.DstIP), + }, + portFilter: map[ConntrackFilterType]uint16{ + ConntrackOrigSrcPort: flow.Forward.SrcPort, + ConntrackOrigDstPort: flow.Forward.DstPort, + }, + protoFilter: unix.IPPROTO_TCP, + } + var match *ConntrackFlow + for _, f := range flows { + if filter.MatchConntrackFlow(f) { + match = f + break + } + } + if match == nil { + t.Fatalf("didn't find any matching conntrack entries for original flow: %+v\n Filter used: %+v", flow, filter) + } + + // Delete using the handler + if err := h.ConntrackDelete(ConntrackTable, InetFamily(nl.FAMILY_V6), &flow); err != nil { + t.Fatalf("failed to delete conntrack via handler: %s", err) + } + + // Verify it's gone + flows, err = h.ConntrackTableList(ConntrackTable, nl.FAMILY_V6) + if err != nil { + t.Fatalf("failed to list conntracks following delete: %s", err) + } + for _, f := range flows { + if filter.MatchConntrackFlow(f) { + t.Fatalf("found flow after delete: %+v", f) + } + } +} + // TestConntrackLabels test the conntrack table labels // Creates some flows and then checks the labels associated func TestConntrackLabels(t *testing.T) {