Skip to content

Commit 811b405

Browse files
author
Maxime Vidori
committed
Move executable to cmd directory
This patch also allow this repository to be imported in another project. Only the execution part and config loading is moved to the cmd directory allowing the use of the core features in other projects.
1 parent c160a22 commit 811b405

File tree

12 files changed

+143
-121
lines changed

12 files changed

+143
-121
lines changed

Makefile

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,17 @@
11
bin:
22
govendor sync
3-
go build
3+
go build cmd
44

55
test:
66
govendor sync
7-
go test -v
7+
go test -v ./...
88

99
test-cov-html:
1010
go test -coverprofile=coverage.out
1111
go tool cover -html=coverage.out
1212

1313
bench:
14-
go test -bench=.
14+
go test -bench=. ./...
1515

1616
bench-cpu:
1717
go test -bench=. -benchtime=5s -cpuprofile=cpu.pprof

client.go

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
1-
package main
1+
package audit
22

33
import (
44
"bytes"
55
"encoding/binary"
66
"errors"
7+
"fmt"
78
"sync/atomic"
89
"syscall"
910
"time"
10-
"fmt"
1111
)
1212

1313
// Endianness is an alias for what we assume is the current machine endianness
@@ -63,13 +63,13 @@ func NewNetlinkClient(recvSize int) (*NetlinkClient, error) {
6363
// Set the buffer size if we were asked
6464
if recvSize > 0 {
6565
if err = syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_RCVBUF, recvSize); err != nil {
66-
el.Println("Failed to set receive buffer size")
66+
Stderr.Println("Failed to set receive buffer size")
6767
}
6868
}
6969

7070
// Print the current receive buffer size
7171
if v, err := syscall.GetsockoptInt(n.fd, syscall.SOL_SOCKET, syscall.SO_RCVBUF); err == nil {
72-
l.Println("Socket receive buffer size:", v)
72+
Std.Println("Socket receive buffer size:", v)
7373
}
7474

