Skip to content

Commit 3a7e9b3

Browse files
committed
feat(notifications): don't send notifications for recent user actions
1 parent 4aa435e commit 3a7e9b3

File tree

3 files changed

+59
-16
lines changed

3 files changed

+59
-16
lines changed

cmd/wrtag/main.go

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ import (
2323
"go.senan.xyz/wrtag/cmd/internal/wrtagflag"
2424
"go.senan.xyz/wrtag/cmd/internal/wrtaglog"
2525
"go.senan.xyz/wrtag/fileutil"
26+
"go.senan.xyz/wrtag/notifications"
2627
"go.senan.xyz/wrtag/researchlink"
2728
)
2829

@@ -49,7 +50,7 @@ func main() {
4950
wrtagflag.DefaultClient()
5051
var (
5152
cfg = wrtagflag.Config()
52-
notifications = wrtagflag.Notifications()
53+
notifs = wrtagflag.Notifications()
5354
researchLinkQuerier = wrtagflag.ResearchLinks()
5455
)
5556
wrtagflag.Parse()
@@ -74,6 +75,8 @@ func main() {
7475
)
7576
flag.Parse(args)
7677

78+
ctx := notifications.RecordAction(context.Background())
79+
7780
var importCondition wrtag.ImportCondition
7881
if *yes {
7982
importCondition = wrtag.Always
@@ -91,7 +94,7 @@ func main() {
9194
return
9295
}
9396

94-
ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM)
97+
ctx, cancel := signal.NotifyContext(ctx, os.Interrupt, syscall.SIGTERM)
9598
defer cancel()
9699

97100
op, err := wrtagflag.OperationByName(command, *dryRun)
@@ -115,6 +118,8 @@ func main() {
115118
)
116119
flag.Parse(args)
117120

121+
ctx := notifications.RecordAction(context.Background())
122+
118123
// walk the whole root dir by default, or some user provided dirs if provided
119124
var dirs []string
120125
if args := flag.Args(); len(args) > 0 {
@@ -132,7 +137,7 @@ func main() {
132137
}
133138
}
134139

135-
ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM)
140+
ctx, cancel := signal.NotifyContext(ctx, os.Interrupt, syscall.SIGTERM)
136141
defer cancel()
137142

