Skip to content

Commit fb13921

Browse files
committed
[BREAKING CHANGE] API adjustment to add context.Context
1 parent d80ddb2 commit fb13921

File tree

3 files changed

+35
-17
lines changed

3 files changed

+35
-17
lines changed

claude_test.go

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package simultaneous_test
33
// This generated with Claude 3.7 Sonnet
44

55
import (
6+
"context"
67
"testing"
78
"time"
89

@@ -16,17 +17,17 @@ func TestTimeoutRejection(t *testing.T) {
1617
limit := simultaneous.New[any](1)
1718

1819
// Take the only available slot
19-
done := limit.Forever()
20+
done := limit.Forever(context.Background())
2021
defer done.Done()
2122

2223
// This should time out immediately
23-
_, err := limit.Timeout(0)
24+
_, err := limit.Timeout(context.Background(), 0)
2425
assert.Error(t, err)
2526
t.Log("Timeout(0) correctly rejected when limit is full")
2627

2728
// This should time out after a short wait
2829
start := time.Now()
29-
_, err = limit.Timeout(50 * time.Millisecond)
30+
_, err = limit.Timeout(context.Background(), 50*time.Millisecond)
3031
duration := time.Since(start)
3132

3233
assert.Error(t, err)

limit.go

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ need a limit.
55
package simultaneous
66

77
import (
8+
"context"
89
"time"
910

1011
"github.com/memsql/errors"
@@ -28,8 +29,8 @@ type Enforced[T any] interface {
2829
// contract.
2930
type Limit[T any] struct {
3031
queue chan struct{}
31-
stuckCallback func()
32-
unstuckCallback func()
32+
stuckCallback func(context.Context)
33+
unstuckCallback func(context.Context)
3334
stuckTimeout time.Duration
3435
}
3536

@@ -49,25 +50,28 @@ func Unlimited[T any]() Enforced[T] {
4950
}
5051

5152
// Forever waits until there is space in the Limit for another
52-
// simultaneous runner. It will wait forever. The Done() method
53+
// simultaneous runner. It will wait until the context is
54+
// cancelled. The Done() method
5355
// must be called to release the space.
5456
//
5557
// defer limit.Forever().Done()
56-
func (l *Limit[T]) Forever() Limited[T] {
58+
func (l *Limit[T]) Forever(ctx context.Context) Limited[T] {
5759
if l.stuckTimeout == 0 {
5860
l.queue <- struct{}{}
5961
} else {
6062
timer := time.NewTimer(l.stuckTimeout)
6163
select {
6264
case l.queue <- struct{}{}:
6365
timer.Stop()
66+
case <-ctx.Done():
67+
timer.Stop()
6468
case <-timer.C:
6569
if l.stuckCallback != nil {
66-
l.stuckCallback()
70+
l.stuckCallback(ctx)
6771
}
6872
l.queue <- struct{}{}
6973
if l.unstuckCallback != nil {
70-
l.unstuckCallback()
74+
l.unstuckCallback(ctx)
7175
}
7276
}
7377
}
@@ -82,13 +86,15 @@ var ErrTimeout errors.String = "could not get permission to run before timeout"
8286
// simultaneous runner. In the case of a timeout, ErrTimeout is returned
8387
// and the Done method is a no-op. If there is room, the Done method must
8488
// be invoked to make room for another runner.
85-
func (l *Limit[T]) Timeout(timeout time.Duration) (Limited[T], error) {
89+
func (l *Limit[T]) Timeout(ctx context.Context, timeout time.Duration) (Limited[T], error) {
8690
if timeout <= 0 {
8791
select {
8892
case l.queue <- struct{}{}:
8993
return limited[T](func() {
9094
<-l.queue
9195
}), nil
96+
case <-ctx.Done():
97+
return limited[T](nil), ErrTimeout.Errorf("context cancelled before any simultaneous runner (of %d) became available", cap(l.queue))
9298
default:
9399
return limited[T](nil), ErrTimeout.Errorf("timeout (%s) expired before any simultaneous runner (of %d) became available", timeout, cap(l.queue))
94100
}
@@ -100,6 +106,8 @@ func (l *Limit[T]) Timeout(timeout time.Duration) (Limited[T], error) {
100106
return limited[T](func() {
101107
<-l.queue
102108
}), nil
109+
case <-ctx.Done():
110+
return limited[T](nil), ErrTimeout.Errorf("context cancelled before any simultaneous runner (of %d) became available", cap(l.queue))
103111
case <-timer.C:
104112
return limited[T](nil), ErrTimeout.Errorf("timeout (%s) expired before any simultaneous runner (of %d) became available", timeout, cap(l.queue))
105113
}
@@ -108,7 +116,7 @@ func (l *Limit[T]) Timeout(timeout time.Duration) (Limited[T], error) {
108116
// SetForeverMessaging returns a modified Limit that changes the behavior of Forever() so that
109117
// it will call stuckCallback() (if set) after waiting for stuckTimeout duration. If past that duration,
110118
// and it will call unstuckCallback() (if set) when it finally gets a limit.
111-
func (l Limit[T]) SetForeverMessaging(stuckTimeout time.Duration, stuckCallback func(), unstuckCallback func()) *Limit[T] {
119+
func (l Limit[T]) SetForeverMessaging(stuckTimeout time.Duration, stuckCallback func(context.Context), unstuckCallback func(context.Context)) *Limit[T] {
112120
l.stuckTimeout = stuckTimeout
113121
l.stuckCallback = stuckCallback
114122
l.unstuckCallback = unstuckCallback

limit_test.go

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package simultaneous_test
22

33
import (
4+
"context"
45
"sync"
56
"sync/atomic"
67
"testing"
@@ -38,12 +39,12 @@ func testLimit(t *testing.T, withStuck bool) {
3839

3940
if withStuck {
4041
limit = limit.SetForeverMessaging(time.Millisecond,
41-
func() {
42+
func(context.Context) {
4243
if stuckCalled.Add(1) == 1 {
4344
close(someUnstuck)
4445
}
4546
},
46-
func() {
47+
func(context.Context) {
4748
unstuckCalled.Add(1)
4849
},
4950
)
@@ -58,12 +59,12 @@ func testLimit(t *testing.T, withStuck bool) {
5859
go func() {
5960
defer wg.Done()
6061
var done simultaneous.Limited[any]
61-
switch i % 3 {
62+
switch i % 4 {
6263
case 0:
63-
done = limit.Forever()
64+
done = limit.Forever(context.Background())
6465
case 1:
6566
var err error
66-
done, err = limit.Timeout(0)
67+
done, err = limit.Timeout(context.Background(), 0)
6768
if err != nil {
6869
fail.Add(1)
6970
return
@@ -72,7 +73,15 @@ func testLimit(t *testing.T, withStuck bool) {
7273
}
7374
case 2:
7475
var err error
75-
done, err = limit.Timeout(time.Second * 2)
76+
done, err = limit.Timeout(context.Background(), time.Second*2)
77+
if !assert.NoError(t, err) {
78+
return
79+
}
80+
case 3:
81+
var err error
82+
ctx, cancel := context.WithTimeout(context.Background(), time.Second*2)
83+
t.Cleanup(cancel)
84+
done, err = limit.Timeout(ctx, time.Hour)
7685
if !assert.NoError(t, err) {
7786
return
7887
}

0 commit comments

Comments
 (0)