Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 2 additions & 5 deletions audit.go
Original file line number Diff line number Diff line change
Expand Up @@ -208,13 +208,10 @@ func handleLogRotation(config *viper.Viper, writer *AuditWriter) {
el.Fatalln("Error re-opening log file. Exiting.")
}

oldFile := writer.w.(*os.File)
writer.w = newWriter.w
writer.e = newWriter.e
err = writer.rotate(newWriter)

err = oldFile.Close()
if err != nil {
el.Printf("Error closing old log file: %+v\n", err)
el.Println(err)
}
}
}
Expand Down
5 changes: 4 additions & 1 deletion audit_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,8 @@ func Test_createFileOutput(t *testing.T) {
// chown error
c = viper.New()
c.Set("output.file.attempts", 1)
c.Set("output.file.path", path.Join(os.TempDir(), "go-audit.test.log"))
testLogFile := path.Join(os.TempDir(), "go-audit.test.log")
c.Set("output.file.path", testLogFile)
c.Set("output.file.mode", 0644)
c.Set("output.file.user", "root")
c.Set("output.file.group", "root")
Expand Down Expand Up @@ -315,8 +316,10 @@ func Test_createOutput(t *testing.T) {
w, err = createOutput(c)
assert.Nil(t, err)
assert.NotNil(t, w)
w.mutex.Lock()
assert.IsType(t, &AuditWriter{}, w)
assert.IsType(t, &os.File{}, w.w)
w.mutex.Unlock()

// File rotation
os.Rename(path.Join(os.TempDir(), "go-audit.test.log"), path.Join(os.TempDir(), "go-audit.test.log.rotated"))
Expand Down
6 changes: 4 additions & 2 deletions client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@ package main
import (
"bytes"
"encoding/binary"
"github.com/stretchr/testify/assert"
"os"
"sync/atomic"
"syscall"
"testing"

"github.com/stretchr/testify/assert"
)

func TestNetlinkClient_KeepConnection(t *testing.T) {
Expand Down Expand Up @@ -99,7 +101,7 @@ func TestNewNetlinkClient(t *testing.T) {
} else {
assert.True(t, (n.fd > 0), "No file descriptor")
assert.True(t, (n.address != nil), "Address was nil")
assert.Equal(t, uint32(0), n.seq, "Seq should start at 0")
assert.Equal(t, uint32(0), atomic.LoadUint32(&n.seq), "Seq should start at 0")
assert.True(t, MAX_AUDIT_MESSAGE_LENGTH >= len(n.buf), "Client buffer is too small")

assert.Equal(t, "Socket receive buffer size: ", lb.String()[:28], "Expected some nice log lines")
Expand Down
4 changes: 2 additions & 2 deletions marshaller_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ func TestAuditMarshaller_Consume(t *testing.T) {

assert.Equal(
t,
"{\"sequence\":1,\"timestamp\":\"10000001\",\"messages\":[{\"type\":1300,\"data\":\"hi there\"},{\"type\":1301,\"data\":\"hi there\"}],\"uid_map\":{}}\n",
"{\"sequence\":1,\"timestamp\":\"10000001\",\"messages\":[{\"type\":1300,\"data\":\"hi there\"},{\"type\":1301,\"data\":\"hi there\"}],\"uid_map\":{},\"syscall\":\"\"}\n",
w.String(),
)
assert.Equal(t, 0, len(m.msgs))
Expand Down Expand Up @@ -113,7 +113,7 @@ func TestAuditMarshaller_Consume(t *testing.T) {
m.Consume(new1320("0"))
}

assert.Equal(t, "{\"sequence\":4,\"timestamp\":\"10000001\",\"messages\":[{\"type\":1300,\"data\":\"hi there\"}],\"uid_map\":{}}\n", w.String())
assert.Equal(t, "{\"sequence\":4,\"timestamp\":\"10000001\",\"messages\":[{\"type\":1300,\"data\":\"hi there\"}],\"uid_map\":{},\"syscall\":\"\"}\n", w.String())
expected := start.Add(time.Second * 2)
assert.True(t, expected.Equal(time.Now()) || expected.Before(time.Now()), "Should have taken at least 2 seconds to flush")
assert.Equal(t, 0, len(m.msgs))
Expand Down
6 changes: 3 additions & 3 deletions parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ type AuditMessageGroup struct {
CompleteAfter time.Time `json:"-"`
Msgs []*AuditMessage `json:"messages"`
UidMap map[string]string `json:"uid_map"`
Syscall string `json:"-"`
Syscall string `json:"syscall"`
}

// Creates a new message group from the details parsed from the message
Expand Down Expand Up @@ -89,7 +89,7 @@ func (amg *AuditMessageGroup) AddMessage(am *AuditMessage) {
amg.Msgs = append(amg.Msgs, am)
//TODO: need to find more message types that won't contain uids, also make these constants
switch am.Type {
case 1309, 1307, 1306:
case 1309, 1307, 1306, 1305:
// Don't map uids here
case 1300:
amg.findSyscall(am)
Expand Down Expand Up @@ -126,7 +126,7 @@ func (amg *AuditMessageGroup) mapUids(am *AuditMessage) {

// Don't bother re-adding if the existing group already has the mapping
if _, ok := amg.UidMap[uid]; !ok {
amg.UidMap[uid] = getUsername(data[start : start+end])
amg.UidMap[uid] = getUsername(uid)
}

// Find the next uid= if we have space for one
Expand Down
22 changes: 22 additions & 0 deletions writer.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,18 @@ package main

import (
"encoding/json"
"fmt"
"io"
"os"
"sync"
"time"
)

type AuditWriter struct {
e *json.Encoder
w io.Writer
attempts int
mutex sync.RWMutex
}

func NewAuditWriter(w io.Writer, attempts int) *AuditWriter {
Expand All @@ -21,6 +25,7 @@ func NewAuditWriter(w io.Writer, attempts int) *AuditWriter {
}

func (a *AuditWriter) Write(msg *AuditMessageGroup) (err error) {
a.mutex.RLock()
for i := 0; i < a.attempts; i++ {
err = a.e.Encode(msg)
if err == nil {
Expand All @@ -34,6 +39,23 @@ func (a *AuditWriter) Write(msg *AuditMessageGroup) (err error) {
time.Sleep(time.Second * 1)
}
}
a.mutex.RUnlock()

return err
}

func (self *AuditWriter) rotate(ow *AuditWriter) error {
oldFile := self.w.(*os.File)

self.mutex.Lock()
self.w = ow.w
self.e = ow.e
self.mutex.Unlock()

err := oldFile.Close()
if err != nil {
return fmt.Errorf("Error re-opening log file. Exiting.")
}

return nil
}