Skip to content

Commit d7766f8

Browse files
authored
feat: implement parallel diffing (filecoin-project#100)
* feat: implement parallel diffing with worker limiting
1 parent 02820af commit d7766f8

File tree

5 files changed

+403
-22
lines changed

5 files changed

+403
-22
lines changed

diff_parallel.go

Lines changed: 356 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,356 @@
1+
package hamt
2+
3+
import (
4+
"bytes"
5+
"context"
6+
"sync"
7+
8+
"github.com/ipfs/go-cid"
9+
cbor "github.com/ipfs/go-ipld-cbor"
10+
cbg "github.com/whyrusleeping/cbor-gen"
11+
"golang.org/x/sync/errgroup"
12+
"golang.org/x/xerrors"
13+
)
14+
15+
// ParallelDiff returns a set of changes that transform node 'prev' into node 'cur'. opts are applied to both prev and cur.
16+
func ParallelDiff(ctx context.Context, prevBs, curBs cbor.IpldStore, prev, cur cid.Cid, workers int64, opts ...Option) ([]*Change, error) {
17+
if prev.Equals(cur) {
18+
return nil, nil
19+
}
20+
21+
prevHamt, err := LoadNode(ctx, prevBs, prev, opts...)
22+
if err != nil {
23+
return nil, err
24+
}
25+
26+
curHamt, err := LoadNode(ctx, curBs, cur, opts...)
27+
if err != nil {
28+
return nil, err
29+
}
30+
31+
if curHamt.bitWidth != prevHamt.bitWidth {
32+
return nil, xerrors.Errorf("diffing HAMTs with differing bitWidths not supported (prev=%d, cur=%d)", prevHamt.bitWidth, curHamt.bitWidth)
33+
}
34+
35+
return doParallelDiffNode(ctx, prevHamt, curHamt, workers)
36+
}
37+
38+
func doParallelDiffNode(ctx context.Context, pre, cur *Node, workers int64) ([]*Change, error) {
39+
bp := cur.Bitfield.BitLen()
40+
if pre.Bitfield.BitLen() > bp {
41+
bp = pre.Bitfield.BitLen()
42+
}
43+
44+
initTasks := []*task{}
45+
for idx := bp; idx >= 0; idx-- {
46+
preBit := pre.Bitfield.Bit(idx)
47+
curBit := cur.Bitfield.Bit(idx)
48+
initTasks = append(initTasks, &task{
49+
idx: idx,
50+
pre: pre,
51+
preBit: preBit,
52+
cur: cur,
53+
curBit: curBit,
54+
})
55+
}
56+
57+
out := make(chan *Change, 2*workers)
58+
differ, ctx := newDiffScheduler(ctx, workers, initTasks...)
59+
differ.startWorkers(ctx, out)
60+
differ.startScheduler(ctx)
61+
62+
var changes []*Change
63+
done := make(chan struct{})
64+
go func() {
65+
defer close(done)
66+
for change := range out {
67+
changes = append(changes, change)
68+
}
69+
}()
70+
71+
err := differ.grp.Wait()
72+
close(out)
73+
<-done
74+
75+
return changes, err
76+
}
77+
78+
type task struct {
79+
idx int
80+
81+
pre *Node
82+
preBit uint
83+
84+
cur *Node
85+
curBit uint
86+
}
87+
88+
func newDiffScheduler(ctx context.Context, numWorkers int64, rootTasks ...*task) (*diffScheduler, context.Context) {
89+
grp, ctx := errgroup.WithContext(ctx)
90+
s := &diffScheduler{
91+
numWorkers: numWorkers,
92+
stack: rootTasks,
93+
in: make(chan *task, numWorkers),
94+
out: make(chan *task, numWorkers),
95+
grp: grp,
96+
}
97+
s.taskWg.Add(len(rootTasks))
98+
return s, ctx
99+
}
100+
101+
type diffScheduler struct {
102+
// number of worker routine to spawn
103+
numWorkers int64
104+
// buffer holds tasks until they are processed
105+
stack []*task
106+
// inbound and outbound tasks
107+
in, out chan *task
108+
// tracks number of inflight tasks
109+
taskWg sync.WaitGroup
110+
// launches workers and collects errors if any occur
111+
grp *errgroup.Group
112+
}
113+
114+
func (s *diffScheduler) enqueueTask(task *task) {
115+
s.taskWg.Add(1)
116+
s.in <- task
117+
}
118+
119+
func (s *diffScheduler) startScheduler(ctx context.Context) {
120+
s.grp.Go(func() error {
121+
defer func() {
122+
close(s.out)
123+
// Because the workers may have exited early (due to the context being canceled).
124+
for range s.out {
125+
s.taskWg.Done()
126+
}
127+
// Because the workers may have enqueued additional tasks.
128+
for range s.in {
129+
s.taskWg.Done()
130+
}
131+
// now, the waitgroup should be at 0, and the goroutine that was _waiting_ on it should have exited.
132+
}()
133+
go func() {
134+
s.taskWg.Wait()
135+
close(s.in)
136+
}()
137+
for {
138+
if n := len(s.stack) - 1; n >= 0 {
139+
select {
140+
case <-ctx.Done():
141+
return ctx.Err()
142+
case newJob, ok := <-s.in:
143+
if !ok {
144+
return nil
145+
}
146+
s.stack = append(s.stack, newJob)
147+
case s.out <- s.stack[n]:
148+
s.stack[n] = nil
149+
s.stack = s.stack[:n]
150+
}
151+
} else {
152+
select {
153+
case <-ctx.Done():
154+
return ctx.Err()
155+
case newJob, ok := <-s.in:
156+
if !ok {
157+
return nil
158+
}
159+
s.stack = append(s.stack, newJob)
160+
}
161+
}
162+
}
163+
})
164+
}
165+
166+
func (s *diffScheduler) startWorkers(ctx context.Context, out chan *Change) {
167+
for i := int64(0); i < s.numWorkers; i++ {
168+
s.grp.Go(func() error {
169+
for task := range s.out {
170+
if err := s.work(ctx, task, out); err != nil {
171+
return err
172+
}
173+
}
174+
return nil
175+
})
176+
}
177+
}
178+
179+
func (s *diffScheduler) work(ctx context.Context, todo *task, results chan *Change) error {
180+
defer s.taskWg.Done()
181+
idx := todo.idx
182+
preBit := todo.preBit
183+
pre := todo.pre
184+
curBit := todo.curBit
185+
cur := todo.cur
186+
187+
switch {
188+
case preBit == 1 && curBit == 1:
189+
// index for pre and cur will be unique to each, calculate it here.
190+
prePointer := pre.getPointer(byte(pre.indexForBitPos(idx)))
191+
curPointer := cur.getPointer(byte(cur.indexForBitPos(idx)))
192+
switch {
193+
// both pointers are shards, recurse down the tree.
194+
case prePointer.isShard() && curPointer.isShard():
195+
preChild, err := prePointer.loadChild(ctx, pre.store, pre.bitWidth, pre.hash)
196+
if err != nil {
197+
return err
198+
}
199+
curChild, err := curPointer.loadChild(ctx, cur.store, cur.bitWidth, cur.hash)
200+
if err != nil {
201+
return err
202+
}
203+
204+
bp := curChild.Bitfield.BitLen()
205+
if preChild.Bitfield.BitLen() > bp {
206+
bp = preChild.Bitfield.BitLen()
207+
}
208+
for idx := bp; idx >= 0; idx-- {
209+
preBit := preChild.Bitfield.Bit(idx)
210+
curBit := curChild.Bitfield.Bit(idx)
211+
s.enqueueTask(&task{
212+
idx: idx,
213+
pre: preChild,
214+
preBit: preBit,
215+
cur: curChild,
216+
curBit: curBit,
217+
})
218+
}
219+
220+
// check if KV's from cur exists in any children of pre's child.
221+
case prePointer.isShard() && !curPointer.isShard():
222+
childKV, err := prePointer.loadChildKVs(ctx, pre.store, pre.bitWidth, pre.hash)
223+
if err != nil {
224+
return err
225+
}
226+
parallelDiffKVs(childKV, curPointer.KVs, results)
227+
228+
// check if KV's from pre exists in any children of cur's child.
229+
case !prePointer.isShard() && curPointer.isShard():
230+
childKV, err := curPointer.loadChildKVs(ctx, cur.store, cur.bitWidth, cur.hash)
231+
if err != nil {
232+
return err
233+
}
234+
parallelDiffKVs(prePointer.KVs, childKV, results)
235+
236+
// both contain KVs, compare.
237+
case !prePointer.isShard() && !curPointer.isShard():
238+
parallelDiffKVs(prePointer.KVs, curPointer.KVs, results)
239+
}
240+
case preBit == 1 && curBit == 0:
241+
// there exists a value in previous not found in current - it was removed
242+
pointer := pre.getPointer(byte(pre.indexForBitPos(idx)))
243+
244+
if pointer.isShard() {
245+
child, err := pointer.loadChild(ctx, pre.store, pre.bitWidth, pre.hash)
246+
if err != nil {
247+
return err
248+
}
249+
err = parallelRemoveAll(ctx, child, results)
250+
if err != nil {
251+
return err
252+
}
253+
} else {
254+
for _, p := range pointer.KVs {
255+
results <- &Change{
256+
Type: Remove,
257+
Key: string(p.Key),
258+
Before: p.Value,
259+
After: nil,
260+
}
261+
}
262+
}
263+
case preBit == 0 && curBit == 1:
264+
// there exists a value in current not found in previous - it was added
265+
pointer := cur.getPointer(byte(cur.indexForBitPos(idx)))
266+
267+
if pointer.isShard() {
268+
child, err := pointer.loadChild(ctx, pre.store, pre.bitWidth, pre.hash)
269+
if err != nil {
270+
return err
271+
}
272+
err = parallelAddAll(ctx, child, results)
273+
if err != nil {
274+
return err
275+
}
276+
} else {
277+
for _, p := range pointer.KVs {
278+
results <- &Change{
279+
Type: Add,
280+
Key: string(p.Key),
281+
Before: nil,
282+
After: p.Value,
283+
}
284+
}
285+
}
286+
}
287+
return nil
288+
}
289+
290+
func parallelDiffKVs(pre, cur []*KV, out chan *Change) {
291+
preMap := make(map[string]*cbg.Deferred, len(pre))
292+
curMap := make(map[string]*cbg.Deferred, len(cur))
293+
294+
for _, kv := range pre {
295+
preMap[string(kv.Key)] = kv.Value
296+
}
297+
for _, kv := range cur {
298+
curMap[string(kv.Key)] = kv.Value
299+
}
300+
// find removed keys: keys in pre and not in cur
301+
for key, value := range preMap {
302+
if _, ok := curMap[key]; !ok {
303+
out <- &Change{
304+
Type: Remove,
305+
Key: key,
306+
Before: value,
307+
After: nil,
308+
}
309+
}
310+
}
311+
// find added keys: keys in cur and not in pre
312+
// find modified values: keys in cur and pre with different values
313+
for key, curVal := range curMap {
314+
if preVal, ok := preMap[key]; !ok {
315+
out <- &Change{
316+
Type: Add,
317+
Key: key,
318+
Before: nil,
319+
After: curVal,
320+
}
321+
} else {
322+
if !bytes.Equal(preVal.Raw, curVal.Raw) {
323+
out <- &Change{
324+
Type: Modify,
325+
Key: key,
326+
Before: preVal,
327+
After: curVal,
328+
}
329+
}
330+
}
331+
}
332+
}
333+
334+
func parallelAddAll(ctx context.Context, node *Node, out chan *Change) error {
335+
return node.ForEach(ctx, func(k string, val *cbg.Deferred) error {
336+
out <- &Change{
337+
Type: Add,
338+
Key: k,
339+
Before: nil,
340+
After: val,
341+
}
342+
return nil
343+
})
344+
}
345+
346+
func parallelRemoveAll(ctx context.Context, node *Node, out chan *Change) error {
347+
return node.ForEach(ctx, func(k string, val *cbg.Deferred) error {
348+
out <- &Change{
349+
Type: Remove,
350+
Key: k,
351+
Before: val,
352+
After: nil,
353+
}
354+
return nil
355+
})
356+
}

0 commit comments

Comments
 (0)