Skip to content

Commit b976526

Browse files
meatballhattempusfrangitmarkphelps
authored
Support cancel / delete of in-flight stream messages (#216)
* Add cancelation metadata to Redis Add some cancelation metadata to redis that allows API to introspect the stream and message ID from the prediction ID. This allows api to perform an `XDEL` on the redis message when a cancelation is received. The only time an `XDEL` should happen is if the prediction has not started. API will be responsible to check `XPENDING` and/or the prediciton cache. Additionally API should `DEL` the meta key when it performs a cancelation or when the prediction completes. key format: `meta:cancelation:<prediction_id>` and it contains json in the form of: `{"stream_id": "<stream key>", "msg_id": "<redis message id>"}` * Move message tracking to separate write script and add client `Del` command for the other side. * Shore up tests for queue client Del and switch meta cancelation key prefix to start with `_` instead of `:` given the latter tends to be expected as a separator, imho. * Renaming to "track field" and using the sha1 of the field value for better alignment with current names and a smidge of safety. * Combine two string ops for efficiency Co-authored-by: Mark Phelps <[email protected]> * A few more refinements based on PR feedback --------- Co-authored-by: Morgan Fainberg <[email protected]> Co-authored-by: Mark Phelps <[email protected]>
1 parent 845eeb3 commit b976526

File tree

4 files changed

+335
-30
lines changed

4 files changed

+335
-30
lines changed

queue/client.go

Lines changed: 61 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,9 @@ package queue
22

33
import (
44
"context"
5+
"crypto/sha1"
6+
"encoding/json"
7+
"errors"
58
"fmt"
69
"regexp"
710
"strings"
@@ -13,15 +16,18 @@ import (
1316
)
1417

1518
var (
16-
ErrInvalidReadArgs = fmt.Errorf("queue: invalid read arguments")
17-
ErrInvalidWriteArgs = fmt.Errorf("queue: invalid write arguments")
19+
ErrInvalidReadArgs = errors.New("queue: invalid read arguments")
20+
ErrInvalidWriteArgs = errors.New("queue: invalid write arguments")
21+
ErrNoMatchingMessageInStream = errors.New("queue: no matching message in stream")
1822

1923
streamSuffixPattern = regexp.MustCompile(`\A:s(\d+)\z`)
2024
)
2125

2226
type Client struct {
2327
rdb redis.Cmdable
2428
ttl time.Duration // ttl for all keys in queue
29+
30+
trackField string
2531
}
2632

2733
type Stats struct {
@@ -32,10 +38,11 @@ type Stats struct {
3238
}
3339

3440
func NewClient(rdb redis.Cmdable, ttl time.Duration) *Client {
35-
return &Client{
36-
rdb: rdb,
37-
ttl: ttl,
38-
}
41+
return &Client{rdb: rdb, ttl: ttl}
42+
}
43+
44+
func NewTrackingClient(rdb redis.Cmdable, ttl time.Duration, field string) *Client {
45+
return &Client{rdb: rdb, ttl: ttl, trackField: field}
3946
}
4047

4148
// Prepare stores the write and read scripts in the Redis script cache so that
@@ -252,16 +259,64 @@ func (c *Client) write(ctx context.Context, args *WriteArgs) (string, error) {
252259
cmdArgs = append(cmdArgs, int(c.ttl.Seconds()))
253260
cmdArgs = append(cmdArgs, args.Streams)
254261
cmdArgs = append(cmdArgs, len(shard))
262+
263+
if c.trackField != "" {
264+
cmdArgs = append(cmdArgs, c.trackField)
265+
}
266+
255267
for _, s := range shard {
256268
cmdArgs = append(cmdArgs, s)
257269
}
258270
for k, v := range args.Values {
259271
cmdArgs = append(cmdArgs, k, v)
260272
}
261273

274+
if c.trackField != "" {
275+
return writeTrackingScript.Run(ctx, c.rdb, cmdKeys, cmdArgs...).Text()
276+
}
277+
262278
return writeScript.Run(ctx, c.rdb, cmdKeys, cmdArgs...).Text()
263279
}
264280

281+
type metaCancelation struct {
282+
StreamID string `json:"stream_id"`
283+
MsgID string `json:"msg_id"`
284+
}
285+
286+
// Del supports removal of a message when the given `fieldValue` matches a "meta
287+
// cancelation" key as written when using a client with tracking support.
288+
func (c *Client) Del(ctx context.Context, fieldValue string) error {
289+
metaCancelationKey := fmt.Sprintf("_meta:cancelation:%x", sha1.Sum([]byte(fieldValue)))
290+
291+
msgBytes, err := c.rdb.Get(ctx, metaCancelationKey).Bytes()
292+
if err != nil {
293+
return err
294+
}
295+
296+
msg := &metaCancelation{}
297+
if err := json.Unmarshal(msgBytes, msg); err != nil {
298+
return err
299+
}
300+
301+
n, err := c.rdb.XDel(ctx, msg.StreamID, msg.MsgID).Result()
302+
if err != nil {
303+
return err
304+
}
305+
306+
if n == 0 {
307+
return fmt.Errorf(
308+
"key=%q field-value=%q stream=%q message-id=%q: %w",
309+
metaCancelationKey,
310+
fieldValue,
311+
msg.StreamID,
312+
msg.MsgID,
313+
ErrNoMatchingMessageInStream,
314+
)
315+
}
316+
317+
return nil
318+
}
319+
265320
func parse(v any) (*Message, error) {
266321
result, err := parseSliceWithLength(v, 1)
267322
if err != nil {

queue/client_test.go

Lines changed: 131 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
package queue_test
22

33
import (
4+
"context"
45
crand "crypto/rand"
6+
"crypto/sha1"
57
"errors"
68
"fmt"
79
"math/rand"
@@ -19,16 +21,34 @@ import (
1921

2022
"github.com/replicate/go/queue"
2123
"github.com/replicate/go/test"
24+
"github.com/replicate/go/uuid"
2225
)
2326

2427
func TestClientIntegration(t *testing.T) {
25-
ctx := test.Context(t)
26-
rdb := test.Redis(ctx, t)
28+
t.Run("vanilla", func(t *testing.T) {
29+
ctx := test.Context(t)
30+
rdb := test.Redis(ctx, t)
2731

28-
ttl := 24 * time.Hour
29-
client := queue.NewClient(rdb, ttl)
30-
require.NoError(t, client.Prepare(ctx))
32+
ttl := 24 * time.Hour
33+
client := queue.NewClient(rdb, ttl)
34+
require.NoError(t, client.Prepare(ctx))
35+
36+
runClientIntegrationTest(ctx, t, client)
37+
})
38+
39+
t.Run("with-tracking", func(t *testing.T) {
40+
ctx := test.Context(t)
41+
rdb := test.Redis(ctx, t)
3142

43+
ttl := 24 * time.Hour
44+
client := queue.NewTrackingClient(rdb, ttl, "id")
45+
require.NoError(t, client.Prepare(ctx))
46+
47+
runClientIntegrationTest(ctx, t, client)
48+
})
49+
}
50+
51+
func runClientIntegrationTest(ctx context.Context, t *testing.T, client *queue.Client) {
3252
id := 0
3353

3454
for range 10 {
@@ -250,36 +270,67 @@ func TestClientReadIntegration(t *testing.T) {
250270
}
251271

252272
func TestClientWriteIntegration(t *testing.T) {
253-
ctx := test.Context(t)
254-
rdb := test.Redis(ctx, t)
273+
t.Run("vanilla", func(t *testing.T) {
274+
ctx := test.Context(t)
275+
rdb := test.Redis(ctx, t)
255276

256-
ttl := 24 * time.Hour
257-
client := queue.NewClient(rdb, ttl)
258-
require.NoError(t, client.Prepare(ctx))
277+
ttl := 24 * time.Hour
278+
client := queue.NewClient(rdb, ttl)
279+
require.NoError(t, client.Prepare(ctx))
280+
281+
runClientWriteIntegrationTest(ctx, t, rdb, client, false)
282+
})
283+
284+
t.Run("with-tracking", func(t *testing.T) {
285+
ctx := test.Context(t)
286+
rdb := test.Redis(ctx, t)
287+
288+
ttl := 24 * time.Hour
289+
client := queue.NewTrackingClient(rdb, ttl, "tracketytrack")
290+
require.NoError(t, client.Prepare(ctx))
291+
292+
runClientWriteIntegrationTest(ctx, t, rdb, client, true)
293+
})
294+
}
295+
296+
func runClientWriteIntegrationTest(ctx context.Context, t *testing.T, rdb *redis.Client, client *queue.Client, withTracking bool) {
297+
trackIDs := []string{}
259298

260299
for i := range 10 {
261-
_, err := client.Write(ctx, &queue.WriteArgs{
300+
trackID, err := uuid.NewV7()
301+
require.NoError(t, err)
302+
303+
trackIDs = append(trackIDs, trackID.String())
304+
305+
_, err = client.Write(ctx, &queue.WriteArgs{
262306
Name: "myqueue",
263307
Streams: 2,
264308
StreamsPerShard: 1,
265309
ShardKey: []byte("panda"),
266310
Values: map[string]any{
267-
"name": "panda",
268-
"idx": i,
311+
"idx": i,
312+
"name": "panda",
313+
"tracketytrack": trackID.String(),
269314
},
270315
})
271316
require.NoError(t, err)
272317
}
273318

274319
for i := range 5 {
275-
_, err := client.Write(ctx, &queue.WriteArgs{
320+
trackID, err := uuid.NewV7()
321+
require.NoError(t, err)
322+
323+
trackIDs = append(trackIDs, trackID.String())
324+
325+
_, err = client.Write(ctx, &queue.WriteArgs{
276326
Name: "myqueue",
277327
Streams: 2,
278328
StreamsPerShard: 1,
279329
ShardKey: []byte("giraffe"),
280330
Values: map[string]any{
281-
"name": "giraffe",
282-
"idx": i,
331+
"idx": i,
332+
"name": "giraffe",
333+
"tracketytrack": trackID.String(),
283334
},
284335
})
285336
require.NoError(t, err)
@@ -306,10 +357,10 @@ func TestClientWriteIntegration(t *testing.T) {
306357
require.NoError(t, err)
307358
assert.Len(t, values, 5)
308359
for i, v := range values {
309-
assert.Equal(t, map[string]interface{}{
310-
"name": "giraffe",
311-
"idx": strconv.Itoa(i),
312-
}, v.Values)
360+
assert.Contains(t, v.Values, "name")
361+
assert.Contains(t, v.Values, "idx")
362+
assert.Equal(t, v.Values["name"], "giraffe")
363+
assert.Equal(t, v.Values["idx"], strconv.Itoa(i))
313364
}
314365
}
315366

@@ -323,10 +374,10 @@ func TestClientWriteIntegration(t *testing.T) {
323374
require.NoError(t, err)
324375
assert.Len(t, values, 10)
325376
for i, v := range values {
326-
assert.Equal(t, map[string]interface{}{
327-
"name": "panda",
328-
"idx": strconv.Itoa(i),
329-
}, v.Values)
377+
assert.Contains(t, v.Values, "name")
378+
assert.Contains(t, v.Values, "idx")
379+
assert.Equal(t, v.Values["name"], "panda")
380+
assert.Equal(t, v.Values["idx"], strconv.Itoa(i))
330381
}
331382
}
332383

@@ -343,6 +394,62 @@ func TestClientWriteIntegration(t *testing.T) {
343394
require.NoError(t, err)
344395
assert.Greater(t, ttl, 23*time.Hour)
345396
}
397+
398+
if !withTracking {
399+
return
400+
}
401+
402+
for _, trackID := range trackIDs {
403+
require.NoError(t, client.Del(ctx, trackID))
404+
}
405+
}
406+
407+
func TestClientDelIntegration(t *testing.T) {
408+
ctx := test.Context(t)
409+
rdb := test.Redis(ctx, t)
410+
411+
ttl := 24 * time.Hour
412+
client := queue.NewTrackingClient(rdb, ttl, "tracketytrack")
413+
require.NoError(t, client.Prepare(ctx))
414+
415+
trackIDs := []string{}
416+
417+
for i := range 3 {
418+
trackID, err := uuid.NewV7()
419+
require.NoError(t, err)
420+
421+
trackIDs = append(trackIDs, trackID.String())
422+
423+
_, err = client.Write(ctx, &queue.WriteArgs{
424+
Name: "myqueue",
425+
Streams: 2,
426+
StreamsPerShard: 1,
427+
ShardKey: []byte("capybara"),
428+
Values: map[string]any{
429+
"idx": i,
430+
"name": "capybara",
431+
"tracketytrack": trackID.String(),
432+
},
433+
})
434+
require.NoError(t, err)
435+
}
436+
437+
require.NoError(t, client.Del(ctx, trackIDs[0]))
438+
require.Error(t, client.Del(ctx, trackIDs[0]))
439+
require.Error(t, client.Del(ctx, trackIDs[0]+"oops"))
440+
require.Error(t, client.Del(ctx, "bogustown"))
441+
442+
metaCancelationKey := "_meta:cancelation:" + fmt.Sprintf("%x", sha1.Sum([]byte(trackIDs[1])))
443+
444+
metaCancel, err := rdb.Get(ctx, metaCancelationKey).Result()
445+
require.NoError(t, err)
446+
447+
rdb.SetEx(ctx, metaCancelationKey, "{{[,"+metaCancel, 5*time.Second)
448+
449+
require.Error(t, client.Del(ctx, trackIDs[1]))
450+
451+
require.NoError(t, client.Del(ctx, trackIDs[2]))
452+
require.ErrorIs(t, client.Del(ctx, trackIDs[2]), queue.ErrNoMatchingMessageInStream)
346453
}
347454

348455
// TestPickupLatencyIntegration runs a test with a mostly-empty queue -- by

queue/queue.go

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,10 @@ var (
5959
//go:embed write.lua
6060
writeCmd string
6161
writeScript = redis.NewScript(writeCmd)
62+
63+
//go:embed writetracking.lua
64+
writeTrackingCmd string
65+
writeTrackingScript = redis.NewScript(writeTrackingCmd)
6266
)
6367

6468
func prepare(ctx context.Context, rdb redis.Cmdable) error {
@@ -77,5 +81,8 @@ func prepare(ctx context.Context, rdb redis.Cmdable) error {
7781
if err := writeScript.Load(ctx, rdb).Err(); err != nil {
7882
return err
7983
}
84+
if err := writeTrackingScript.Load(ctx, rdb).Err(); err != nil {
85+
return err
86+
}
8087
return nil
8188
}

0 commit comments

Comments
 (0)