Skip to content

Commit e1dcd31

Browse files
committed
propagate contexts throughout
Signed-off-by: Bora M. Alper <bora@boramalper.org>
1 parent 4eeb856 commit e1dcd31

File tree

14 files changed

+107
-100
lines changed

14 files changed

+107
-100
lines changed

driver.go

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ func WithInputs(inputs ...string) Option {
130130
}
131131
}
132132

133-
func (d *Driver) runMapPhase(job *Job, jobNumber int, inputs []string) {
133+
func (d *Driver) runMapPhase(ctx context.Context, job *Job, jobNumber int, inputs []string) {
134134
inputSplits := job.inputSplits(inputs, d.config.SplitSize)
135135
if len(inputSplits) == 0 {
136136
log.Warnf("No input splits")
@@ -145,15 +145,15 @@ func (d *Driver) runMapPhase(job *Job, jobNumber int, inputs []string) {
145145
var wg sync.WaitGroup
146146
sem := semaphore.NewWeighted(int64(d.config.MaxConcurrency))
147147
for binID, bin := range inputBins {
148-
if err := sem.Acquire(context.Background(), 1); err != nil {
148+
if err := sem.Acquire(ctx, 1); err != nil {
149149
log.Fatal("Failed to acquire semaphore: ", err)
150150
}
151151
wg.Add(1)
152152
go func(bID uint, b []inputSplit) {
153153
defer wg.Done()
154154
defer sem.Release(1)
155155
defer bar.Increment()
156-
err := d.executor.RunMapper(job, jobNumber, bID, b)
156+
err := d.executor.RunMapper(ctx, job, jobNumber, bID, b)
157157
if err != nil {
158158
log.Errorf("Error when running mapper %d: %s", bID, err)
159159
}
@@ -163,15 +163,15 @@ func (d *Driver) runMapPhase(job *Job, jobNumber int, inputs []string) {
163163
bar.Finish()
164164
}
165165

166-
func (d *Driver) runReducePhase(job *Job, jobNumber int) {
166+
func (d *Driver) runReducePhase(ctx context.Context, job *Job, jobNumber int) {
167167
var wg sync.WaitGroup
168168
bar := pb.New(int(job.intermediateBins)).Prefix("Reduce").Start()
169169
for binID := uint(0); binID < job.intermediateBins; binID++ {
170170
wg.Add(1)
171171
go func(bID uint) {
172172
defer wg.Done()
173173
defer bar.Increment()
174-
err := d.executor.RunReducer(job, jobNumber, bID)
174+
err := d.executor.RunReducer(ctx, job, jobNumber, bID)
175175
if err != nil {
176176
log.Errorf("Error when running reducer %d: %s", bID, err)
177177
}
@@ -182,7 +182,7 @@ func (d *Driver) runReducePhase(job *Job, jobNumber int) {
182182
}
183183

184184
// run starts the Driver
185-
func (d *Driver) run() {
185+
func (d *Driver) run(ctx context.Context) {
186186
if runningInLambda() {
187187
lambdaDriver = d
188188
lambda.Start(handleRequest)
@@ -225,8 +225,8 @@ func (d *Driver) run() {
225225
job.outputPath = jobWorkingLoc
226226

227227
*job.config = *d.config
228-
d.runMapPhase(job, idx, inputs)
229-
d.runReducePhase(job, idx)
228+
d.runMapPhase(ctx, job, idx, inputs)
229+
d.runReducePhase(ctx, job, idx)
230230

231231
// Set inputs of next job to be outputs of current job
232232
inputs = []string{job.fileSystem.Join(jobWorkingLoc, "output-*")}
@@ -245,7 +245,7 @@ var undeploy = flag.Bool("undeploy", false, "Undeploy the Lambda function and IA
245245
var undeployKnative = flag.Bool("undeployKnative", false, "Undeploy the Knative service without running the driver")
246246

247247
// Main starts the Driver, running the submitted jobs.
248-
func (d *Driver) Main() {
248+
func (d *Driver) Main(ctx context.Context) {
249249
if viper.GetBool("verbose") {
250250
log.SetLevel(log.DebugLevel)
251251
}
@@ -272,7 +272,7 @@ func (d *Driver) Main() {
272272
}
273273

274274
start := time.Now()
275-
d.run()
275+
d.run(ctx)
276276
end := time.Now()
277277
fmt.Printf("Job Execution Time: %s\n", end.Sub(start))
278278

driver_test.go

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package corral
22

33
import (
4+
"context"
45
"fmt"
56
"io/ioutil"
67
"os"
@@ -30,20 +31,20 @@ func TestNewDriver(t *testing.T) {
3031

3132
type testWCJob struct{}
3233

33-
func (testWCJob) Map(key, value string, emitter Emitter) {
34+
func (testWCJob) Map(ctx context.Context, key, value string, emitter Emitter) {
3435
for _, word := range strings.Fields(value) {
35-
if err := emitter.Emit(word, "1"); err != nil {
36+
if err := emitter.Emit(ctx, word, "1"); err != nil {
3637
panic(err)
3738
}
3839
}
3940
}
4041

41-
func (testWCJob) Reduce(key string, values ValueIterator, emitter Emitter) {
42+
func (testWCJob) Reduce(ctx context.Context, key string, values ValueIterator, emitter Emitter) {
4243
count := 0
4344
for range values.Iter() {
4445
count++
4546
}
46-
if err := emitter.Emit(key, fmt.Sprintf("%d", count)); err != nil {
47+
if err := emitter.Emit(ctx, key, fmt.Sprintf("%d", count)); err != nil {
4748
panic(err)
4849
}
4950
}
@@ -52,18 +53,18 @@ type testFilterJob struct {
5253
prefix string
5354
}
5455

55-
func (j *testFilterJob) Map(key, value string, emitter Emitter) {
56+
func (j *testFilterJob) Map(ctx context.Context, key, value string, emitter Emitter) {
5657
if strings.HasPrefix(key, j.prefix) {
57-
if err := emitter.Emit(key, value); err != nil {
58+
if err := emitter.Emit(ctx, key, value); err != nil {
5859
panic(err)
5960
}
6061
}
6162
}
6263

63-
func (j *testFilterJob) Reduce(key string, values ValueIterator, emitter Emitter) {
64+
func (j *testFilterJob) Reduce(ctx context.Context, key string, values ValueIterator, emitter Emitter) {
6465
// Identity reducer
6566
for value := range values.Iter() {
66-
if err := emitter.Emit(key, value); err != nil {
67+
if err := emitter.Emit(ctx, key, value); err != nil {
6768
panic(err)
6869
}
6970
}
@@ -101,7 +102,7 @@ func TestLocalMapReduce(t *testing.T) {
101102
WithWorkingLocation(tmpdir),
102103
)
103104

104-
driver.Main()
105+
driver.Main(context.Background())
105106

106107
output, err := ioutil.ReadFile(filepath.Join(tmpdir, "output-part-0"))
107108
assert.Nil(t, err)
@@ -142,7 +143,7 @@ func TestLocalMultiJob(t *testing.T) {
142143
WithWorkingLocation(tmpdir),
143144
)
144145

145-
driver.Main()
146+
driver.Main(context.Background())
146147

147148
output, err := ioutil.ReadFile(filepath.Join(tmpdir, "job1", "output-part-0"))
148149
assert.Nil(t, err)
@@ -167,32 +168,32 @@ func TestLocalNoCrashOnNoResolvedInputFiles(t *testing.T) {
167168
WithWorkingLocation("some_file"),
168169
)
169170

170-
driver.Main()
171+
driver.Main(context.Background())
171172
}
172173

173174
type statefulJob struct {
174175
filterWords *[]string
175176
}
176177

177-
func (s statefulJob) Map(key, value string, emitter Emitter) {
178+
func (s statefulJob) Map(ctx context.Context, key, value string, emitter Emitter) {
178179
for _, word := range strings.Fields(value) {
179180
for _, filterWord := range *s.filterWords {
180181
if filterWord != word {
181182
continue
182183
}
183-
if err := emitter.Emit(word, "1"); err != nil {
184+
if err := emitter.Emit(ctx, word, "1"); err != nil {
184185
panic(err)
185186
}
186187
}
187188
}
188189
}
189190

190-
func (statefulJob) Reduce(key string, values ValueIterator, emitter Emitter) {
191+
func (statefulJob) Reduce(ctx context.Context, key string, values ValueIterator, emitter Emitter) {
191192
count := 0
192193
for range values.Iter() {
193194
count++
194195
}
195-
if err := emitter.Emit(key, fmt.Sprintf("%d", count)); err != nil {
196+
if err := emitter.Emit(ctx, key, fmt.Sprintf("%d", count)); err != nil {
196197
panic(err)
197198
}
198199
}
@@ -213,7 +214,7 @@ func TestLocalStructFieldMapReduce(t *testing.T) {
213214
WithWorkingLocation(tmpdir),
214215
)
215216

216-
driver.Main()
217+
driver.Main(context.Background())
217218

218219
output, err := ioutil.ReadFile(filepath.Join(tmpdir, "output-part-0"))
219220
assert.Nil(t, err)

emitter.go

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package corral
22

33
import (
4+
context "context"
45
"encoding/json"
56
"errors"
67
"fmt"
@@ -15,7 +16,7 @@ import (
1516

1617
// Emitter enables mappers and reducers to yield key-value pairs.
1718
type Emitter interface {
18-
Emit(key, value string) error
19+
Emit(ctx context.Context, key, value string) error
1920
close() error
2021
bytesWritten() int64
2122
}
@@ -36,7 +37,7 @@ func newReducerEmitter(writer io.WriteCloser) *reducerEmitter {
3637
}
3738

3839
// Emit yields a key-value pair to the framework.
39-
func (e *reducerEmitter) Emit(key, value string) error {
40+
func (e *reducerEmitter) Emit(ctx context.Context, key, value string) error {
4041
e.mut.Lock()
4142
defer e.mut.Unlock()
4243

@@ -87,7 +88,7 @@ func hashPartition(key string, numBins uint) uint {
8788
}
8889

8990
// Emit yields a key-value pair to the framework.
90-
func (me *mapperEmitter) Emit(key, value string) error {
91+
func (me *mapperEmitter) Emit(ctx context.Context, key, value string) error {
9192
bin := me.partitionFunc(key, me.numBins)
9293

9394
// Open writer for the bin, if necessary

emitter_test.go

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package corral
22

33
import (
44
"bytes"
5+
"context"
56
"fmt"
67
"io"
78
"io/ioutil"
@@ -31,7 +32,7 @@ func TestReducerEmitter(t *testing.T) {
3132
writer := &testWriteCloser{new(bytes.Buffer)}
3233
emitter := newReducerEmitter(writer)
3334

34-
err := emitter.Emit("key", "value")
35+
err := emitter.Emit(context.Background(), "key", "value")
3536
assert.Nil(t, err)
3637

3738
written, err := ioutil.ReadAll(writer)
@@ -51,7 +52,7 @@ func TestReducerEmitterThreadSafety(t *testing.T) {
5152
wg.Add(1)
5253
go func(key int) {
5354
defer wg.Done()
54-
err := emitter.Emit(fmt.Sprint(key), "value")
55+
err := emitter.Emit(context.Background(), fmt.Sprint(key), "value")
5556
assert.Nil(t, err)
5657
}(i)
5758
}
@@ -108,13 +109,13 @@ func TestMapperEmitter(t *testing.T) {
108109
var fs corfs.FileSystem = mFs
109110
emitter := newMapperEmitter(3, 0, "out", fs)
110111

111-
err := emitter.Emit("key1", "val1")
112+
err := emitter.Emit(context.Background(), "key1", "val1")
112113
assert.Nil(t, err)
113114

114-
err = emitter.Emit("key123", "val2")
115+
err = emitter.Emit(context.Background(), "key123", "val2")
115116
assert.Nil(t, err)
116117

117-
err = emitter.Emit("key359", "val3")
118+
err = emitter.Emit(context.Background(), "key359", "val3")
118119
assert.Nil(t, err)
119120

120121
assert.Len(t, mFs.writers, 3)
@@ -137,13 +138,13 @@ func TestMapperEmitterCustomPartition(t *testing.T) {
137138
return numBuckets - 1
138139
}
139140

140-
err := emitter.Emit("a", "val1")
141+
err := emitter.Emit(context.Background(), "a", "val1")
141142
assert.Nil(t, err)
142143

143-
err = emitter.Emit("a", "val2")
144+
err = emitter.Emit(context.Background(), "a", "val2")
144145
assert.Nil(t, err)
145146

146-
err = emitter.Emit("b", "val3")
147+
err = emitter.Emit(context.Background(), "b", "val3")
147148
assert.Nil(t, err)
148149

149150
assert.Len(t, mFs.writers, 2)

examples/amplab1/amplab1.go

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package main
22

33
import (
4+
"context"
45
"fmt"
56
"strconv"
67
"strings"
@@ -12,7 +13,7 @@ const pageRankCutoff = 50
1213

1314
type amplab1 struct{}
1415

15-
func (a amplab1) Map(key, value string, emitter corral.Emitter) {
16+
func (a amplab1) Map(ctx context.Context, key, value string, emitter corral.Emitter) {
1617
fields := strings.Split(value, ",")
1718
if len(fields) != 3 {
1819
fmt.Printf("Invalid record: '%s'\n", value)
@@ -22,19 +23,19 @@ func (a amplab1) Map(key, value string, emitter corral.Emitter) {
2223
pageURL := fields[0]
2324
pageRank, err := strconv.Atoi(fields[1])
2425
if err == nil && pageRank > pageRankCutoff {
25-
emitter.Emit(pageURL, fields[1])
26+
emitter.Emit(ctx, pageURL, fields[1])
2627
}
2728
}
2829

29-
func (a amplab1) Reduce(key string, values corral.ValueIterator, emitter corral.Emitter) {
30+
func (a amplab1) Reduce(ctx context.Context, key string, values corral.ValueIterator, emitter corral.Emitter) {
3031
for value := range values.Iter() {
31-
emitter.Emit(key, value)
32+
emitter.Emit(ctx, key, value)
3233
}
3334
}
3435

3536
func main() {
3637
job := corral.NewJob(amplab1{}, amplab1{})
3738

3839
driver := corral.NewDriver(job)
39-
driver.Main()
40+
driver.Main(context.Background())
4041
}

examples/amplab2/amplab2.go

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package main
22

33
import (
4+
"context"
45
"fmt"
56
"strconv"
67
"strings"
@@ -19,7 +20,7 @@ func min(a, b int) int {
1920
return b
2021
}
2122

22-
func (a amplab2) Map(key, value string, emitter corral.Emitter) {
23+
func (a amplab2) Map(ctx context.Context, key, value string, emitter corral.Emitter) {
2324
fields := strings.Split(value, ",")
2425
if len(fields) != 9 {
2526
fmt.Printf("Invalid record: '%s'\n", value)
@@ -28,23 +29,23 @@ func (a amplab2) Map(key, value string, emitter corral.Emitter) {
2829

2930
sourceIP := fields[0]
3031
adRevenue := fields[3]
31-
emitter.Emit(sourceIP[:min(subStrX, len(sourceIP))], adRevenue)
32+
emitter.Emit(ctx, sourceIP[:min(subStrX, len(sourceIP))], adRevenue)
3233
}
3334

34-
func (a amplab2) Reduce(key string, values corral.ValueIterator, emitter corral.Emitter) {
35+
func (a amplab2) Reduce(ctx context.Context, key string, values corral.ValueIterator, emitter corral.Emitter) {
3536
totalRevenue := 0.0
3637
for value := range values.Iter() {
3738
adRevenue, err := strconv.ParseFloat(value, 64)
3839
if err == nil {
3940
totalRevenue += adRevenue
4041
}
4142
}
42-
emitter.Emit(key, fmt.Sprintf("%f", totalRevenue))
43+
emitter.Emit(ctx, key, fmt.Sprintf("%f", totalRevenue))
4344
}
4445

4546
func main() {
4647
job := corral.NewJob(amplab2{}, amplab2{})
4748

4849
driver := corral.NewDriver(job)
49-
driver.Main()
50+
driver.Main(context.Background())
5051
}

0 commit comments

Comments
 (0)