7575
go func() {
@@ -151,6 +151,6 @@ func (n *NetlinkClient) KeepConnection() {
151151

152152
err := n.Send(packet, payload)
153153
if err != nil {
154-
el.Println("Error occurred while trying to keep the connection:", err)
154+
Stderr.Println("Error occurred while trying to keep the connection:", err)
155155
}
156156
}

client_test.go

Lines changed: 8 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
1-
package main
1+
package audit
22

33
import (
4-
"bytes"
54
"encoding/binary"
6-
"github.com/stretchr/testify/assert"
75
"os"
86
"syscall"
97
"testing"
8+
9+
"github.com/slackhq/go-audit/internal/test"
10+
"github.com/stretchr/testify/assert"
1011
)
1112

1213
func TestNetlinkClient_KeepConnection(t *testing.T) {
@@ -29,8 +30,8 @@ func TestNetlinkClient_KeepConnection(t *testing.T) {
2930
assert.EqualValues(t, msg.Data[:40], expectedData, "data was wrong")
3031

3132
// Make sure we get errors printed
32-
lb, elb := hookLogger()
33-
defer resetLogger()
33+
lb, elb := test.HookLogger(Std, Stderr)
34+
defer test.ResetLogger(Std, Stderr)
3435
syscall.Close(n.fd)
3536
n.KeepConnection()
3637
assert.Equal(t, "", lb.String(), "Got some log lines we did not expect")
@@ -88,8 +89,8 @@ func TestNetlinkClient_SendReceive(t *testing.T) {
8889
}
8990

9091
func TestNewNetlinkClient(t *testing.T) {
91-
lb, elb := hookLogger()
92-
defer resetLogger()
92+
lb, elb := test.HookLogger(Std, Stderr)
93+
defer test.ResetLogger(Std, Stderr)
9394

9495
n, err := NewNetlinkClient(1024)
9596

@@ -143,19 +144,3 @@ func sendReceive(t *testing.T, n *NetlinkClient, packet *NetlinkPacket, payload
143144

144145
return msg
145146
}
146-
147-
// Resets global loggers
148-
func resetLogger() {
149-
l.SetOutput(os.Stdout)
150-
el.SetOutput(os.Stderr)
151-
}
152-
153-
// Hooks the global loggers writers so you can assert their contents
154-
func hookLogger() (lb *bytes.Buffer, elb *bytes.Buffer) {
155-
lb = &bytes.Buffer{}
156-
l.SetOutput(lb)
157-
158-
elb = &bytes.Buffer{}
159-
el.SetOutput(elb)
160-
return
161-
}

audit.go renamed to cmd/audit.go

Lines changed: 40 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@ import (
44
"errors"
55
"flag"
66
"fmt"
7-
"log"
87
"log/syslog"
98
"os"
109
"os/exec"
@@ -15,12 +14,10 @@ import (
1514
"strings"
1615
"syscall"
1716

17+
"github.com/slackhq/go-audit"
1818
"github.com/spf13/viper"
1919
)
2020

21-
var l = log.New(os.Stdout, "", 0)
22-
var el = log.New(os.Stderr, "", 0)
23-
2421
type executor func(string, ...string) error
2522

2623
func lExec(s string, a ...string) error {
@@ -46,8 +43,8 @@ func loadConfig(configFile string) (*viper.Viper, error) {
4643
return nil, err
4744
}
4845

49-
l.SetFlags(config.GetInt("log.flags"))
50-
el.SetFlags(config.GetInt("log.flags"))
46+
audit.Std.SetFlags(config.GetInt("log.flags"))
47+
audit.Stderr.SetFlags(config.GetInt("log.flags"))
5148

5249
return config, nil
5350
}
@@ -58,7 +55,7 @@ func setRules(config *viper.Viper, e executor) error {
5855
return fmt.Errorf("Failed to flush existing audit rules. Error: %s", err)
5956
}
6057

61-
l.Println("Flushed existing audit rules")
58+
audit.Std.Println("Flushed existing audit rules")
6259

6360
// Add ours in
6461
if rules := config.GetStringSlice("rules"); len(rules) != 0 {
@@ -72,7 +69,7 @@ func setRules(config *viper.Viper, e executor) error {
7269
return fmt.Errorf("Failed to add rule #%d. Error: %s", i+1, err)
7370
}
7471

75-
l.Printf("Added audit rule #%d\n", i+1)
72+
audit.Std.Printf("Added audit rule #%d\n", i+1)
7673
}
7774
} else {
7875
return errors.New("No audit rules found")
@@ -81,8 +78,8 @@ func setRules(config *viper.Viper, e executor) error {
8178
return nil
8279
}
8380

84-
func createOutput(config *viper.Viper) (*AuditWriter, error) {
85-
var writer *AuditWriter
81+
func createOutput(config *viper.Viper) (*audit.AuditWriter, error) {
82+
var writer *audit.AuditWriter
8683
var err error
8784
i := 0
8885

@@ -123,7 +120,7 @@ func createOutput(config *viper.Viper) (*AuditWriter, error) {
123120
return writer, nil
124121
}
125122

126-
func createSyslogOutput(config *viper.Viper) (*AuditWriter, error) {
123+
func createSyslogOutput(config *viper.Viper) (*audit.AuditWriter, error) {
127124
attempts := config.GetInt("output.syslog.attempts")
128125
if attempts < 1 {
129126
return nil, fmt.Errorf("Output attempts for syslog must be at least 1, %v provided", attempts)
@@ -140,10 +137,10 @@ func createSyslogOutput(config *viper.Viper) (*AuditWriter, error) {
140137
return nil, fmt.Errorf("Failed to open syslog writer. Error: %v", err)
141138
}
142139

143-
return NewAuditWriter(syslogWriter, attempts), nil
140+
return audit.NewAuditWriter(syslogWriter, attempts), nil
144141
}
145142

146-
func createFileOutput(config *viper.Viper) (*AuditWriter, error) {
143+
func createFileOutput(config *viper.Viper) (*audit.AuditWriter, error) {
147144
attempts := config.GetInt("output.file.attempts")
148145
if attempts < 1 {
149146
return nil, fmt.Errorf("Output attempts for file must be at least 1, %v provided", attempts)
@@ -193,10 +190,10 @@ func createFileOutput(config *viper.Viper) (*AuditWriter, error) {
193190
return nil, fmt.Errorf("Could not chown output file. Error: %s", err)
194191
}
195192

196-
return NewAuditWriter(f, attempts), nil
193+
return audit.NewAuditWriter(f, attempts), nil
197194
}
198195

199-
func handleLogRotation(config *viper.Viper, writer *AuditWriter) {
196+
func handleLogRotation(config *viper.Viper, writer *audit.AuditWriter) {
200197
// Re-open our log file. This is triggered by a USR1 signal and is meant to be used upon log rotation
201198

202199
sigc := make(chan os.Signal, 1)
@@ -205,38 +202,37 @@ func handleLogRotation(config *viper.Viper, writer *AuditWriter) {
205202
for range sigc {
206203
newWriter, err := createFileOutput(config)
207204
if err != nil {
208-
el.Fatalln("Error re-opening log file. Exiting.")
205+
audit.Stderr.Fatalln("Error re-opening log file. Exiting.")
209206
}
210207

211-
oldFile := writer.w.(*os.File)
212-
writer.w = newWriter.w
213-
writer.e = newWriter.e
208+
oldFile := writer.IOWriter().(*os.File)
209+
writer.SetIOWriter(newWriter.IOWriter())
214210

215211
err = oldFile.Close()
216212
if err != nil {
217-
el.Printf("Error closing old log file: %+v\n", err)
213+
audit.Stderr.Printf("Error closing old log file: %+v\n", err)
218214
}
219215
}
220216
}
221217

222-
func createStdOutOutput(config *viper.Viper) (*AuditWriter, error) {
218+
func createStdOutOutput(config *viper.Viper) (*audit.AuditWriter, error) {
223219
attempts := config.GetInt("output.stdout.attempts")
224220
if attempts < 1 {
225221
return nil, fmt.Errorf("Output attempts for stdout must be at least 1, %v provided", attempts)
226222
}
227223

228224
// l logger is no longer stdout
229-
l.SetOutput(os.Stderr)
225+
audit.Std.SetOutput(os.Stderr)
230226

231-
return NewAuditWriter(os.Stdout, attempts), nil
227+
return audit.NewAuditWriter(os.Stdout, attempts), nil
232228
}
233229

234-
func createFilters(config *viper.Viper) ([]AuditFilter, error) {
230+
func createFilters(config *viper.Viper) ([]audit.AuditFilter, error) {
235231
var err error
236232
var ok bool
237233

238234
fs := config.Get("filters")
239-
filters := []AuditFilter{}
235+
filters := []audit.AuditFilter{}
240236

241237
if fs == nil {
242238
return filters, nil
@@ -253,7 +249,7 @@ func createFilters(config *viper.Viper) ([]AuditFilter, error) {
253249
return filters, fmt.Errorf("Could not parse filter %d; '%+v'", i+1, f)
254250
}
255251

256-
af := AuditFilter{}
252+
af := audit.AuditFilter{}
257253
for k, v := range f2 {
258254
switch k {
259255
case "message_type":
@@ -262,10 +258,10 @@ func createFilters(config *viper.Viper) ([]AuditFilter, error) {
262258
if err != nil {
263259
return filters, fmt.Errorf("`message_type` in filter %d could not be parsed; Value: `%+v`; Error: %s", i+1, v, err)
264260
}
265-
af.messageType = uint16(fv)
261+
af.MessageType = uint16(fv)
266262

267263
} else if ev, ok := v.(int); ok {
268-
af.messageType = uint16(ev)
264+
af.MessageType = uint16(ev)
269265

270266
} else {
271267
return filters, fmt.Errorf("`message_type` in filter %d could not be parsed; Value: `%+v`", i+1, v)
@@ -277,31 +273,31 @@ func createFilters(config *viper.Viper) ([]AuditFilter, error) {
277273
return filters, fmt.Errorf("`regex` in filter %d could not be parsed; Value: `%+v`", i+1, v)
278274
}
279275

280-
if af.regex, err = regexp.Compile(re); err != nil {
276+
if af.Regex, err = regexp.Compile(re); err != nil {
281277
return filters, fmt.Errorf("`regex` in filter %d could not be parsed; Value: `%+v`; Error: %s", i+1, v, err)
282278
}
283279

284280
case "syscall":
285-
if af.syscall, ok = v.(string); ok {
281+
if af.Syscall, ok = v.(string); ok {
286282
// All is good
287283
} else if ev, ok := v.(int); ok {
288-
af.syscall = strconv.Itoa(ev)
284+
af.Syscall = strconv.Itoa(ev)
289285
} else {
290286
return filters, fmt.Errorf("`syscall` in filter %d could not be parsed; Value: `%+v`", i+1, v)
291287
}
292288
}
293289
}
294290

295-
if af.regex == nil {
291+
if af.Regex == nil {
296292
return filters, fmt.Errorf("Filter %d is missing the `regex` entry", i+1)
297293
}
298294

299-
if af.messageType == 0 {
295+
if af.MessageType == 0 {
300296
return filters, fmt.Errorf("Filter %d is missing the `message_type` entry", i+1)
301297
}
302298

303299
filters = append(filters, af)
304-
l.Printf("Ignoring syscall `%v` containing message type `%v` matching string `%s`\n", af.syscall, af.messageType, af.regex.String())
300+
audit.Std.Printf("Ignoring syscall `%v` containing message type `%v` matching string `%s`\n", af.Syscall, af.MessageType, af.Regex.String())
305301
}
306302

307303
return filters, nil
@@ -313,37 +309,37 @@ func main() {
313309
flag.Parse()
314310

315311
if *configFile == "" {
316-
el.Println("A config file must be provided")
312+
audit.Stderr.Println("A config file must be provided")
317313
flag.Usage()
318314
os.Exit(1)
319315
}
320316

321317
config, err := loadConfig(*configFile)
322318
if err != nil {
323-
el.Fatal(err)
319+
audit.Stderr.Fatal(err)
324320
}
325321

326322
// output needs to be created before anything that write to stdout
327323
writer, err := createOutput(config)
328324
if err != nil {
329-
el.Fatal(err)
325+
audit.Stderr.Fatal(err)
330326
}
331327

332328
if err := setRules(config, lExec); err != nil {
333-
el.Fatal(err)
329+
audit.Stderr.Fatal(err)
334330
}
335331

336332
filters, err := createFilters(config)
337333
if err != nil {
338-
el.Fatal(err)
334+
audit.Stderr.Fatal(err)
339335
}
340336

341-
nlClient, err := NewNetlinkClient(config.GetInt("socket_buffer.receive"))
337+
nlClient, err := audit.NewNetlinkClient(config.GetInt("socket_buffer.receive"))
342338
if err != nil {
343-
el.Fatal(err)
339+
audit.Stderr.Fatal(err)
344340
}
345341

346-
marshaller := NewAuditMarshaller(
342+
marshaller := audit.NewAuditMarshaller(
347343
writer,
348344
uint16(config.GetInt("events.min")),
349345
uint16(config.GetInt("events.max")),
@@ -353,13 +349,13 @@ func main() {
353349
filters,
354350
)
355351

356-
l.Printf("Started processing events in the range [%d, %d]\n", config.GetInt("events.min"), config.GetInt("events.max"))
352+
audit.Std.Printf("Started processing events in the range [%d, %d]\n", config.GetInt("events.min"), config.GetInt("events.max"))
357353

358354
//Main loop. Get data from netlink and send it to the json lib for processing
359355
for {
360356
msg, err := nlClient.Receive()
361357
if err != nil {
362-
el.Printf("Error during message receive: %+v\n", err)
358+
audit.Stderr.Printf("Error during message receive: %+v\n", err)
363359
continue
364360
}
365361

0 commit comments

Comments
 (0)