@@ -5,6 +5,7 @@ need a limit.
55package simultaneous
66
77import (
8+ "context"
89 "time"
910
1011 "github.com/memsql/errors"
@@ -28,8 +29,8 @@ type Enforced[T any] interface {
2829// contract.
2930type Limit [T any ] struct {
3031 queue chan struct {}
31- stuckCallback func ()
32- unstuckCallback func ()
32+ stuckCallback func (context. Context )
33+ unstuckCallback func (context. Context )
3334 stuckTimeout time.Duration
3435}
3536
@@ -49,25 +50,28 @@ func Unlimited[T any]() Enforced[T] {
4950}
5051
5152// Forever waits until there is space in the Limit for another
52- // simultaneous runner. It will wait forever. The Done() method
53+ // simultaneous runner. It will wait until the context is
54+ // cancelled. The Done() method
5355// must be called to release the space.
5456//
5557// defer limit.Forever().Done()
56- func (l * Limit [T ]) Forever () Limited [T ] {
58+ func (l * Limit [T ]) Forever (ctx context. Context ) Limited [T ] {
5759 if l .stuckTimeout == 0 {
5860 l .queue <- struct {}{}
5961 } else {
6062 timer := time .NewTimer (l .stuckTimeout )
6163 select {
6264 case l .queue <- struct {}{}:
6365 timer .Stop ()
66+ case <- ctx .Done ():
67+ timer .Stop ()
6468 case <- timer .C :
6569 if l .stuckCallback != nil {
66- l .stuckCallback ()
70+ l .stuckCallback (ctx )
6771 }
6872 l .queue <- struct {}{}
6973 if l .unstuckCallback != nil {
70- l .unstuckCallback ()
74+ l .unstuckCallback (ctx )
7175 }
7276 }
7377 }
@@ -82,13 +86,15 @@ var ErrTimeout errors.String = "could not get permission to run before timeout"
8286// simultaneous runner. In the case of a timeout, ErrTimeout is returned
8387// and the Done method is a no-op. If there is room, the Done method must
8488// be invoked to make room for another runner.
85- func (l * Limit [T ]) Timeout (timeout time.Duration ) (Limited [T ], error ) {
89+ func (l * Limit [T ]) Timeout (ctx context. Context , timeout time.Duration ) (Limited [T ], error ) {
8690 if timeout <= 0 {
8791 select {
8892 case l .queue <- struct {}{}:
8993 return limited [T ](func () {
9094 <- l .queue
9195 }), nil
96+ case <- ctx .Done ():
97+ return limited [T ](nil ), ErrTimeout .Errorf ("context cancelled before any simultaneous runner (of %d) became available" , cap (l .queue ))
9298 default :
9399 return limited [T ](nil ), ErrTimeout .Errorf ("timeout (%s) expired before any simultaneous runner (of %d) became available" , timeout , cap (l .queue ))
94100 }
@@ -100,6 +106,8 @@ func (l *Limit[T]) Timeout(timeout time.Duration) (Limited[T], error) {
100106 return limited [T ](func () {
101107 <- l .queue
102108 }), nil
109+ case <- ctx .Done ():
110+ return limited [T ](nil ), ErrTimeout .Errorf ("context cancelled before any simultaneous runner (of %d) became available" , cap (l .queue ))
103111 case <- timer .C :
104112 return limited [T ](nil ), ErrTimeout .Errorf ("timeout (%s) expired before any simultaneous runner (of %d) became available" , timeout , cap (l .queue ))
105113 }
@@ -108,7 +116,7 @@ func (l *Limit[T]) Timeout(timeout time.Duration) (Limited[T], error) {
108116// SetForeverMessaging returns a modified Limit that changes the behavior of Forever() so that
109117// it will call stuckCallback() (if set) after waiting for stuckTimeout duration. If past that duration,
110118// and it will call unstuckCallback() (if set) when it finally gets a limit.
111- func (l Limit [T ]) SetForeverMessaging (stuckTimeout time.Duration , stuckCallback func (), unstuckCallback func ()) * Limit [T ] {
119+ func (l Limit [T ]) SetForeverMessaging (stuckTimeout time.Duration , stuckCallback func (context. Context ), unstuckCallback func (context. Context )) * Limit [T ] {
112120 l .stuckTimeout = stuckTimeout
113121 l .stuckCallback = stuckCallback
114122 l .unstuckCallback = unstuckCallback
0 commit comments