Skip to content

Commit 21c6b58

Browse files
gilbsgilbsrdleal
authored andcommitted
InOrderTraverse fixes
- Fix build (missing import and trivial wrong symbol) - Pass end of interval along with start of interval to VisitFunc - Make VisitFunc pass the interval before the value to match Update prototype - Add tests
1 parent 86f1717 commit 21c6b58

File tree

2 files changed

+254
-13
lines changed

2 files changed

+254
-13
lines changed

interval/search.go

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
package interval
22

3+
import (
4+
"errors"
5+
)
6+
37
// Find returns the value which interval key exactly matches with the given start and end interval.
48
// It returns true as the second return value if an exaclty matching interval key is found in the tree;
59
// otherwise, false.
@@ -484,13 +488,13 @@ func maxEnd[V, T any](n *node[V, T], searchEnd T, cmp CmpFunc[T], visit func(*no
484488
var StopTraversal = errors.New("stop tree traversal")
485489

486490
// VisitFunc is called on all values. Returning non-nil error will stop iteration.
487-
// If the returned error is [StopTraversal], the iteration is interrupted, but no error is returned to the caller.
488-
type VisitFunc[V, T any] func(V, T) error
491+
// If the returned error is [StopTraversal], the iteration is interrupted, but no error is returned to the caller.
492+
type VisitFunc[V, T any] func(T, T, V) error
489493

490494
// InOrderTraverse traverses the tree in order and applies VisitFunc to each node. It's safe for concurrent use. To prevent deadlock, avoid calling other tree methods within visitFunc.
491495
func (st *SearchTree[V, T]) InOrderTraverse(visitFunc VisitFunc[V, T]) error {
492-
tree.mu.RLock()
493-
defer tree.mu.RUnlock()
496+
st.mu.RLock()
497+
defer st.mu.RUnlock()
494498

495499
var inOrder func(n *node[V, T]) error
496500
inOrder = func(n *node[V, T]) error {
@@ -504,7 +508,7 @@ func (st *SearchTree[V, T]) InOrderTraverse(visitFunc VisitFunc[V, T]) error {
504508
}
505509

506510
// Visit current node
507-
err := visitFunc(n.Interval.Val, n.Interval.Start)
511+
err := visitFunc(n.Interval.Start, n.Interval.End, n.Interval.Val)
508512
if err != nil {
509513
return err
510514
}
@@ -513,18 +517,18 @@ func (st *SearchTree[V, T]) InOrderTraverse(visitFunc VisitFunc[V, T]) error {
513517
return inOrder(n.Right)
514518
}
515519

516-
err := inOrder(tree.root)
520+
err := inOrder(st.root)
517521
// Do not percolate StopTraversal error to the caller.
518522
if errors.Is(err, StopTraversal) {
519-
return nil
523+
return nil
520524
}
521525
return err
522526
}
523527

524528
// InOrderTraverse traverses the tree in order and applies VisitFunc to each node. It's safe for concurrent use. To prevent deadlock, avoid calling other tree methods within visitFunc.
525-
func (st *MultiValueSearchTree[V, T]) InOrderTraverse(visitFunc VisitFunc[V, T]) error {
526-
tree.mu.RLock()
527-
defer tree.mu.RUnlock()
529+
func (st *MultiValueSearchTree[V, T]) InOrderTraverse(visitFunc VisitFunc[[]V, T]) error {
530+
st.mu.RLock()
531+
defer st.mu.RUnlock()
528532

529533
var inOrder func(n *node[V, T]) error
530534
inOrder = func(n *node[V, T]) error {
@@ -538,7 +542,7 @@ func (st *MultiValueSearchTree[V, T]) InOrderTraverse(visitFunc VisitFunc[V, T])
538542
}
539543

540544
// Visit current node
541-
err := visitFunc(n.Interval.Vals, n.Interval.Start)
545+
err := visitFunc(n.Interval.Start, n.Interval.End, n.Interval.Vals)
542546
if err != nil {
543547
return err
544548
}
@@ -547,10 +551,10 @@ func (st *MultiValueSearchTree[V, T]) InOrderTraverse(visitFunc VisitFunc[V, T])
547551
return inOrder(n.Right)
548552
}
549553

550-
err := inOrder(tree.root)
554+
err := inOrder(st.root)
551555
// Do not percolate StopTraversal error to the caller.
552556
if errors.Is(err, StopTraversal) {
553-
return nil
557+
return nil
554558
}
555559
return err
556560
}

interval/search_test.go

Lines changed: 237 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -587,6 +587,124 @@ func TestSearchTree_Select(t *testing.T) {
587587
}
588588
}
589589

590+
func TestSearchTree_InOrderTraverse(t *testing.T) {
591+
type insert struct {
592+
start int
593+
end int
594+
val string
595+
}
596+
tests := []struct {
597+
name string
598+
inserts []insert
599+
expectedVisits []string
600+
}{
601+
{
602+
name: "empty interval",
603+
inserts: []insert{},
604+
expectedVisits: []string{},
605+
},
606+
{
607+
name: "single interval",
608+
inserts: []insert{
609+
{start: 1, end: 10, val: "node1"},
610+
},
611+
expectedVisits: []string{"node1"},
612+
},
613+
{
614+
name: "multiple intervals",
615+
inserts: []insert{
616+
{start: 1, end: 10, val: "node1"},
617+
{start: 5, end: 15, val: "node2"},
618+
{start: 10, end: 20, val: "node3"},
619+
{start: 15, end: 25, val: "node4"},
620+
{start: 20, end: 30, val: "node5"},
621+
},
622+
expectedVisits: []string{"node1", "node2", "node3", "node4", "node5"},
623+
},
624+
{
625+
name: "multiple intervals with same end",
626+
inserts: []insert{
627+
{start: 1, end: 10, val: "node1"},
628+
{start: 5, end: 15, val: "node2"},
629+
{start: 10, end: 20, val: "node3"},
630+
{start: 15, end: 25, val: "node4"},
631+
{start: 20, end: 30, val: "node5"},
632+
{start: 25, end: 30, val: "node6"},
633+
},
634+
expectedVisits: []string{"node1", "node2", "node3", "node4", "node5", "node6"},
635+
},
636+
{
637+
name: "multiple intervals with same end and same start",
638+
inserts: []insert{
639+
{start: 20, end: 30, val: "node5"},
640+
{start: 25, end: 30, val: "node6"},
641+
{start: 15, end: 30, val: "node7"},
642+
},
643+
expectedVisits: []string{"node7", "node5", "node6"},
644+
},
645+
{
646+
name: "interval spanning entire range",
647+
inserts: []insert{
648+
{start: 1, end: 5, val: "node1"},
649+
{start: 5, end: 10, val: "node2"},
650+
{start: 10, end: 20, val: "node3"},
651+
{start: 0, end: 30, val: "node4"},
652+
},
653+
expectedVisits: []string{"node4", "node1", "node2", "node3"},
654+
},
655+
}
656+
657+
for _, tc := range tests {
658+
t.Run(tc.name, func(t *testing.T) {
659+
st := NewSearchTree[string](func(x, y int) int { return x - y })
660+
661+
for _, insert := range tc.inserts {
662+
st.Insert(insert.start, insert.end, insert.val)
663+
}
664+
665+
got := []string{}
666+
667+
err := st.InOrderTraverse(func(start, end int, node string) error {
668+
got = append(got, node)
669+
return nil
670+
})
671+
if err != nil {
672+
t.Fatalf("st.InOrderTraverse(): error %v", err)
673+
}
674+
675+
if !reflect.DeepEqual(got, tc.expectedVisits) {
676+
t.Errorf("st.InOderTraverse(): got unexpected value %v; want %v", got, tc.expectedVisits)
677+
}
678+
})
679+
}
680+
681+
t.Run("stop traversal", func(t *testing.T) {
682+
st := NewSearchTree[string](func(x, y int) int { return x - y })
683+
st.Insert(17, 19, "node1")
684+
st.Insert(5, 8, "node2")
685+
st.Insert(21, 24, "node3")
686+
st.Insert(4, 8, "node4")
687+
688+
want := []string{"node4", "node2", "node1"}
689+
got := []string{}
690+
err := st.InOrderTraverse(func(start, end int, node string) error {
691+
got = append(got, node)
692+
if node == "node1" {
693+
return StopTraversal
694+
}
695+
return nil
696+
})
697+
698+
if err != nil {
699+
t.Fatalf("st.InOrderTraverse(): error %v", err)
700+
}
701+
702+
if !reflect.DeepEqual(got, want) {
703+
t.Errorf("st.InOderTraverse(): got unexpected value %v; want %v", got, want)
704+
}
705+
})
706+
}
707+
590708
func TestMultiValueSearchTree_AnyIntersection(t *testing.T) {
591709
st := NewMultiValueSearchTree[string](func(x, y int) int { return x - y })
592710
defer mustBeValidTree(t, st.root)
@@ -1115,3 +1233,122 @@ func TestMultiValueSearchTree_MaxEnd(t *testing.T) {
11151233
}
11161234

11171235
}
1236+
1237+
func TestMultiValueSearchTree_InOrderTraverse(t *testing.T) {
1238+
type insert struct {
1239+
start int
1240+
end int
1241+
val string
1242+
}
1243+
tests := []struct {
1244+
name string
1245+
inserts []insert
1246+
expectedVisits [][]string
1247+
}{
1248+
{
1249+
name: "empty interval",
1250+
inserts: []insert{},
1251+
expectedVisits: [][]string{},
1252+
},
1253+
{
1254+
name: "single interval",
1255+
inserts: []insert{
1256+
{start: 1, end: 10, val: "node1"},
1257+
{start: 1, end: 10, val: "node2"},
1258+
},
1259+
expectedVisits: [][]string{{"node1", "node2"}},
1260+
},
1261+
{
1262+
name: "multiple intervals",
1263+
inserts: []insert{
1264+
{start: 1, end: 10, val: "node1"},
1265+
{start: 5, end: 15, val: "node2"},
1266+
{start: 10, end: 20, val: "node3"},
1267+
{start: 15, end: 25, val: "node4"},
1268+
{start: 20, end: 30, val: "node5"},
1269+
},
1270+
expectedVisits: [][]string{{"node1"}, {"node2"}, {"node3"}, {"node4"}, {"node5"}},
1271+
},
1272+
{
1273+
name: "multiple intervals with same end",
1274+
inserts: []insert{
1275+
{start: 1, end: 10, val: "node1"},
1276+
{start: 5, end: 15, val: "node2"},
1277+
{start: 10, end: 20, val: "node3"},
1278+
{start: 15, end: 25, val: "node4"},
1279+
{start: 20, end: 30, val: "node5"},
1280+
{start: 25, end: 30, val: "node6"},
1281+
},
1282+
expectedVisits: [][]string{{"node1"}, {"node2"}, {"node3"}, {"node4"}, {"node5"}, {"node6"}},
1283+
},
1284+
{
1285+
name: "multiple intervals with same end and same start",
1286+
inserts: []insert{
1287+
{start: 20, end: 30, val: "node5"},
1288+
{start: 25, end: 30, val: "node6"},
1289+
{start: 15, end: 30, val: "node7"},
1290+
},
1291+
expectedVisits: [][]string{{"node7"}, {"node5"}, {"node6"}},
1292+
},
1293+
{
1294+
name: "interval spanning entire range",
1295+
inserts: []insert{
1296+
{start: 1, end: 5, val: "node1"},
1297+
{start: 5, end: 10, val: "node2"},
1298+
{start: 10, end: 20, val: "node3"},
1299+
{start: 0, end: 30, val: "node4"},
1300+
},
1301+
expectedVisits: [][]string{{"node4"}, {"node1"}, {"node2"}, {"node3"}},
1302+
},
1303+
}
1304+
1305+
for _, tc := range tests {
1306+
t.Run(tc.name, func(t *testing.T) {
1307+
st := NewMultiValueSearchTree[string](func(x, y int) int { return x - y })
1308+
1309+
for _, insert := range tc.inserts {
1310+
st.Insert(insert.start, insert.end, insert.val)
1311+
}
1312+
1313+
got := [][]string{}
1314+
1315+
err := st.InOrderTraverse(func(start, end int, node []string) error {
1316+
got = append(got, node)
1317+
return nil
1318+
})
1319+
if err != nil {
1320+
t.Fatalf("st.InOrderTraverse(): error %v", err)
1321+
}
1322+
1323+
if !reflect.DeepEqual(got, tc.expectedVisits) {
1324+
t.Errorf("st.InOderTraverse(): got unexpected value %v; want %v", got, tc.expectedVisits)
1325+
}
1326+
})
1327+
}
1328+
1329+
t.Run("stop traversal", func(t *testing.T) {
1330+
st := NewMultiValueSearchTree[string](func(x, y int) int { return x - y })
1331+
st.Insert(17, 19, "node1")
1332+
st.Insert(5, 8, "node2")
1333+
st.Insert(21, 24, "node3")
1334+
st.Insert(4, 8, "node4")
1335+
1336+
want := [][]string{{"node4"}, {"node2"}, {"node1"}}
1337+
got := [][]string{}
1338+
err := st.InOrderTraverse(func(start, end int, node []string) error {
1339+
got = append(got, node)
1340+
if node[0] == "node1" {
1341+
return StopTraversal
1342+
}
1343+
return nil
1344+
})
1345+
1346+
if err != nil {
1347+
t.Fatalf("st.InOrderTraverse(): error %v", err)
1348+
}
1349+
1350+
if !reflect.DeepEqual(got, want) {
1351+
t.Errorf("st.InOderTraverse(): got unexpected value %v; want %v", got, want)
1352+
}
1353+
})
1354+
}

0 commit comments

Comments
 (0)