Skip to content

Commit 26378aa

Browse files
author
Tom Moulard
committed
review: add context handling
1 parent 9e4cd89 commit 26378aa

File tree

16 files changed

+65
-34
lines changed

16 files changed

+65
-34
lines changed

.golangci.yml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,10 @@ issues:
8888
- text: 'use of `fmt.Println` forbidden' # FIXME: add revert this change ASAP
8989
linters:
9090
- forbidigo
91+
- text: 'Magic number: 2, in <condition> detected'
92+
file: 'pkg/notifications/webhook.go'
93+
linters:
94+
- mnd
9195

9296
output:
9397
show-stats: true

fail2ban.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ func ImportIP(list List) ([]string, error) {
7979

8080
// New instantiates and returns the required components used to handle a HTTP
8181
// request.
82-
func New(_ context.Context, next http.Handler, config *Config, _ string) (http.Handler, error) {
82+
func New(ctx context.Context, next http.Handler, config *Config, _ string) (http.Handler, error) {
8383
if !config.Rules.Enabled {
8484
log.Println("Plugin: FailToBan is disabled")
8585

@@ -140,7 +140,7 @@ func New(_ context.Context, next http.Handler, config *Config, _ string) (http.H
140140

141141
notifSrvc := notifications.NewService(config.Notifications)
142142
if notifSrvc != nil {
143-
go notifSrvc.Run()
143+
go notifSrvc.Run(context.WithoutCancel(ctx))
144144
}
145145

146146
log.Println("Plugin: FailToBan is up and running")

pkg/fail2ban/fail2ban.go

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,6 @@ func (u *Fail2Ban) ShouldAllow(remoteIP string) bool {
7777
}
7878

7979
fmt.Println(remoteIP + " is no longer banned")
80-
u.notify(notifications.UnbanEvent(remoteIP, fmt.Sprintf("%q is no longer banned", remoteIP)))
8180

8281
return true
8382
}
@@ -92,7 +91,6 @@ func (u *Fail2Ban) ShouldAllow(remoteIP string) bool {
9291
msg := fmt.Sprintf("%q is banned for %d>=%d request",
9392
remoteIP, ip.Count+1, u.rules.MaxRetry)
9493
fmt.Printf("%s", msg)
95-
u.notify(notifications.BanEvent(remoteIP, msg, u.rules.Bantime))
9694

9795
return false
9896
}

pkg/notifications/discord.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package notifications
22

33
import (
44
"bytes"
5+
"context"
56
"encoding/json"
67
"fmt"
78
"net/http"
@@ -47,8 +48,7 @@ func NewDiscordNotifier(cfg DiscordConfig, httpCli *http.Client) *DiscordNotifie
4748
}
4849
}
4950

50-
//nolint:noctx
51-
func (d *DiscordNotifier) Send(event Event) error {
51+
func (d *DiscordNotifier) Send(ctx context.Context, event Event) error {
5252
var color int
5353

5454
switch event.Type {
@@ -90,7 +90,7 @@ func (d *DiscordNotifier) Send(event Event) error {
9090
return fmt.Errorf("failed to marshal discord payload: %w", err)
9191
}
9292

93-
req, err := http.NewRequest(http.MethodPost, d.webhookURL, bytes.NewReader(jsonPayload))
93+
req, err := http.NewRequestWithContext(ctx, http.MethodPost, d.webhookURL, bytes.NewReader(jsonPayload))
9494
if err != nil {
9595
return fmt.Errorf("failed to create discord request: %w", err)
9696
}

pkg/notifications/discord_test.go

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,9 @@ func TestDiscordNotifier(t *testing.T) {
6767
test.config.WebhookURL = server.URL
6868
n := NewDiscordNotifier(test.config, server.Client())
6969

70-
err := n.Send(Event{
70+
ctx := t.Context()
71+
72+
err := n.Send(ctx, Event{
7173
Type: EventTypeBan,
7274
IP: "192.0.2.1",
7375
Message: "test",

pkg/notifications/email.go

Lines changed: 26 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,20 @@
11
package notifications
22

33
import (
4+
"context"
45
"crypto/tls"
56
"fmt"
67
"io"
8+
"net"
79
"net/smtp"
10+
"strconv"
811
"sync"
12+
"time"
13+
)
14+
15+
const (
16+
// dialerTimeout defines the timeout for establishing SMTP connections.
17+
dialerTimeout = 5 * time.Second
918
)
1019

1120
type mailer interface {
@@ -45,7 +54,7 @@ func NewEmailNotifier(cfg EmailConfig, templates *TemplateHandler) *EmailNotifie
4554
return n
4655
}
4756

48-
func (e *EmailNotifier) ensureConnected() error {
57+
func (e *EmailNotifier) ensureConnected(ctx context.Context) error {
4958
e.clientMux.Lock()
5059
defer e.clientMux.Unlock()
5160

@@ -58,7 +67,7 @@ func (e *EmailNotifier) ensureConnected() error {
5867
_ = e.client.Close()
5968
}
6069

61-
client, err := createSMTPClient(e.server, e.port, e.username, e.password)
70+
client, err := createSMTPClient(ctx, e.server, e.port, e.username, e.password)
6271
if err != nil {
6372
return fmt.Errorf("failed to reconnect SMTP client: %w", err)
6473
}
@@ -68,9 +77,9 @@ func (e *EmailNotifier) ensureConnected() error {
6877
return nil
6978
}
7079

71-
func (e *EmailNotifier) Send(event Event) error {
80+
func (e *EmailNotifier) Send(ctx context.Context, event Event) error {
7281
// Ensure we have a valid connection
73-
if err := e.ensureConnected(); err != nil {
82+
if err := e.ensureConnected(ctx); err != nil {
7483
return fmt.Errorf("failed to ensure SMTP connection: %w", err)
7584
}
7685

@@ -114,12 +123,21 @@ func (e *EmailNotifier) Send(event Event) error {
114123
return nil
115124
}
116125

117-
func createSMTPClient(host string, port int, username, password string) (*smtp.Client, error) {
118-
addr := fmt.Sprintf("%s:%d", host, port)
126+
func createSMTPClient(ctx context.Context, host string, port int, username, password string) (*smtp.Client, error) {
127+
addr := net.JoinHostPort(host, strconv.Itoa(port))
128+
129+
dialer := &net.Dialer{
130+
Timeout: dialerTimeout,
131+
}
132+
133+
conn, err := dialer.DialContext(ctx, "tcp", addr)
134+
if err != nil {
135+
return nil, fmt.Errorf("dial SMTP server: %w", err)
136+
}
119137

120-
client, err := smtp.Dial(addr)
138+
client, err := smtp.NewClient(conn, host)
121139
if err != nil {
122-
return nil, fmt.Errorf("failed to dial SMTP server: %w", err)
140+
return nil, fmt.Errorf("create SMTP client: %w", err)
123141
}
124142

125143
tlsConfig := &tls.Config{

pkg/notifications/email_test.go

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -98,11 +98,14 @@ func TestEmailNotifier(t *testing.T) {
9898
},
9999
validate: func(t *testing.T, n *EmailNotifier) {
100100
t.Helper()
101+
102+
ctx := t.Context()
103+
101104
// Send multiple emails concurrently
102105
errCh := make(chan error, 3)
103106
for range 3 {
104107
go func() {
105-
errCh <- n.Send(Event{
108+
errCh <- n.Send(ctx, Event{
106109
Type: EventTypeBan,
107110
IP: "192.0.2.1",
108111
Message: "test ban",
@@ -157,7 +160,8 @@ func TestEmailNotifier(t *testing.T) {
157160
}
158161

159162
// Send notification
160-
err := n.Send(test.event)
163+
ctx := t.Context()
164+
err := n.Send(ctx, test.event)
161165

162166
// Validate results
163167
if test.expectError {

pkg/notifications/service.go

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,14 @@
22
package notifications
33

44
import (
5+
"context"
56
"log"
67
"net/http"
78
"time"
89
)
910

1011
type notifier interface {
11-
Send(event Event) error
12+
Send(ctx context.Context, event Event) error
1213
}
1314
type Service struct {
1415
allowedTypes []string
@@ -30,10 +31,10 @@ func (s *Service) Notify(event Event) {
3031
}
3132
}
3233

33-
func (s *Service) Run() {
34+
func (s *Service) Run(ctx context.Context) {
3435
for event := range s.ch {
3536
for _, n := range s.notifiers {
36-
if err := n.Send(event); err != nil {
37+
if err := n.Send(ctx, event); err != nil {
3738
log.Printf("failed to send notification: %v", err)
3839
}
3940
}

pkg/notifications/service_test.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package notifications
22

33
import (
4+
"context"
45
"testing"
56
"time"
67

@@ -174,7 +175,7 @@ func TestServiceNotify(t *testing.T) {
174175
notifiers: []notifier{n},
175176
ch: make(chan Event, 1),
176177
}
177-
go svc.Run()
178+
go svc.Run(t.Context())
178179
svc.Notify(test.event)
179180

180181
if test.shouldNotify {
@@ -202,6 +203,6 @@ type mockNotifier struct {
202203
fn func(event Event) error
203204
}
204205

205-
func (m *mockNotifier) Send(event Event) error {
206+
func (m *mockNotifier) Send(_ context.Context, event Event) error {
206207
return m.fn(event)
207208
}

pkg/notifications/telegram.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package notifications
22

33
import (
44
"bytes"
5+
"context"
56
"encoding/json"
67
"fmt"
78
"net/http"
@@ -31,8 +32,7 @@ func NewTelegramNotifier(cfg TelegramConfig, templates *TemplateHandler, httpCli
3132
return &tn
3233
}
3334

34-
//nolint:noctx
35-
func (t *TelegramNotifier) Send(event Event) error {
35+
func (t *TelegramNotifier) Send(ctx context.Context, event Event) error {
3636
msg, err := t.templates.RenderTemplate(event)
3737
if err != nil {
3838
return fmt.Errorf("failed to render telegram template: %w", err)
@@ -50,7 +50,7 @@ func (t *TelegramNotifier) Send(event Event) error {
5050
return fmt.Errorf("failed to marshal telegram payload: %w", err)
5151
}
5252

53-
req, err := http.NewRequest(http.MethodPost, url, bytes.NewBuffer(jsonData))
53+
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewBuffer(jsonData))
5454
if err != nil {
5555
return fmt.Errorf("failed to create telegram request: %w", err)
5656
}

0 commit comments

Comments
 (0)