Skip to content

Commit 62d6447

Browse files
committed
Diff methods that use tracked and node sinking ForEach methods
1 parent a06e1cc commit 62d6447

File tree

2 files changed

+267
-10
lines changed

2 files changed

+267
-10
lines changed

diff.go

Lines changed: 251 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,11 @@ const (
2424
// Change represents a change to a DAG and contains a reference to the old and
2525
// new CIDs.
2626
type Change struct {
27-
Type ChangeType
28-
Key string
29-
Before *cbg.Deferred
30-
After *cbg.Deferred
27+
Type ChangeType
28+
Key string
29+
Before *cbg.Deferred
30+
After *cbg.Deferred
31+
SelectorSuffix []int
3132
}
3233

3334
func (ch Change) String() string {
@@ -57,6 +58,29 @@ func Diff(ctx context.Context, prevBs, curBs cbor.IpldStore, prev, cur cid.Cid,
5758
return diffNode(ctx, prevHamt, curHamt, 0)
5859
}
5960

61+
// DiffTrackedWithNodeSink returns a set of changes that transform node 'prev' into node 'cur'. opts are applied to both prev and cur.
62+
// it associates selector suffixes with the emitted Change set and sinks all unique nodes encountered under the current CID to the provided CBORUnmarshaler
63+
func DiffTrackedWithNodeSink(ctx context.Context, prevBs, curBs cbor.IpldStore, prev, cur cid.Cid, b *bytes.Buffer, sink cbg.CBORUnmarshaler, trail []int, opts ...Option) ([]*Change, error) {
64+
if prev.Equals(cur) {
65+
return nil, nil
66+
}
67+
68+
prevHamt, err := LoadNode(ctx, prevBs, prev, opts...)
69+
if err != nil {
70+
return nil, err
71+
}
72+
73+
curHamt, err := LoadNode(ctx, curBs, cur, opts...)
74+
if err != nil {
75+
return nil, err
76+
}
77+
78+
if curHamt.bitWidth != prevHamt.bitWidth {
79+
return nil, xerrors.Errorf("diffing HAMTs with differing bitWidths not supported (prev=%d, cur=%d)", prevHamt.bitWidth, curHamt.bitWidth)
80+
}
81+
return diffNodeTrackedWithNodeSink(ctx, prevHamt, curHamt, 0, b, sink, trail)
82+
}
83+
6084
func diffNode(ctx context.Context, pre, cur *Node, depth int) ([]*Change, error) {
6185
// which Bitfield contains the most bits. We will start a loop from this index, calling Bitfield.Bit(idx)
6286
// on an out of range index will return zero.
@@ -176,6 +200,144 @@ func diffNode(ctx context.Context, pre, cur *Node, depth int) ([]*Change, error)
176200
return changes, nil
177201
}
178202

203+
func diffNodeTrackedWithNodeSink(ctx context.Context, pre, cur *Node, depth int, b *bytes.Buffer, sink cbg.CBORUnmarshaler, trail []int) ([]*Change, error) {
204+
// which Bitfield contains the most bits. We will start a loop from this index, calling Bitfield.Bit(idx)
205+
// on an out of range index will return zero.
206+
bp := cur.Bitfield.BitLen()
207+
if pre.Bitfield.BitLen() > bp {
208+
bp = pre.Bitfield.BitLen()
209+
}
210+
211+
if sink != nil {
212+
if b == nil {
213+
b = bytes.NewBuffer(nil)
214+
}
215+
b.Reset()
216+
if err := cur.MarshalCBOR(b); err != nil {
217+
return nil, err
218+
}
219+
if err := sink.UnmarshalCBOR(b); err != nil {
220+
return nil, err
221+
}
222+
}
223+
224+
// the changes between cur and prev
225+
var changes []*Change
226+
l := len(trail)
227+
// loop over each bit in the bitfields
228+
for idx := bp; idx >= 0; idx-- {
229+
preBit := pre.Bitfield.Bit(idx)
230+
curBit := cur.Bitfield.Bit(idx)
231+
232+
subTrail := make([]int, l, l+1)
233+
copy(subTrail, trail)
234+
subTrail = append(subTrail, idx)
235+
236+
if preBit == 1 && curBit == 1 {
237+
// index for pre and cur will be unique to each, calculate it here.
238+
prePointer := pre.getPointer(byte(pre.indexForBitPos(idx)))
239+
curPointer := cur.getPointer(byte(cur.indexForBitPos(idx)))
240+
241+
// both pointers are shards, recurse down the tree.
242+
if prePointer.isShard() && curPointer.isShard() {
243+
if prePointer.Link == curPointer.Link {
244+
continue
245+
}
246+
preChild, err := prePointer.loadChild(ctx, pre.store, pre.bitWidth, pre.hash)
247+
if err != nil {
248+
return nil, err
249+
}
250+
curChild, err := curPointer.loadChild(ctx, cur.store, cur.bitWidth, cur.hash)
251+
if err != nil {
252+
return nil, err
253+
}
254+
255+
change, err := diffNodeTrackedWithNodeSink(ctx, preChild, curChild, depth+1, b, sink, subTrail)
256+
if err != nil {
257+
return nil, err
258+
}
259+
changes = append(changes, change...)
260+
}
261+
262+
// check if KV's from cur exists in any children of pre's child.
263+
if prePointer.isShard() && !curPointer.isShard() {
264+
childKV, err := prePointer.loadChildKVs(ctx, pre.store, pre.bitWidth, pre.hash)
265+
if err != nil {
266+
return nil, err
267+
}
268+
changes = append(changes, diffKVsTracked(childKV, curPointer.KVs, idx, subTrail)...)
269+
270+
}
271+
272+
// check if KV's from pre exists in any children of cur's child.
273+
if !prePointer.isShard() && curPointer.isShard() {
274+
childKV, err := curPointer.loadChildKVs(ctx, cur.store, cur.bitWidth, cur.hash)
275+
if err != nil {
276+
return nil, err
277+
}
278+
changes = append(changes, diffKVsTracked(prePointer.KVs, childKV, idx, subTrail)...)
279+
}
280+
281+
// both contain KVs, compare.
282+
if !prePointer.isShard() && !curPointer.isShard() {
283+
changes = append(changes, diffKVsTracked(prePointer.KVs, curPointer.KVs, idx, subTrail)...)
284+
}
285+
} else if preBit == 1 && curBit == 0 {
286+
// there exists a value in previous not found in current - it was removed
287+
pointer := pre.getPointer(byte(pre.indexForBitPos(idx)))
288+
289+
if pointer.isShard() {
290+
child, err := pointer.loadChild(ctx, pre.store, pre.bitWidth, pre.hash)
291+
if err != nil {
292+
return nil, err
293+
}
294+
rm, err := removeAllTracked(ctx, child, idx, subTrail)
295+
if err != nil {
296+
return nil, err
297+
}
298+
changes = append(changes, rm...)
299+
} else {
300+
for _, p := range pointer.KVs {
301+
changes = append(changes, &Change{
302+
Type: Remove,
303+
Key: string(p.Key),
304+
Before: p.Value,
305+
After: nil,
306+
SelectorSuffix: subTrail,
307+
})
308+
}
309+
}
310+
} else if curBit == 1 && preBit == 0 {
311+
// there exists a value in current not found in previous - it was added
312+
pointer := cur.getPointer(byte(cur.indexForBitPos(idx)))
313+
314+
if pointer.isShard() {
315+
child, err := pointer.loadChild(ctx, pre.store, pre.bitWidth, pre.hash)
316+
if err != nil {
317+
return nil, err
318+
}
319+
add, err := addAllTrackWithNodeSink(ctx, child, idx, b, sink, subTrail)
320+
if err != nil {
321+
return nil, err
322+
}
323+
changes = append(changes, add...)
324+
} else {
325+
for _, p := range pointer.KVs {
326+
changes = append(changes, &Change{
327+
Type: Add,
328+
Key: string(p.Key),
329+
Before: nil,
330+
After: p.Value,
331+
SelectorSuffix: subTrail,
332+
})
333+
}
334+
}
335+
}
336+
}
337+
338+
return changes, nil
339+
}
340+
179341
func diffKVs(pre, cur []*KV, idx int) []*Change {
180342
preMap := make(map[string]*cbg.Deferred, len(pre))
181343
curMap := make(map[string]*cbg.Deferred, len(cur))
@@ -222,6 +384,55 @@ func diffKVs(pre, cur []*KV, idx int) []*Change {
222384
return changes
223385
}
224386

387+
func diffKVsTracked(pre, cur []*KV, idx int, trail []int) []*Change {
388+
preMap := make(map[string]*cbg.Deferred, len(pre))
389+
curMap := make(map[string]*cbg.Deferred, len(cur))
390+
var changes []*Change
391+
392+
for _, kv := range pre {
393+
preMap[string(kv.Key)] = kv.Value
394+
}
395+
for _, kv := range cur {
396+
curMap[string(kv.Key)] = kv.Value
397+
}
398+
// find removed keys: keys in pre and not in cur
399+
for key, value := range preMap {
400+
if _, ok := curMap[key]; !ok {
401+
changes = append(changes, &Change{
402+
Type: Remove,
403+
Key: key,
404+
Before: value,
405+
After: nil,
406+
SelectorSuffix: trail,
407+
})
408+
}
409+
}
410+
// find added keys: keys in cur and not in pre
411+
// find modified values: keys in cur and pre with different values
412+
for key, curVal := range curMap {
413+
if preVal, ok := preMap[key]; !ok {
414+
changes = append(changes, &Change{
415+
Type: Add,
416+
Key: key,
417+
Before: nil,
418+
After: curVal,
419+
SelectorSuffix: trail,
420+
})
421+
} else {
422+
if !bytes.Equal(preVal.Raw, curVal.Raw) {
423+
changes = append(changes, &Change{
424+
Type: Modify,
425+
Key: key,
426+
Before: preVal,
427+
After: curVal,
428+
SelectorSuffix: trail,
429+
})
430+
}
431+
}
432+
}
433+
return changes
434+
}
435+
225436
func addAll(ctx context.Context, node *Node, idx int) ([]*Change, error) {
226437
var changes []*Change
227438
if err := node.ForEach(ctx, func(k string, val *cbg.Deferred) error {
@@ -239,6 +450,24 @@ func addAll(ctx context.Context, node *Node, idx int) ([]*Change, error) {
239450
return changes, nil
240451
}
241452

453+
func addAllTrackWithNodeSink(ctx context.Context, node *Node, idx int, b *bytes.Buffer, sink cbg.CBORUnmarshaler, trail []int) ([]*Change, error) {
454+
var changes []*Change
455+
if err := node.ForEachTrackedWithNodeSink(ctx, trail, b, sink, func(k string, val *cbg.Deferred, selectorSuffix []int) error {
456+
changes = append(changes, &Change{
457+
Type: Add,
458+
Key: k,
459+
Before: nil,
460+
After: val,
461+
SelectorSuffix: selectorSuffix,
462+
})
463+
464+
return nil
465+
}); err != nil {
466+
return nil, err
467+
}
468+
return changes, nil
469+
}
470+
242471
func removeAll(ctx context.Context, node *Node, idx int) ([]*Change, error) {
243472
var changes []*Change
244473
if err := node.ForEach(ctx, func(k string, val *cbg.Deferred) error {
@@ -255,3 +484,21 @@ func removeAll(ctx context.Context, node *Node, idx int) ([]*Change, error) {
255484
}
256485
return changes, nil
257486
}
487+
488+
func removeAllTracked(ctx context.Context, node *Node, idx int, trail []int) ([]*Change, error) {
489+
var changes []*Change
490+
if err := node.ForEachTracked(ctx, trail, func(k string, val *cbg.Deferred, selectorSuffix []int) error {
491+
changes = append(changes, &Change{
492+
Type: Remove,
493+
Key: k,
494+
Before: val,
495+
After: nil,
496+
SelectorSuffix: selectorSuffix,
497+
})
498+
499+
return nil
500+
}); err != nil {
501+
return nil, err
502+
}
503+
return changes, nil
504+
}

hamt.go

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -885,24 +885,29 @@ func (n *Node) ForEach(ctx context.Context, f func(k string, val *cbg.Deferred)
885885
// This method also provides the trail of indices to the current node, which can be used to formulate a selector suffix
886886
func (n *Node) ForEachTracked(ctx context.Context, trail []int, f func(k string, val *cbg.Deferred, selectorSuffix []int) error) error {
887887
idx := 0
888+
l := len(trail)
888889
for _, p := range n.Pointers {
889890
// Seek the next set bit in the bitfield to find the actual index for this pointer
890891
for n.Bitfield.Bit(idx) == 0 {
891892
idx++
892893
}
893-
trail = append(trail, idx)
894+
895+
subTrail := make([]int, l, l+1)
896+
copy(subTrail, trail)
897+
subTrail = append(subTrail, idx)
898+
894899
if p.isShard() {
895900
chnd, err := p.loadChild(ctx, n.store, n.bitWidth, n.hash)
896901
if err != nil {
897902
return err
898903
}
899904

900-
if err := chnd.ForEachTracked(ctx, trail, f); err != nil {
905+
if err := chnd.ForEachTracked(ctx, subTrail, f); err != nil {
901906
return err
902907
}
903908
} else {
904909
for _, kv := range p.KVs {
905-
if err := f(string(kv.Key), kv.Value, trail); err != nil {
910+
if err := f(string(kv.Key), kv.Value, subTrail); err != nil {
906911
return err
907912
}
908913
}
@@ -932,24 +937,29 @@ func (n *Node) ForEachTrackedWithNodeSink(ctx context.Context, trail []int, b *b
932937
}
933938
}
934939
idx := 0
940+
l := len(trail)
935941
for _, p := range n.Pointers {
936942
// Seek the next set bit in the bitfield to find the actual index for this pointer
937943
for n.Bitfield.Bit(idx) == 0 {
938944
idx++
939945
}
940-
trail = append(trail, idx)
946+
947+
subTrail := make([]int, l, l+1)
948+
copy(subTrail, trail)
949+
subTrail = append(subTrail, idx)
950+
941951
if p.isShard() {
942952
chnd, err := p.loadChild(ctx, n.store, n.bitWidth, n.hash)
943953
if err != nil {
944954
return err
945955
}
946956

947-
if err := chnd.ForEachTrackedWithNodeSink(ctx, trail, b, sink, f); err != nil {
957+
if err := chnd.ForEachTrackedWithNodeSink(ctx, subTrail, b, sink, f); err != nil {
948958
return err
949959
}
950960
} else {
951961
for _, kv := range p.KVs {
952-
if err := f(string(kv.Key), kv.Value, trail); err != nil {
962+
if err := f(string(kv.Key), kv.Value, subTrail); err != nil {
953963
return err
954964
}
955965
}

0 commit comments

Comments
 (0)