Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion cache/cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -369,7 +369,9 @@ type _nullLock struct{}

var nullLock lock.Lock = &_nullLock{}

func (*_nullLock) Release(context.Context) error { return nil }
func (*_nullLock) Release(context.Context) error { return nil }
func (*_nullLock) TTL(context.Context) (time.Duration, error) { return 0, lock.ErrLockNotHeld }
func (*_nullLock) Refresh(context.Context, time.Duration) error { return lock.ErrLockNotHeld }

func (c *Cache[T]) acquireIfMultipleRedises(ctx context.Context, key string, ttl time.Duration) (lock.Lock, error) {
if len(c.clients) == 1 {
Expand Down
10 changes: 10 additions & 0 deletions kv/kv.go
Original file line number Diff line number Diff line change
Expand Up @@ -247,3 +247,13 @@ func New(ctx context.Context, name, urlString string, clientOpts ...ClientOption

return client, nil
}

// Exists checks if a key exists in Redis. Returns true if the key exists, false otherwise.
// May return an error if it cannot communicate with Redis.
func Exists(ctx context.Context, client redis.UniversalClient, key string) (bool, error) {
result, err := client.Exists(ctx, key).Result()
if err != nil {
return false, err
}
return result == 1, nil
}
32 changes: 32 additions & 0 deletions kv/kv_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -234,3 +234,35 @@ func TestNewWithEmptyAddr(t *testing.T) {
assert.Nil(t, client)
assert.Contains(t, err.Error(), "failed to ping")
}

func TestExists(t *testing.T) {
if os.Getenv("INTEGRATION") != "1" {
t.Skip("skipping integration test")
}

ctx := test.Context(t)
client := test.Redis(ctx, t)

// Test key that doesn't exist
exists, err := kv.Exists(ctx, client, "nonexistent-key")
require.NoError(t, err)
assert.False(t, exists)

// Create a key
err = client.Set(ctx, "test-key", "test-value", 0).Err()
require.NoError(t, err)

// Test key that exists
exists, err = kv.Exists(ctx, client, "test-key")
require.NoError(t, err)
assert.True(t, exists)

// Delete the key
err = client.Del(ctx, "test-key").Err()
require.NoError(t, err)

// Test key that no longer exists
exists, err = kv.Exists(ctx, client, "test-key")
require.NoError(t, err)
assert.False(t, exists)
}
53 changes: 53 additions & 0 deletions lock/lock.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ type Locker struct {

type Lock interface {
Release(context.Context) error
TTL(context.Context) (time.Duration, error)
Refresh(context.Context, time.Duration) error
}

type lock struct {
Expand Down Expand Up @@ -124,6 +126,57 @@ func (l *lock) Release(ctx context.Context) error {
return l.release(ctx, len(l.clients))
}

// TTL returns the remaining time to live for the lock. Returns an error if
// the lock has expired or is held by another party.
func (l *lock) TTL(ctx context.Context) (time.Duration, error) {
if len(l.clients) == 0 {
return 0, ErrLockNotHeld
}

// Check the first client for TTL - all clients should have the same TTL
ttl, err := l.clients[0].TTL(ctx, l.key).Result()
if err != nil {
return 0, err
}

// If TTL is -2, the key doesn't exist; if -1, the key has no expiration
if ttl < 0 {
return 0, ErrLockNotHeld
}

return ttl, nil
}

// Refresh extends the lock's TTL across all clients. Returns an error if the
// lock has expired or is held by another party.
func (l *lock) Refresh(ctx context.Context, ttl time.Duration) error {
errs := []error{}

for _, client := range l.clients {
// Use EXPIRE command with the token verification to ensure we still hold the lock
result, err := client.Eval(ctx,
`if redis.call("get", KEYS[1]) == ARGV[1] then
return redis.call("expire", KEYS[1], ARGV[2])
else
return 0
end`,
[]string{l.key},
l.token,
int(ttl.Seconds())).Result()

if err != nil {
errs = append(errs, err)
continue
}

if i, ok := result.(int64); !ok || i != 1 {
errs = append(errs, ErrLockNotHeld)
}
}

return errors.Join(errs...)
}

func (l *lock) release(ctx context.Context, n int) error {
errs := []error{}

Expand Down
84 changes: 84 additions & 0 deletions lock/lock_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -411,3 +411,87 @@ func TestLockTryAcquireIntegration(t *testing.T) {
// Check that only one goroutine got the lock
require.Equal(t, 1, len(results))
}

func TestLockTTL(t *testing.T) {
ctx := context.Background()
client, mock := redismock.NewClientMock()

l := &lock{
clients: []redis.Cmdable{client},
key: "test-key",
token: "test-token",
}

// Test successful TTL
mock.ExpectTTL("test-key").SetVal(30 * time.Second)

ttl, err := l.TTL(ctx)
assert.NoError(t, err)
assert.Equal(t, 30*time.Second, ttl)
assert.NoError(t, mock.ExpectationsWereMet())
}

func TestLockTTLNotHeld(t *testing.T) {
ctx := context.Background()
client, mock := redismock.NewClientMock()

l := &lock{
clients: []redis.Cmdable{client},
key: "test-key",
token: "test-token",
}

// Test when key doesn't exist (TTL = -2)
mock.ExpectTTL("test-key").SetVal(-2 * time.Second)

ttl, err := l.TTL(ctx)
assert.ErrorIs(t, err, ErrLockNotHeld)
assert.Equal(t, time.Duration(0), ttl)
assert.NoError(t, mock.ExpectationsWereMet())
}

func TestLockRefresh(t *testing.T) {
ctx := context.Background()
client, mock := redismock.NewClientMock()

l := &lock{
clients: []redis.Cmdable{client},
key: "test-key",
token: "test-token",
}

// Test successful refresh
refreshScript := `if redis.call("get", KEYS[1]) == ARGV[1] then
return redis.call("expire", KEYS[1], ARGV[2])
else
return 0
end`
mock.ExpectEval(refreshScript, []string{"test-key"}, "test-token", int(60)).SetVal(int64(1))

err := l.Refresh(ctx, 60*time.Second)
assert.NoError(t, err)
assert.NoError(t, mock.ExpectationsWereMet())
}

func TestLockRefreshNotHeld(t *testing.T) {
ctx := context.Background()
client, mock := redismock.NewClientMock()

l := &lock{
clients: []redis.Cmdable{client},
key: "test-key",
token: "test-token",
}

// Test when lock is not held (script returns 0)
refreshScript := `if redis.call("get", KEYS[1]) == ARGV[1] then
return redis.call("expire", KEYS[1], ARGV[2])
else
return 0
end`
mock.ExpectEval(refreshScript, []string{"test-key"}, "test-token", int(60)).SetVal(int64(0))

err := l.Refresh(ctx, 60*time.Second)
assert.ErrorIs(t, err, ErrLockNotHeld)
assert.NoError(t, mock.ExpectationsWereMet())
}