diff --git a/Makefile b/Makefile index 80b0ab0..c104057 100644 --- a/Makefile +++ b/Makefile @@ -1,17 +1,17 @@ bin: govendor sync - go build + go build cmd/*.go test: govendor sync - go test -v + go test -v ./... test-cov-html: go test -coverprofile=coverage.out go tool cover -html=coverage.out bench: - go test -bench=. + go test -bench=. ./... bench-cpu: go test -bench=. -benchtime=5s -cpuprofile=cpu.pprof diff --git a/client.go b/client.go index 7538cb1..d83e919 100644 --- a/client.go +++ b/client.go @@ -1,13 +1,13 @@ -package main +package audit import ( "bytes" "encoding/binary" "errors" + "fmt" "sync/atomic" "syscall" "time" - "fmt" ) // Endianness is an alias for what we assume is the current machine endianness @@ -63,13 +63,13 @@ func NewNetlinkClient(recvSize int) (*NetlinkClient, error) { // Set the buffer size if we were asked if recvSize > 0 { if err = syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_RCVBUF, recvSize); err != nil { - el.Println("Failed to set receive buffer size") + Stderr.Println("Failed to set receive buffer size") } } // Print the current receive buffer size if v, err := syscall.GetsockoptInt(n.fd, syscall.SOL_SOCKET, syscall.SO_RCVBUF); err == nil { - l.Println("Socket receive buffer size:", v) + Std.Println("Socket receive buffer size:", v) } go func() { @@ -151,6 +151,6 @@ func (n *NetlinkClient) KeepConnection() { err := n.Send(packet, payload) if err != nil { - el.Println("Error occurred while trying to keep the connection:", err) + Stderr.Println("Error occurred while trying to keep the connection:", err) } } diff --git a/client_test.go b/client_test.go index 440fd23..c87b7c3 100644 --- a/client_test.go +++ b/client_test.go @@ -1,12 +1,13 @@ -package main +package audit import ( - "bytes" "encoding/binary" - "github.com/stretchr/testify/assert" "os" "syscall" "testing" + + "github.com/slackhq/go-audit/internal/test" + "github.com/stretchr/testify/assert" ) func TestNetlinkClient_KeepConnection(t *testing.T) { @@ -29,8 +30,8 @@ func TestNetlinkClient_KeepConnection(t *testing.T) { assert.EqualValues(t, msg.Data[:40], expectedData, "data was wrong") // Make sure we get errors printed - lb, elb := hookLogger() - defer resetLogger() + lb, elb := test.HookLogger(Std, Stderr) + defer test.ResetLogger(Std, Stderr) syscall.Close(n.fd) n.KeepConnection() assert.Equal(t, "", lb.String(), "Got some log lines we did not expect") @@ -88,8 +89,8 @@ func TestNetlinkClient_SendReceive(t *testing.T) { } func TestNewNetlinkClient(t *testing.T) { - lb, elb := hookLogger() - defer resetLogger() + lb, elb := test.HookLogger(Std, Stderr) + defer test.ResetLogger(Std, Stderr) n, err := NewNetlinkClient(1024) @@ -143,19 +144,3 @@ func sendReceive(t *testing.T, n *NetlinkClient, packet *NetlinkPacket, payload return msg } - -// Resets global loggers -func resetLogger() { - l.SetOutput(os.Stdout) - el.SetOutput(os.Stderr) -} - -// Hooks the global loggers writers so you can assert their contents -func hookLogger() (lb *bytes.Buffer, elb *bytes.Buffer) { - lb = &bytes.Buffer{} - l.SetOutput(lb) - - elb = &bytes.Buffer{} - el.SetOutput(elb) - return -} diff --git a/audit.go b/cmd/audit.go similarity index 77% rename from audit.go rename to cmd/audit.go index 649cb4d..fe42333 100644 --- a/audit.go +++ b/cmd/audit.go @@ -4,7 +4,6 @@ import ( "errors" "flag" "fmt" - "log" "log/syslog" "os" "os/exec" @@ -15,12 +14,10 @@ import ( "strings" "syscall" + "github.com/slackhq/go-audit" "github.com/spf13/viper" ) -var l = log.New(os.Stdout, "", 0) -var el = log.New(os.Stderr, "", 0) - type executor func(string, ...string) error func lExec(s string, a ...string) error { @@ -46,8 +43,8 @@ func loadConfig(configFile string) (*viper.Viper, error) { return nil, err } - l.SetFlags(config.GetInt("log.flags")) - el.SetFlags(config.GetInt("log.flags")) + audit.Std.SetFlags(config.GetInt("log.flags")) + audit.Stderr.SetFlags(config.GetInt("log.flags")) return config, nil } @@ -58,7 +55,7 @@ func setRules(config *viper.Viper, e executor) error { return fmt.Errorf("Failed to flush existing audit rules. Error: %s", err) } - l.Println("Flushed existing audit rules") + audit.Std.Println("Flushed existing audit rules") // Add ours in if rules := config.GetStringSlice("rules"); len(rules) != 0 { @@ -72,7 +69,7 @@ func setRules(config *viper.Viper, e executor) error { return fmt.Errorf("Failed to add rule #%d. Error: %s", i+1, err) } - l.Printf("Added audit rule #%d\n", i+1) + audit.Std.Printf("Added audit rule #%d\n", i+1) } } else { return errors.New("No audit rules found") @@ -81,8 +78,8 @@ func setRules(config *viper.Viper, e executor) error { return nil } -func createOutput(config *viper.Viper) (*AuditWriter, error) { - var writer *AuditWriter +func createOutput(config *viper.Viper) (*audit.JSONAuditWriter, error) { + var writer *audit.JSONAuditWriter var err error i := 0 @@ -123,7 +120,7 @@ func createOutput(config *viper.Viper) (*AuditWriter, error) { return writer, nil } -func createSyslogOutput(config *viper.Viper) (*AuditWriter, error) { +func createSyslogOutput(config *viper.Viper) (*audit.JSONAuditWriter, error) { attempts := config.GetInt("output.syslog.attempts") if attempts < 1 { return nil, fmt.Errorf("Output attempts for syslog must be at least 1, %v provided", attempts) @@ -140,10 +137,10 @@ func createSyslogOutput(config *viper.Viper) (*AuditWriter, error) { return nil, fmt.Errorf("Failed to open syslog writer. Error: %v", err) } - return NewAuditWriter(syslogWriter, attempts), nil + return audit.NewAuditWriter(syslogWriter, attempts), nil } -func createFileOutput(config *viper.Viper) (*AuditWriter, error) { +func createFileOutput(config *viper.Viper) (*audit.JSONAuditWriter, error) { attempts := config.GetInt("output.file.attempts") if attempts < 1 { return nil, fmt.Errorf("Output attempts for file must be at least 1, %v provided", attempts) @@ -193,10 +190,10 @@ func createFileOutput(config *viper.Viper) (*AuditWriter, error) { return nil, fmt.Errorf("Could not chown output file. Error: %s", err) } - return NewAuditWriter(f, attempts), nil + return audit.NewAuditWriter(f, attempts), nil } -func handleLogRotation(config *viper.Viper, writer *AuditWriter) { +func handleLogRotation(config *viper.Viper, writer *audit.JSONAuditWriter) { // Re-open our log file. This is triggered by a USR1 signal and is meant to be used upon log rotation sigc := make(chan os.Signal, 1) @@ -205,38 +202,37 @@ func handleLogRotation(config *viper.Viper, writer *AuditWriter) { for range sigc { newWriter, err := createFileOutput(config) if err != nil { - el.Fatalln("Error re-opening log file. Exiting.") + audit.Stderr.Fatalln("Error re-opening log file. Exiting.") } - oldFile := writer.w.(*os.File) - writer.w = newWriter.w - writer.e = newWriter.e + oldFile := writer.IOWriter().(*os.File) + writer.SetIOWriter(newWriter.IOWriter()) err = oldFile.Close() if err != nil { - el.Printf("Error closing old log file: %+v\n", err) + audit.Stderr.Printf("Error closing old log file: %+v\n", err) } } } -func createStdOutOutput(config *viper.Viper) (*AuditWriter, error) { +func createStdOutOutput(config *viper.Viper) (*audit.JSONAuditWriter, error) { attempts := config.GetInt("output.stdout.attempts") if attempts < 1 { return nil, fmt.Errorf("Output attempts for stdout must be at least 1, %v provided", attempts) } // l logger is no longer stdout - l.SetOutput(os.Stderr) + audit.Std.SetOutput(os.Stderr) - return NewAuditWriter(os.Stdout, attempts), nil + return audit.NewAuditWriter(os.Stdout, attempts), nil } -func createFilters(config *viper.Viper) ([]AuditFilter, error) { +func createFilters(config *viper.Viper) ([]audit.AuditFilter, error) { var err error var ok bool fs := config.Get("filters") - filters := []AuditFilter{} + filters := []audit.AuditFilter{} if fs == nil { return filters, nil @@ -253,7 +249,7 @@ func createFilters(config *viper.Viper) ([]AuditFilter, error) { return filters, fmt.Errorf("Could not parse filter %d; '%+v'", i+1, f) } - af := AuditFilter{} + af := audit.AuditFilter{} for k, v := range f2 { switch k { case "message_type": @@ -262,10 +258,10 @@ func createFilters(config *viper.Viper) ([]AuditFilter, error) { if err != nil { return filters, fmt.Errorf("`message_type` in filter %d could not be parsed; Value: `%+v`; Error: %s", i+1, v, err) } - af.messageType = uint16(fv) + af.MessageType = uint16(fv) } else if ev, ok := v.(int); ok { - af.messageType = uint16(ev) + af.MessageType = uint16(ev) } else { return filters, fmt.Errorf("`message_type` in filter %d could not be parsed; Value: `%+v`", i+1, v) @@ -277,31 +273,31 @@ func createFilters(config *viper.Viper) ([]AuditFilter, error) { return filters, fmt.Errorf("`regex` in filter %d could not be parsed; Value: `%+v`", i+1, v) } - if af.regex, err = regexp.Compile(re); err != nil { + if af.Regex, err = regexp.Compile(re); err != nil { return filters, fmt.Errorf("`regex` in filter %d could not be parsed; Value: `%+v`; Error: %s", i+1, v, err) } case "syscall": - if af.syscall, ok = v.(string); ok { + if af.Syscall, ok = v.(string); ok { // All is good } else if ev, ok := v.(int); ok { - af.syscall = strconv.Itoa(ev) + af.Syscall = strconv.Itoa(ev) } else { return filters, fmt.Errorf("`syscall` in filter %d could not be parsed; Value: `%+v`", i+1, v) } } } - if af.regex == nil { + if af.Regex == nil { return filters, fmt.Errorf("Filter %d is missing the `regex` entry", i+1) } - if af.messageType == 0 { + if af.MessageType == 0 { return filters, fmt.Errorf("Filter %d is missing the `message_type` entry", i+1) } filters = append(filters, af) - l.Printf("Ignoring syscall `%v` containing message type `%v` matching string `%s`\n", af.syscall, af.messageType, af.regex.String()) + audit.Std.Printf("Ignoring syscall `%v` containing message type `%v` matching string `%s`\n", af.Syscall, af.MessageType, af.Regex.String()) } return filters, nil @@ -313,37 +309,37 @@ func main() { flag.Parse() if *configFile == "" { - el.Println("A config file must be provided") + audit.Stderr.Println("A config file must be provided") flag.Usage() os.Exit(1) } config, err := loadConfig(*configFile) if err != nil { - el.Fatal(err) + audit.Stderr.Fatal(err) } // output needs to be created before anything that write to stdout writer, err := createOutput(config) if err != nil { - el.Fatal(err) + audit.Stderr.Fatal(err) } if err := setRules(config, lExec); err != nil { - el.Fatal(err) + audit.Stderr.Fatal(err) } filters, err := createFilters(config) if err != nil { - el.Fatal(err) + audit.Stderr.Fatal(err) } - nlClient, err := NewNetlinkClient(config.GetInt("socket_buffer.receive")) + nlClient, err := audit.NewNetlinkClient(config.GetInt("socket_buffer.receive")) if err != nil { - el.Fatal(err) + audit.Stderr.Fatal(err) } - marshaller := NewAuditMarshaller( + marshaller := audit.NewAuditMarshaller( writer, uint16(config.GetInt("events.min")), uint16(config.GetInt("events.max")), @@ -353,13 +349,13 @@ func main() { filters, ) - l.Printf("Started processing events in the range [%d, %d]\n", config.GetInt("events.min"), config.GetInt("events.max")) + audit.Std.Printf("Started processing events in the range [%d, %d]\n", config.GetInt("events.min"), config.GetInt("events.max")) //Main loop. Get data from netlink and send it to the json lib for processing for { msg, err := nlClient.Receive() if err != nil { - el.Printf("Error during message receive: %+v\n", err) + audit.Stderr.Printf("Error during message receive: %+v\n", err) continue } diff --git a/audit_test.go b/cmd/audit_test.go similarity index 93% rename from audit_test.go rename to cmd/audit_test.go index c6000e5..493d742 100644 --- a/audit_test.go +++ b/cmd/audit_test.go @@ -13,6 +13,8 @@ import ( "testing" "time" + "github.com/slackhq/go-audit" + "github.com/slackhq/go-audit/internal/test" "github.com/spf13/viper" "github.com/stretchr/testify/assert" ) @@ -33,8 +35,8 @@ func Test_loadConfig(t *testing.T) { assert.Equal(t, "go-audit", config.GetString("output.syslog.tag"), "output.syslog.tag should default to go-audit") assert.Equal(t, 3, config.GetInt("output.syslog.attempts"), "output.syslog.attempts should default to 3") assert.Equal(t, 0, config.GetInt("log.flags"), "log.flags should default to 0") - assert.Equal(t, 0, l.Flags(), "stdout log flags was wrong") - assert.Equal(t, 0, el.Flags(), "stderr log flags was wrong") + assert.Equal(t, 0, audit.Std.Flags(), "stdout log flags was wrong") + assert.Equal(t, 0, audit.Stderr.Flags(), "stderr log flags was wrong") assert.Nil(t, err) // parse error @@ -45,7 +47,7 @@ func Test_loadConfig(t *testing.T) { } func Test_setRules(t *testing.T) { - defer resetLogger() + defer test.ResetLogger(audit.Std, audit.Stderr) // fail to flush rules config := viper.New() @@ -174,7 +176,7 @@ func Test_createFileOutput(t *testing.T) { w, err = createFileOutput(c) assert.Nil(t, err) assert.NotNil(t, w) - assert.IsType(t, &os.File{}, w.w) + assert.IsType(t, &os.File{}, w.IOWriter()) } func Test_createSyslogOutput(t *testing.T) { @@ -208,7 +210,7 @@ func Test_createSyslogOutput(t *testing.T) { w, err = createSyslogOutput(c) assert.Nil(t, err) assert.NotNil(t, w) - assert.IsType(t, &syslog.Writer{}, w.w) + assert.IsType(t, &syslog.Writer{}, w.IOWriter()) } func Test_createStdOutOutput(t *testing.T) { @@ -225,7 +227,7 @@ func Test_createStdOutOutput(t *testing.T) { w, err = createStdOutOutput(c) assert.Nil(t, err) assert.NotNil(t, w) - assert.IsType(t, &os.File{}, w.w) + assert.IsType(t, &os.File{}, w.IOWriter()) } func Test_createOutput(t *testing.T) { @@ -302,7 +304,7 @@ func Test_createOutput(t *testing.T) { w, err = createSyslogOutput(c) assert.Nil(t, err) assert.NotNil(t, w) - assert.IsType(t, &syslog.Writer{}, w.w) + assert.IsType(t, &syslog.Writer{}, w.IOWriter()) // All good file c = viper.New() @@ -315,8 +317,8 @@ func Test_createOutput(t *testing.T) { w, err = createOutput(c) assert.Nil(t, err) assert.NotNil(t, w) - assert.IsType(t, &AuditWriter{}, w) - assert.IsType(t, &os.File{}, w.w) + assert.IsType(t, &audit.JSONAuditWriter{}, w) + assert.IsType(t, &os.File{}, w.IOWriter()) // File rotation os.Rename(path.Join(os.TempDir(), "go-audit.test.log"), path.Join(os.TempDir(), "go-audit.test.log.rotated")) @@ -329,8 +331,8 @@ func Test_createOutput(t *testing.T) { } func Test_createFilters(t *testing.T) { - lb, elb := hookLogger() - defer resetLogger() + lb, elb := test.HookLogger(audit.Std, audit.Stderr) + defer test.ResetLogger(audit.Std, audit.Stderr) // no filters c := viper.New() @@ -425,9 +427,9 @@ func Test_createFilters(t *testing.T) { f, err = createFilters(c) assert.Nil(t, err) assert.NotEmpty(t, f) - assert.Equal(t, "1", f[0].syscall) - assert.Equal(t, uint16(1), f[0].messageType) - assert.Equal(t, "1", f[0].regex.String()) + assert.Equal(t, "1", f[0].Syscall) + assert.Equal(t, uint16(1), f[0].MessageType) + assert.Equal(t, "1", f[0].Regex.String()) assert.Empty(t, elb.String()) assert.Equal(t, "Ignoring syscall `1` containing message type `1` matching string `1`\n", lb.String()) @@ -441,15 +443,15 @@ func Test_createFilters(t *testing.T) { f, err = createFilters(c) assert.Nil(t, err) assert.NotEmpty(t, f) - assert.Equal(t, "1", f[0].syscall) - assert.Equal(t, uint16(1), f[0].messageType) - assert.Equal(t, "1", f[0].regex.String()) + assert.Equal(t, "1", f[0].Syscall) + assert.Equal(t, uint16(1), f[0].MessageType) + assert.Equal(t, "1", f[0].Regex.String()) assert.Empty(t, elb.String()) assert.Equal(t, "Ignoring syscall `1` containing message type `1` matching string `1`\n", lb.String()) } func Benchmark_MultiPacketMessage(b *testing.B) { - marshaller := NewAuditMarshaller(NewAuditWriter(&noopWriter{}, 1), uint16(1300), uint16(1399), false, false, 1, []AuditFilter{}) + marshaller := audit.NewAuditMarshaller(audit.NewAuditWriter(&noopWriter{}, 1), uint16(1300), uint16(1399), false, false, 1, []audit.AuditFilter{}) data := make([][]byte, 6) @@ -476,11 +478,11 @@ func Benchmark_MultiPacketMessage(b *testing.B) { nlen := len(data[n]) msg := &syscall.NetlinkMessage{ Header: syscall.NlMsghdr{ - Len: Endianness.Uint32(data[n][0:4]), - Type: Endianness.Uint16(data[n][4:6]), - Flags: Endianness.Uint16(data[n][6:8]), - Seq: Endianness.Uint32(data[n][8:12]), - Pid: Endianness.Uint32(data[n][12:16]), + Len: audit.Endianness.Uint32(data[n][0:4]), + Type: audit.Endianness.Uint16(data[n][4:6]), + Flags: audit.Endianness.Uint16(data[n][6:8]), + Seq: audit.Endianness.Uint32(data[n][8:12]), + Pid: audit.Endianness.Uint32(data[n][12:16]), }, Data: data[n][syscall.SizeofNlMsghdr:nlen], } diff --git a/internal/test/utils.go b/internal/test/utils.go new file mode 100644 index 0000000..a65e928 --- /dev/null +++ b/internal/test/utils.go @@ -0,0 +1,23 @@ +package test + +import ( + "bytes" + "log" + "os" +) + +// Resets global loggers +func ResetLogger(std, stderr *log.Logger) { + std.SetOutput(os.Stdout) + stderr.SetOutput(os.Stderr) +} + +// Hooks the global loggers writers so you can assert their contents +func HookLogger(std, stderr *log.Logger) (lb *bytes.Buffer, elb *bytes.Buffer) { + lb = &bytes.Buffer{} + std.SetOutput(lb) + + elb = &bytes.Buffer{} + stderr.SetOutput(elb) + return +} diff --git a/logger.go b/logger.go new file mode 100644 index 0000000..41ee803 --- /dev/null +++ b/logger.go @@ -0,0 +1,9 @@ +package audit + +import ( + "log" + "os" +) + +var Std = log.New(os.Stdout, "", 0) +var Stderr = log.New(os.Stderr, "", 0) diff --git a/marshaller.go b/marshaller.go index cf30d1f..fe1868c 100644 --- a/marshaller.go +++ b/marshaller.go @@ -1,4 +1,4 @@ -package main +package audit import ( "os" @@ -13,7 +13,7 @@ const ( type AuditMarshaller struct { msgs map[int]*AuditMessageGroup - writer *AuditWriter + writer AuditWriter lastSeq int missed map[int]bool worstLag int @@ -23,17 +23,21 @@ type AuditMarshaller struct { logOutOfOrder bool maxOutOfOrder int attempts int - filters map[string]map[uint16][]*regexp.Regexp // { syscall: { mtype: [regexp, ...] } } + filters map[string]map[uint16][]*regexp.Regexp // { Syscall: { mtype: [regexp, ...] } } } type AuditFilter struct { - messageType uint16 - regex *regexp.Regexp - syscall string + MessageType uint16 + Regex *regexp.Regexp + Syscall string +} + +func NewAuditFilter(messageType uint16, regex *regexp.Regexp, syscall string) AuditFilter { + return AuditFilter{MessageType: messageType, Regex: regex, Syscall: syscall} } // Create a new marshaller -func NewAuditMarshaller(w *AuditWriter, eventMin uint16, eventMax uint16, trackMessages, logOOO bool, maxOOO int, filters []AuditFilter) *AuditMarshaller { +func NewAuditMarshaller(w AuditWriter, eventMin uint16, eventMax uint16, trackMessages, logOOO bool, maxOOO int, filters []AuditFilter) *AuditMarshaller { am := AuditMarshaller{ writer: w, msgs: make(map[int]*AuditMessageGroup, 5), // It is not typical to have more than 2 message groups at any given time @@ -47,15 +51,15 @@ func NewAuditMarshaller(w *AuditWriter, eventMin uint16, eventMax uint16, trackM } for _, filter := range filters { - if _, ok := am.filters[filter.syscall]; !ok { - am.filters[filter.syscall] = make(map[uint16][]*regexp.Regexp) + if _, ok := am.filters[filter.Syscall]; !ok { + am.filters[filter.Syscall] = make(map[uint16][]*regexp.Regexp) } - if _, ok := am.filters[filter.syscall][filter.messageType]; !ok { - am.filters[filter.syscall][filter.messageType] = []*regexp.Regexp{} + if _, ok := am.filters[filter.Syscall][filter.MessageType]; !ok { + am.filters[filter.Syscall][filter.MessageType] = []*regexp.Regexp{} } - am.filters[filter.syscall][filter.messageType] = append(am.filters[filter.syscall][filter.messageType], filter.regex) + am.filters[filter.Syscall][filter.MessageType] = append(am.filters[filter.Syscall][filter.MessageType], filter.Regex) } return &am @@ -123,7 +127,7 @@ func (a *AuditMarshaller) completeMessage(seq int) { } if err := a.writer.Write(msg); err != nil { - el.Println("Failed to write message. Error:", err) + Stderr.Println("Failed to write message. Error:", err) os.Exit(1) } @@ -166,11 +170,11 @@ func (a *AuditMarshaller) detectMissing(seq int) { } if a.logOutOfOrder { - el.Println("Got sequence", missedSeq, "after", lag, "messages. Worst lag so far", a.worstLag, "messages") + Stderr.Println("Got sequence", missedSeq, "after", lag, "messages. Worst lag so far", a.worstLag, "messages") } delete(a.missed, missedSeq) } else if seq-missedSeq > a.maxOutOfOrder { - el.Printf("Likely missed sequence %d, current %d, worst message delay %d\n", missedSeq, seq, a.worstLag) + Stderr.Printf("Likely missed sequence %d, current %d, worst message delay %d\n", missedSeq, seq, a.worstLag) delete(a.missed, missedSeq) } } diff --git a/marshaller_test.go b/marshaller_test.go index 45714fb..9dcd18b 100644 --- a/marshaller_test.go +++ b/marshaller_test.go @@ -1,4 +1,4 @@ -package main +package audit import ( "bytes" @@ -126,8 +126,8 @@ func TestAuditMarshaller_completeMessage(t *testing.T) { // lb, elb := hookLogger() // m := NewAuditMarshaller(NewAuditWriter(&FailWriter{}, 1), uint16(1300), uint16(1399), false, false, 0, []AuditFilter{}) - // m.Consume(&syscall.NetlinkMessage{ - // Header: syscall.NlMsghdr{ + // m.Consume(&Syscall.NetlinkMessage{ + // Header: Syscall.NlMsghdr{ // Len: uint32(44), // Type: uint16(1300), // Flags: uint16(0), diff --git a/parser.go b/parser.go index 84ba6fe..298fc7e 100644 --- a/parser.go +++ b/parser.go @@ -1,4 +1,4 @@ -package main +package audit import ( "bytes" @@ -145,17 +145,17 @@ func (amg *AuditMessageGroup) findSyscall(am *AuditMessage) { start := 0 end := 0 - if start = strings.Index(data, "syscall="); start < 0 { + if start = strings.Index(data, "Syscall="); start < 0 { return } // Progress the start point beyond the = sign start += 8 if end = strings.IndexByte(data[start:], spaceChar); end < 0 { - // There was no ending space, maybe the syscall id is at the end of the line + // There was no ending space, maybe the Syscall id is at the end of the line end = len(data) - start - // If the end of the line is greater than 5 characters away (overflows a 16 bit uint) then it can't be a syscall id + // If the end of the line is greater than 5 characters away (overflows a 16 bit uint) then it can't be a Syscall id if end > 5 { return } diff --git a/parser_test.go b/parser_test.go index 0584675..c4ca791 100644 --- a/parser_test.go +++ b/parser_test.go @@ -1,4 +1,4 @@ -package main +package audit import ( "github.com/stretchr/testify/assert" diff --git a/writer.go b/writer.go index e3b3a69..e8d5d80 100644 --- a/writer.go +++ b/writer.go @@ -1,4 +1,4 @@ -package main +package audit import ( "encoding/json" @@ -6,21 +6,30 @@ import ( "time" ) -type AuditWriter struct { - e *json.Encoder - w io.Writer - attempts int -} +type ( + AuditWriter interface { + Write(msg *AuditMessageGroup) error + } + + JSONAuditWriter struct { + e *json.Encoder + w io.Writer + attempts int + } +) -func NewAuditWriter(w io.Writer, attempts int) *AuditWriter { - return &AuditWriter{ +func NewAuditWriter(w io.Writer, attempts int) *JSONAuditWriter { + return &JSONAuditWriter{ e: json.NewEncoder(w), w: w, attempts: attempts, } } -func (a *AuditWriter) Write(msg *AuditMessageGroup) (err error) { +func (a *JSONAuditWriter) IOWriter() io.Writer { return a.w } +func (a *JSONAuditWriter) SetIOWriter(w io.Writer) { a.e, a.w = json.NewEncoder(w), w } + +func (a *JSONAuditWriter) Write(msg *AuditMessageGroup) (err error) { for i := 0; i < a.attempts; i++ { err = a.e.Encode(msg) if err == nil { @@ -30,7 +39,7 @@ func (a *AuditWriter) Write(msg *AuditMessageGroup) (err error) { if i != a.attempts { // We have to reset the encoder because write errors are kept internally and can not be retried a.e = json.NewEncoder(a.w) - el.Println("Failed to write message, retrying in 1 second. Error:", err) + Stderr.Println("Failed to write message, retrying in 1 second. Error:", err) time.Sleep(time.Second * 1) } }