138143
start := time.Now()
@@ -148,10 +153,10 @@ func main() {
148153
switch {
149154
case stats.errors.Load() > 0:
150155
slog.Error("sync finished", "took", took, "", &stats)
151-
notifications.Sendf(ctx, notifSyncError, "sync finished in %v %v", took, &stats)
156+
notifs.Sendf(ctx, notifSyncError, "sync finished in %v %v", took, &stats)
152157
default:
153158
slog.Info("sync finished", "took", took, "", &stats)
154-
notifications.Sendf(ctx, notifSyncComplete, "sync finished in %v %v", took, &stats)
159+
notifs.Sendf(ctx, notifSyncComplete, "sync finished in %v %v", took, &stats)
155160
}
156161

157162
default:

cmd/wrtagweb/main.go

Lines changed: 34 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ import (
2828
"go.senan.xyz/wrtag"
2929
wrtagflag "go.senan.xyz/wrtag/cmd/internal/wrtagflag"
3030
"go.senan.xyz/wrtag/cmd/internal/wrtaglog"
31+
"go.senan.xyz/wrtag/notifications"
3132
"go.senan.xyz/wrtag/researchlink"
3233

3334
_ "github.com/ncruces/go-sqlite3/driver"
@@ -58,7 +59,7 @@ func main() {
5859
wrtagflag.DefaultClient()
5960
var (
6061
cfg = wrtagflag.Config()
61-
notifications = wrtagflag.Notifications()
62+
notifs = wrtagflag.Notifications()
6263
researchLinkQuerier = wrtagflag.ResearchLinks()
6364
apiKey = flag.String("web-api-key", "", "API key for web interface")
6465
listenAddr = flag.String("web-listen-addr", ":7373", "Listen address for web interface (optional)")
@@ -187,9 +188,9 @@ func main() {
187188

188189
switch job.Status {
189190
case StatusComplete:
190-
go notifications.Send(context.WithoutCancel(ctx), notifComplete, jobNotificationMessage(*publicURL, job))
191+
go notifs.Send(context.WithoutCancel(ctx), notifComplete, jobNotificationMessage(*publicURL, job))
191192
case StatusNeedsInput:
192-
go notifications.Send(context.WithoutCancel(ctx), notifNeedsInput, jobNotificationMessage(*publicURL, job))
193+
go notifs.Send(context.WithoutCancel(ctx), notifNeedsInput, jobNotificationMessage(*publicURL, job))
193194
}
194195

195196
return nil
@@ -266,7 +267,9 @@ func main() {
266267
w.WriteHeader(http.StatusOK)
267268
rc.Flush()
268269

269-
for id := range sse.receive(r.Context(), 0) {
270+
ctx := r.Context()
271+
272+
for id := range sse.receive(ctx, 0) {
270273
fmt.Fprintf(w, "data: %d\n\n", id)
271274
rc.Flush()
272275
}
@@ -276,7 +279,10 @@ func main() {
276279
search := r.URL.Query().Get("search")
277280
filter := JobStatus(r.URL.Query().Get("filter"))
278281
page, _ := strconv.Atoi(r.URL.Query().Get("page"))
279-
jl, err := listJobs(r.Context(), filter, search, page)
282+
283+
ctx := r.Context()
284+
285+
jl, err := listJobs(ctx, filter, search, page)
280286
if err != nil {
281287
respErrf(w, http.StatusInternalServerError, "error listing jobs: %v", err)
282288
return
@@ -301,8 +307,11 @@ func main() {
301307
}
302308
path = filepath.Clean(path)
303309

310+
ctx := r.Context()
311+
ctx = notifications.RecordAction(ctx)
312+
304313
var job Job
305-
if err := sqlb.ScanRow(r.Context(), db, &job, "insert into jobs (source_path, operation, time) values (?, ?, ?) returning *", path, operationStr, time.Now()); err != nil {
314+
if err := sqlb.ScanRow(ctx, db, &job, "insert into jobs (source_path, operation, time) values (?, ?, ?) returning *", path, operationStr, time.Now()); err != nil {
306315
http.Error(w, fmt.Sprintf("error saving job: %v", err), http.StatusInternalServerError)
307316
return
308317
}
@@ -314,8 +323,11 @@ func main() {
314323

315324
mux.HandleFunc("GET /jobs/{id}", func(w http.ResponseWriter, r *http.Request) {
316325
id, _ := strconv.Atoi(r.PathValue("id"))
326+
327+
ctx := r.Context()
328+
317329
var job Job
318-
if err := sqlb.ScanRow(r.Context(), db, &job, "select * from jobs where id=?", id); err != nil {
330+
if err := sqlb.ScanRow(ctx, db, &job, "select * from jobs where id=?", id); err != nil {
319331
respErrf(w, http.StatusInternalServerError, "error getting job")
320332
return
321333
}
@@ -332,8 +344,11 @@ func main() {
332344
useMBID = filepath.Base(useMBID) // accept release URL
333345
}
334346

347+
ctx := r.Context()
348+
ctx = notifications.RecordAction(ctx)
349+
335350
var job Job
336-
if err := sqlb.ScanRow(r.Context(), db, &job, "update jobs set confirm=?, use_mbid=?, status=? where id=? and status<>? returning *", confirm, useMBID, StatusEnqueued, id, StatusInProgress); err != nil {
351+
if err := sqlb.ScanRow(ctx, db, &job, "update jobs set confirm=?, use_mbid=?, status=? where id=? and status<>? returning *", confirm, useMBID, StatusEnqueued, id, StatusInProgress); err != nil {
337352
respErrf(w, http.StatusInternalServerError, "error getting job")
338353
return
339354
}
@@ -345,7 +360,11 @@ func main() {
345360

346361
mux.HandleFunc("DELETE /jobs/{id}", func(w http.ResponseWriter, r *http.Request) {
347362
id, _ := strconv.Atoi(r.PathValue("id"))
348-
if err := sqlb.Exec(r.Context(), db, "delete from jobs where id=? and status<>?", id, StatusInProgress); err != nil {
363+
364+
ctx := r.Context()
365+
ctx = notifications.RecordAction(ctx)
366+
367+
if err := sqlb.Exec(ctx, db, "delete from jobs where id=? and status<>?", id, StatusInProgress); err != nil {
349368
respErrf(w, http.StatusInternalServerError, "error getting job")
350369
return
351370
}
@@ -384,7 +403,9 @@ func main() {
384403
})
385404

386405
mux.HandleFunc("/{$}", func(w http.ResponseWriter, r *http.Request) {
387-
jl, err := listJobs(r.Context(), "", "", 0)
406+
ctx := r.Context()
407+
408+
jl, err := listJobs(ctx, "", "", 0)
388409
if err != nil {
389410
respErrf(w, http.StatusInternalServerError, "error listing jobs: %v", err)
390411
return
@@ -417,7 +438,9 @@ func main() {
417438
}
418439
path = filepath.Clean(path)
419440

420-
if err := sqlb.Exec(r.Context(), db, "insert into jobs (source_path, operation, time) values (?, ?, ?)", path, operationStr, time.Now()); err != nil {
441+
ctx := r.Context()
442+
443+
if err := sqlb.Exec(ctx, db, "insert into jobs (source_path, operation, time) values (?, ?, ?)", path, operationStr, time.Now()); err != nil {
421444
http.Error(w, fmt.Sprintf("error saving job: %v", err), http.StatusInternalServerError)
422445
return
423446
}

notifications/notifications.go

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import (
66
"fmt"
77
"log/slog"
88
"net/url"
9+
"time"
910

1011
"github.com/containrrr/shoutrrr"
1112
shoutrrrtypes "github.com/containrrr/shoutrrr/pkg/types"
@@ -44,6 +45,12 @@ func (n *Notifications) Sendf(ctx context.Context, event string, f string, a ...
4445
// Send a simple string for now, maybe later message could instead be a type which
4546
// implements a notifications.Bodyer or something so that notifiers can send rich notifications.
4647
func (n *Notifications) Send(ctx context.Context, event string, message string) {
48+
if actionTime, ok := ctx.Value(actionKey{}).(time.Time); ok && time.Since(actionTime) < 30*time.Second {
49+
slog.DebugContext(ctx, "suppressing notification for recent manual action",
50+
"event", event, "since_manual", time.Since(actionTime))
51+
return
52+
}
53+
4754
uris := n.mappings[event]
4855
if len(uris) == 0 {
4956
return
@@ -63,3 +70,11 @@ func (n *Notifications) Send(ctx context.Context, event string, message string)
6370
return
6471
}
6572
}
73+
74+
type actionKey struct{}
75+
76+
// RecordAction records the current time of a user action and returns a context which may
77+
// be used to suppres notifications later.
78+
func RecordAction(ctx context.Context) context.Context {
79+
return context.WithValue(ctx, actionKey{}, time.Now())
80+
}

0 commit comments

Comments
 (0)