diff --git a/cache/cache.go b/cache/cache.go index 62f07a6..9e87c5e 100644 --- a/cache/cache.go +++ b/cache/cache.go @@ -369,7 +369,9 @@ type _nullLock struct{} var nullLock lock.Lock = &_nullLock{} -func (*_nullLock) Release(context.Context) error { return nil } +func (*_nullLock) Release(context.Context) error { return nil } +func (*_nullLock) TTL(context.Context) (time.Duration, error) { return 0, lock.ErrLockNotHeld } +func (*_nullLock) Refresh(context.Context, time.Duration) error { return lock.ErrLockNotHeld } func (c *Cache[T]) acquireIfMultipleRedises(ctx context.Context, key string, ttl time.Duration) (lock.Lock, error) { if len(c.clients) == 1 { diff --git a/kv/kv.go b/kv/kv.go index 6bd9c92..0c829f6 100644 --- a/kv/kv.go +++ b/kv/kv.go @@ -247,3 +247,13 @@ func New(ctx context.Context, name, urlString string, clientOpts ...ClientOption return client, nil } + +// Exists checks if a key exists in Redis. Returns true if the key exists, false otherwise. +// May return an error if it cannot communicate with Redis. +func Exists(ctx context.Context, client redis.UniversalClient, key string) (bool, error) { + result, err := client.Exists(ctx, key).Result() + if err != nil { + return false, err + } + return result == 1, nil +} diff --git a/kv/kv_test.go b/kv/kv_test.go index 5b0be2f..3daaa2e 100644 --- a/kv/kv_test.go +++ b/kv/kv_test.go @@ -234,3 +234,35 @@ func TestNewWithEmptyAddr(t *testing.T) { assert.Nil(t, client) assert.Contains(t, err.Error(), "failed to ping") } + +func TestExists(t *testing.T) { + if os.Getenv("INTEGRATION") != "1" { + t.Skip("skipping integration test") + } + + ctx := test.Context(t) + client := test.Redis(ctx, t) + + // Test key that doesn't exist + exists, err := kv.Exists(ctx, client, "nonexistent-key") + require.NoError(t, err) + assert.False(t, exists) + + // Create a key + err = client.Set(ctx, "test-key", "test-value", 0).Err() + require.NoError(t, err) + + // Test key that exists + exists, err = kv.Exists(ctx, client, "test-key") + require.NoError(t, err) + assert.True(t, exists) + + // Delete the key + err = client.Del(ctx, "test-key").Err() + require.NoError(t, err) + + // Test key that no longer exists + exists, err = kv.Exists(ctx, client, "test-key") + require.NoError(t, err) + assert.False(t, exists) +} diff --git a/lock/lock.go b/lock/lock.go index d63b2e3..708fb25 100644 --- a/lock/lock.go +++ b/lock/lock.go @@ -37,6 +37,8 @@ type Locker struct { type Lock interface { Release(context.Context) error + TTL(context.Context) (time.Duration, error) + Refresh(context.Context, time.Duration) error } type lock struct { @@ -124,6 +126,57 @@ func (l *lock) Release(ctx context.Context) error { return l.release(ctx, len(l.clients)) } +// TTL returns the remaining time to live for the lock. Returns an error if +// the lock has expired or is held by another party. +func (l *lock) TTL(ctx context.Context) (time.Duration, error) { + if len(l.clients) == 0 { + return 0, ErrLockNotHeld + } + + // Check the first client for TTL - all clients should have the same TTL + ttl, err := l.clients[0].TTL(ctx, l.key).Result() + if err != nil { + return 0, err + } + + // If TTL is -2, the key doesn't exist; if -1, the key has no expiration + if ttl < 0 { + return 0, ErrLockNotHeld + } + + return ttl, nil +} + +// Refresh extends the lock's TTL across all clients. Returns an error if the +// lock has expired or is held by another party. +func (l *lock) Refresh(ctx context.Context, ttl time.Duration) error { + errs := []error{} + + for _, client := range l.clients { + // Use EXPIRE command with the token verification to ensure we still hold the lock + result, err := client.Eval(ctx, + `if redis.call("get", KEYS[1]) == ARGV[1] then + return redis.call("expire", KEYS[1], ARGV[2]) + else + return 0 + end`, + []string{l.key}, + l.token, + int(ttl.Seconds())).Result() + + if err != nil { + errs = append(errs, err) + continue + } + + if i, ok := result.(int64); !ok || i != 1 { + errs = append(errs, ErrLockNotHeld) + } + } + + return errors.Join(errs...) +} + func (l *lock) release(ctx context.Context, n int) error { errs := []error{} diff --git a/lock/lock_test.go b/lock/lock_test.go index c6c4f37..c7ecdc4 100644 --- a/lock/lock_test.go +++ b/lock/lock_test.go @@ -411,3 +411,87 @@ func TestLockTryAcquireIntegration(t *testing.T) { // Check that only one goroutine got the lock require.Equal(t, 1, len(results)) } + +func TestLockTTL(t *testing.T) { + ctx := context.Background() + client, mock := redismock.NewClientMock() + + l := &lock{ + clients: []redis.Cmdable{client}, + key: "test-key", + token: "test-token", + } + + // Test successful TTL + mock.ExpectTTL("test-key").SetVal(30 * time.Second) + + ttl, err := l.TTL(ctx) + assert.NoError(t, err) + assert.Equal(t, 30*time.Second, ttl) + assert.NoError(t, mock.ExpectationsWereMet()) +} + +func TestLockTTLNotHeld(t *testing.T) { + ctx := context.Background() + client, mock := redismock.NewClientMock() + + l := &lock{ + clients: []redis.Cmdable{client}, + key: "test-key", + token: "test-token", + } + + // Test when key doesn't exist (TTL = -2) + mock.ExpectTTL("test-key").SetVal(-2 * time.Second) + + ttl, err := l.TTL(ctx) + assert.ErrorIs(t, err, ErrLockNotHeld) + assert.Equal(t, time.Duration(0), ttl) + assert.NoError(t, mock.ExpectationsWereMet()) +} + +func TestLockRefresh(t *testing.T) { + ctx := context.Background() + client, mock := redismock.NewClientMock() + + l := &lock{ + clients: []redis.Cmdable{client}, + key: "test-key", + token: "test-token", + } + + // Test successful refresh + refreshScript := `if redis.call("get", KEYS[1]) == ARGV[1] then + return redis.call("expire", KEYS[1], ARGV[2]) + else + return 0 + end` + mock.ExpectEval(refreshScript, []string{"test-key"}, "test-token", int(60)).SetVal(int64(1)) + + err := l.Refresh(ctx, 60*time.Second) + assert.NoError(t, err) + assert.NoError(t, mock.ExpectationsWereMet()) +} + +func TestLockRefreshNotHeld(t *testing.T) { + ctx := context.Background() + client, mock := redismock.NewClientMock() + + l := &lock{ + clients: []redis.Cmdable{client}, + key: "test-key", + token: "test-token", + } + + // Test when lock is not held (script returns 0) + refreshScript := `if redis.call("get", KEYS[1]) == ARGV[1] then + return redis.call("expire", KEYS[1], ARGV[2]) + else + return 0 + end` + mock.ExpectEval(refreshScript, []string{"test-key"}, "test-token", int(60)).SetVal(int64(0)) + + err := l.Refresh(ctx, 60*time.Second) + assert.ErrorIs(t, err, ErrLockNotHeld) + assert.NoError(t, mock.ExpectationsWereMet()) +}