Skip to content

Commit c2a1826

Browse files
authored
feat!: use context parameter in Future.Get (#35)
1 parent 6d948a5 commit c2a1826

File tree

4 files changed

+33
-23
lines changed

4 files changed

+33
-23
lines changed

future.go

Lines changed: 12 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
11
package async
22

33
import (
4-
"fmt"
4+
"context"
55
"sync"
6-
"time"
76
)
87

98
// Future represents a value which may or may not currently be available,
@@ -23,9 +22,9 @@ type Future[T any] interface {
2322
// or an error.
2423
Join() (T, error)
2524

26-
// Get blocks for at most the given time duration for this Future to
27-
// complete and returns either a result or an error.
28-
Get(time.Duration) (T, error)
25+
// Get blocks until the Future is completed or context is canceled and
26+
// returns either a result or an error.
27+
Get(context.Context) (T, error)
2928

3029
// Recover handles any error that this Future might contain using a
3130
// resolver function.
@@ -68,16 +67,14 @@ func (fut *futureImpl[T]) accept() {
6867
}
6968

7069
// acceptTimeout blocks once, until the Future result is available or until
71-
// the timeout expires.
72-
func (fut *futureImpl[T]) acceptTimeout(timeout time.Duration) {
70+
// the context is canceled.
71+
func (fut *futureImpl[T]) acceptContext(ctx context.Context) {
7372
fut.acceptOnce.Do(func() {
74-
timer := time.NewTimer(timeout)
75-
defer timer.Stop()
7673
select {
7774
case result := <-fut.done:
7875
fut.setResult(result)
79-
case <-timer.C:
80-
fut.setResult(fmt.Errorf("Future timeout after %s", timeout))
76+
case <-ctx.Done():
77+
fut.setResult(ctx.Err())
8178
}
8279
})
8380
}
@@ -137,10 +134,10 @@ func (fut *futureImpl[T]) Join() (T, error) {
137134
return fut.value, fut.err
138135
}
139136

140-
// Get blocks for at most the given time duration for this Future to
141-
// complete and returns either a result or an error.
142-
func (fut *futureImpl[T]) Get(timeout time.Duration) (T, error) {
143-
fut.acceptTimeout(timeout)
137+
// Get blocks until the Future is completed or context is canceled and
138+
// returns either a result or an error.
139+
func (fut *futureImpl[T]) Get(ctx context.Context) (T, error) {
140+
fut.acceptContext(ctx)
144141
return fut.value, fut.err
145142
}
146143

future_test.go

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package async
22

33
import (
4+
"context"
45
"errors"
56
"fmt"
67
"runtime"
@@ -79,7 +80,7 @@ func TestFuture_Transform(t *testing.T) {
7980
return util.Ptr(5), nil
8081
})
8182

82-
res, _ := future.Get(time.Second * 5)
83+
res, _ := future.Get(context.Background())
8384
assert.Equal(t, 3, *res)
8485

8586
res, _ = future.Join()
@@ -136,11 +137,15 @@ func TestFuture_Timeout(t *testing.T) {
136137
}()
137138
future := p.Future()
138139

139-
_, err := future.Get(10 * time.Millisecond)
140-
assert.ErrorContains(t, err, "timeout")
140+
ctx, cancel := context.WithTimeout(context.Background(),
141+
10*time.Millisecond)
142+
defer cancel()
143+
144+
_, err := future.Get(ctx)
145+
assert.ErrorIs(t, err, context.DeadlineExceeded)
141146

142147
_, err = future.Join()
143-
assert.ErrorContains(t, err, "timeout")
148+
assert.ErrorIs(t, err, context.DeadlineExceeded)
144149
}
145150

146151
func TestFuture_GoroutineLeak(t *testing.T) {
@@ -161,7 +166,7 @@ func TestFuture_GoroutineLeak(t *testing.T) {
161166
go func() {
162167
defer wg.Done()
163168
fut := promise.Future()
164-
_, _ = fut.Get(10 * time.Millisecond)
169+
_, _ = fut.Get(context.Background())
165170
time.Sleep(100 * time.Millisecond)
166171
_, _ = fut.Join()
167172
}()

future_utils.go

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,11 +44,10 @@ func FutureFirstCompletedOf[T any](futures ...Future[T]) Future[T] {
4444
func FutureTimer[T any](d time.Duration) Future[T] {
4545
next := newFuture[T]()
4646
go func() {
47-
timer := time.NewTimer(d)
48-
<-timer.C
47+
<-time.After(d)
4948
var zero T
5049
next.(*futureImpl[T]).
51-
complete(zero, fmt.Errorf("FutureTimer %s timeout", d))
50+
complete(zero, fmt.Errorf("future timeout after %s", d))
5251
}()
5352
return next
5453
}

internal/assert/assertions.go

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package assert
22

33
import (
4+
"errors"
45
"fmt"
56
"reflect"
67
"strings"
@@ -78,6 +79,14 @@ func ErrorContains(t *testing.T, err error, str string) {
7879
}
7980
}
8081

82+
// ErrorIs checks whether any error in err's tree matches target.
83+
func ErrorIs(t *testing.T, err error, target error) {
84+
if !errors.Is(err, target) {
85+
t.Helper()
86+
t.Fatalf("Error type mismatch: %v != %v", err, target)
87+
}
88+
}
89+
8190
// Panics checks whether the given function panics.
8291
func Panics(t *testing.T, f func()) {
8392
defer func() {

0 commit comments

Comments
 (0)