diff --git a/audit.go b/audit.go index 649cb4d..0b302c5 100644 --- a/audit.go +++ b/audit.go @@ -208,13 +208,10 @@ func handleLogRotation(config *viper.Viper, writer *AuditWriter) { el.Fatalln("Error re-opening log file. Exiting.") } - oldFile := writer.w.(*os.File) - writer.w = newWriter.w - writer.e = newWriter.e + err = writer.rotate(newWriter) - err = oldFile.Close() if err != nil { - el.Printf("Error closing old log file: %+v\n", err) + el.Println(err) } } } diff --git a/audit_test.go b/audit_test.go index c6000e5..58cbe44 100644 --- a/audit_test.go +++ b/audit_test.go @@ -156,7 +156,8 @@ func Test_createFileOutput(t *testing.T) { // chown error c = viper.New() c.Set("output.file.attempts", 1) - c.Set("output.file.path", path.Join(os.TempDir(), "go-audit.test.log")) + testLogFile := path.Join(os.TempDir(), "go-audit.test.log") + c.Set("output.file.path", testLogFile) c.Set("output.file.mode", 0644) c.Set("output.file.user", "root") c.Set("output.file.group", "root") @@ -315,8 +316,10 @@ func Test_createOutput(t *testing.T) { w, err = createOutput(c) assert.Nil(t, err) assert.NotNil(t, w) + w.mutex.Lock() assert.IsType(t, &AuditWriter{}, w) assert.IsType(t, &os.File{}, w.w) + w.mutex.Unlock() // File rotation os.Rename(path.Join(os.TempDir(), "go-audit.test.log"), path.Join(os.TempDir(), "go-audit.test.log.rotated")) diff --git a/client_test.go b/client_test.go index 440fd23..26c42d2 100644 --- a/client_test.go +++ b/client_test.go @@ -3,10 +3,12 @@ package main import ( "bytes" "encoding/binary" - "github.com/stretchr/testify/assert" "os" + "sync/atomic" "syscall" "testing" + + "github.com/stretchr/testify/assert" ) func TestNetlinkClient_KeepConnection(t *testing.T) { @@ -99,7 +101,7 @@ func TestNewNetlinkClient(t *testing.T) { } else { assert.True(t, (n.fd > 0), "No file descriptor") assert.True(t, (n.address != nil), "Address was nil") - assert.Equal(t, uint32(0), n.seq, "Seq should start at 0") + assert.Equal(t, uint32(0), atomic.LoadUint32(&n.seq), "Seq should start at 0") assert.True(t, MAX_AUDIT_MESSAGE_LENGTH >= len(n.buf), "Client buffer is too small") assert.Equal(t, "Socket receive buffer size: ", lb.String()[:28], "Expected some nice log lines") diff --git a/marshaller_test.go b/marshaller_test.go index 45714fb..869562c 100644 --- a/marshaller_test.go +++ b/marshaller_test.go @@ -45,7 +45,7 @@ func TestAuditMarshaller_Consume(t *testing.T) { assert.Equal( t, - "{\"sequence\":1,\"timestamp\":\"10000001\",\"messages\":[{\"type\":1300,\"data\":\"hi there\"},{\"type\":1301,\"data\":\"hi there\"}],\"uid_map\":{}}\n", + "{\"sequence\":1,\"timestamp\":\"10000001\",\"messages\":[{\"type\":1300,\"data\":\"hi there\"},{\"type\":1301,\"data\":\"hi there\"}],\"uid_map\":{},\"syscall\":\"\"}\n", w.String(), ) assert.Equal(t, 0, len(m.msgs)) @@ -113,7 +113,7 @@ func TestAuditMarshaller_Consume(t *testing.T) { m.Consume(new1320("0")) } - assert.Equal(t, "{\"sequence\":4,\"timestamp\":\"10000001\",\"messages\":[{\"type\":1300,\"data\":\"hi there\"}],\"uid_map\":{}}\n", w.String()) + assert.Equal(t, "{\"sequence\":4,\"timestamp\":\"10000001\",\"messages\":[{\"type\":1300,\"data\":\"hi there\"}],\"uid_map\":{},\"syscall\":\"\"}\n", w.String()) expected := start.Add(time.Second * 2) assert.True(t, expected.Equal(time.Now()) || expected.Before(time.Now()), "Should have taken at least 2 seconds to flush") assert.Equal(t, 0, len(m.msgs)) diff --git a/parser.go b/parser.go index 84ba6fe..d943784 100644 --- a/parser.go +++ b/parser.go @@ -33,7 +33,7 @@ type AuditMessageGroup struct { CompleteAfter time.Time `json:"-"` Msgs []*AuditMessage `json:"messages"` UidMap map[string]string `json:"uid_map"` - Syscall string `json:"-"` + Syscall string `json:"syscall"` } // Creates a new message group from the details parsed from the message @@ -89,7 +89,7 @@ func (amg *AuditMessageGroup) AddMessage(am *AuditMessage) { amg.Msgs = append(amg.Msgs, am) //TODO: need to find more message types that won't contain uids, also make these constants switch am.Type { - case 1309, 1307, 1306: + case 1309, 1307, 1306, 1305: // Don't map uids here case 1300: amg.findSyscall(am) @@ -126,7 +126,7 @@ func (amg *AuditMessageGroup) mapUids(am *AuditMessage) { // Don't bother re-adding if the existing group already has the mapping if _, ok := amg.UidMap[uid]; !ok { - amg.UidMap[uid] = getUsername(data[start : start+end]) + amg.UidMap[uid] = getUsername(uid) } // Find the next uid= if we have space for one diff --git a/writer.go b/writer.go index e3b3a69..f4bcb8d 100644 --- a/writer.go +++ b/writer.go @@ -2,7 +2,10 @@ package main import ( "encoding/json" + "fmt" "io" + "os" + "sync" "time" ) @@ -10,6 +13,7 @@ type AuditWriter struct { e *json.Encoder w io.Writer attempts int + mutex sync.RWMutex } func NewAuditWriter(w io.Writer, attempts int) *AuditWriter { @@ -21,6 +25,7 @@ func NewAuditWriter(w io.Writer, attempts int) *AuditWriter { } func (a *AuditWriter) Write(msg *AuditMessageGroup) (err error) { + a.mutex.RLock() for i := 0; i < a.attempts; i++ { err = a.e.Encode(msg) if err == nil { @@ -34,6 +39,23 @@ func (a *AuditWriter) Write(msg *AuditMessageGroup) (err error) { time.Sleep(time.Second * 1) } } + a.mutex.RUnlock() return err } + +func (self *AuditWriter) rotate(ow *AuditWriter) error { + oldFile := self.w.(*os.File) + + self.mutex.Lock() + self.w = ow.w + self.e = ow.e + self.mutex.Unlock() + + err := oldFile.Close() + if err != nil { + return fmt.Errorf("Error re-opening log file. Exiting.") + } + + return nil +}