Skip to content

Commit 63093db

Browse files
Factored tree drawing code out of TS SVG.
Added some basic test coverage also.
1 parent 4ac7a82 commit 63093db

File tree

3 files changed

+281
-70
lines changed

3 files changed

+281
-70
lines changed

python/tests/test_drawing.py

Lines changed: 133 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -615,6 +615,68 @@ def test_simple_tree_sequence(self):
615615
ts = tskit.load_text(nodes, edges, strict=False)
616616
self.verify_text_rendering(ts.draw_text(), ts_drawing)
617617

618+
def test_tree_height_scale(self):
619+
tree = msprime.simulate(4, random_seed=2).first()
620+
with self.assertRaises(ValueError):
621+
tree.draw_text(tree_height_scale="time")
622+
623+
t1 = tree.draw_text(tree_height_scale="rank")
624+
t2 = tree.draw_text()
625+
self.assertEqual(t1, t2)
626+
627+
for bad_scale in [0, "", "NOT A SCALE"]:
628+
with self.assertRaises(ValueError):
629+
tree.draw_text(tree_height_scale=bad_scale)
630+
631+
def test_max_tree_height(self):
632+
nodes = io.StringIO("""\
633+
id is_sample population individual time metadata
634+
0 1 0 -1 0.00000000000000
635+
1 1 0 -1 0.00000000000000
636+
2 1 0 -1 0.00000000000000
637+
3 1 0 -1 0.00000000000000
638+
4 0 0 -1 0.02445014598813
639+
5 0 0 -1 0.11067965364865
640+
6 0 0 -1 1.75005250750382
641+
7 0 0 -1 2.31067154311640
642+
8 0 0 -1 3.57331354884652
643+
9 0 0 -1 9.08308317451295
644+
""")
645+
edges = io.StringIO("""\
646+
id left right parent child
647+
0 0.00000000 1.00000000 4 0
648+
1 0.00000000 1.00000000 4 1
649+
2 0.00000000 1.00000000 5 2
650+
3 0.00000000 1.00000000 5 3
651+
4 0.79258618 0.90634460 6 4
652+
5 0.79258618 0.90634460 6 5
653+
6 0.05975243 0.79258618 7 4
654+
7 0.90634460 0.91029435 7 4
655+
8 0.05975243 0.79258618 7 5
656+
9 0.90634460 0.91029435 7 5
657+
10 0.91029435 1.00000000 8 4
658+
11 0.91029435 1.00000000 8 5
659+
12 0.00000000 0.05975243 9 4
660+
13 0.00000000 0.05975243 9 5
661+
""")
662+
ts = tskit.load_text(nodes, edges, strict=False)
663+
tree = (
664+
" 9 \n"
665+
" ┏━┻━┓ \n"
666+
" ┃ ┃ \n"
667+
" ┃ ┃ \n"
668+
" ┃ ┃ \n"
669+
" ┃ ┃ \n"
670+
" ┃ ┃ \n"
671+
" ┃ ┃ \n"
672+
" ┃ 5 \n"
673+
" ┃ ┏┻┓\n"
674+
" 4 ┃ ┃\n"
675+
"┏┻┓ ┃ ┃\n"
676+
"0 1 2 3\n")
677+
t = ts.first()
678+
self.verify_text_rendering(t.draw_text(max_tree_height="ts"), tree)
679+
618680

619681
class TestDrawSvg(TestTreeDraw):
620682
"""
@@ -636,67 +698,105 @@ def test_draw_file(self):
636698
with open(filename) as tmp:
637699
other_svg = tmp.read()
638700
self.assertEqual(svg, other_svg)
701+
os.unlink(filename)
702+
703+
svg = t.draw_svg(path=filename)
704+
self.assertGreater(os.path.getsize(filename), 0)
705+
with open(filename) as tmp:
706+
other_svg = tmp.read()
707+
self.verify_basic_svg(svg)
708+
self.verify_basic_svg(other_svg)
709+
710+
ts = t.tree_sequence
711+
svg = ts.draw_svg(path=filename)
712+
self.assertGreater(os.path.getsize(filename), 0)
713+
with open(filename) as tmp:
714+
other_svg = tmp.read()
715+
self.verify_basic_svg(svg)
716+
self.verify_basic_svg(other_svg)
639717
finally:
640718
os.unlink(filename)
641719

642720
def test_draw_defaults(self):
643721
t = self.get_binary_tree()
644722
svg = t.draw()
645723
self.verify_basic_svg(svg)
724+
svg = t.draw_svg()
725+
self.verify_basic_svg(svg)
646726

647727
def test_draw_nonbinary(self):
648728
t = self.get_nonbinary_tree()
649729
svg = t.draw()
650730
self.verify_basic_svg(svg)
731+
svg = t.draw_svg()
732+
self.verify_basic_svg(svg)
651733

652734
def test_draw_multiroot(self):
653735
t = self.get_multiroot_tree()
654736
svg = t.draw()
655737
self.verify_basic_svg(svg)
738+
svg = t.draw_svg()
739+
self.verify_basic_svg(svg)
656740

657741
def test_draw_mutations_over_roots(self):
658742
t = self.get_mutations_over_roots_tree()
659743
svg = t.draw()
660744
self.verify_basic_svg(svg)
745+
svg = t.draw_svg()
746+
self.verify_basic_svg(svg)
661747

662748
def test_draw_unary(self):
663749
t = self.get_unary_node_tree()
664750
svg = t.draw()
665751
self.verify_basic_svg(svg)
752+
svg = t.draw_svg()
753+
self.verify_basic_svg(svg)
666754

667755
def test_draw_empty(self):
668756
t = self.get_empty_tree()
669757
self.assertRaises(ValueError, t.draw)
758+
self.assertRaises(ValueError, t.draw_svg)
670759

671760
def test_draw_zero_roots(self):
672761
t = self.get_zero_roots_tree()
673762
self.assertRaises(ValueError, t.draw)
763+
self.assertRaises(ValueError, t.draw_svg)
674764

675765
def test_draw_zero_edge(self):
676766
t = self.get_zero_edge_tree()
677767
svg = t.draw()
678768
self.verify_basic_svg(svg)
769+
svg = t.draw_svg()
770+
self.verify_basic_svg(svg)
679771

680772
def test_width_height(self):
681773
t = self.get_binary_tree()
682774
w = 123
683775
h = 456
684776
svg = t.draw(width=w, height=h)
685777
self.verify_basic_svg(svg, w, h)
778+
svg = t.draw_svg(size=(w, h))
779+
self.verify_basic_svg(svg, w, h)
686780

687781
def test_node_labels(self):
688782
t = self.get_binary_tree()
689783
labels = {u: "XXX" for u in t.nodes()}
690784
svg = t.draw(format="svg", node_labels=labels)
691785
self.verify_basic_svg(svg)
692786
self.assertEqual(svg.count("XXX"), t.num_nodes)
787+
svg = t.draw_svg(node_label_attrs={u: {"text": labels[u]} for u in t.nodes()})
788+
self.verify_basic_svg(svg)
789+
self.assertEqual(svg.count("XXX"), t.num_nodes)
693790

694791
def test_one_node_label(self):
695792
t = self.get_binary_tree()
696793
labels = {0: "XXX"}
697794
svg = t.draw(format="svg", node_labels=labels)
698795
self.verify_basic_svg(svg)
699796
self.assertEqual(svg.count("XXX"), 1)
797+
svg = t.draw_svg(node_label_attrs={0: {"text": "XXX"}})
798+
self.verify_basic_svg(svg)
799+
self.assertEqual(svg.count("XXX"), 1)
700800

701801
def test_no_node_labels(self):
702802
t = self.get_binary_tree()
@@ -712,6 +812,9 @@ def test_one_node_colour(self):
712812
svg = t.draw(format="svg", node_colours=colours)
713813
self.verify_basic_svg(svg)
714814
self.assertEqual(svg.count('fill="{}"'.format(colour)), 1)
815+
svg = t.draw_svg(node_attrs={0: {'fill': colour}})
816+
self.verify_basic_svg(svg)
817+
self.assertEqual(svg.count('fill="{}"'.format(colour)), 1)
715818

716819
def test_all_nodes_colour(self):
717820
t = self.get_binary_tree()
@@ -721,6 +824,12 @@ def test_all_nodes_colour(self):
721824
for colour in colours.values():
722825
self.assertEqual(svg.count('fill="{}"'.format(colour)), 1)
723826

827+
svg = t.draw_svg(node_attrs={u: {'fill': colours[u]} for u in t.nodes()})
828+
self.verify_basic_svg(svg)
829+
self.assertEqual(svg.count('fill="{}"'.format(colour)), 1)
830+
for colour in colours.values():
831+
self.assertEqual(svg.count('fill="{}"'.format(colour)), 1)
832+
724833
def test_unplotted_node(self):
725834
t = self.get_binary_tree()
726835
colour = None
@@ -737,7 +846,15 @@ def test_one_edge_colour(self):
737846
svg = t.draw(format="svg", edge_colours=colours)
738847
self.verify_basic_svg(svg)
739848
self.assertEqual(svg.count('stroke="{}"'.format(colour)), 2)
849+
svg = t.draw_svg(edge_attrs={0: {"stroke": colour}})
850+
self.verify_basic_svg(svg)
851+
# We're mapping to a path here, so only see it once. The old code
852+
# drew two lines.
853+
self.assertEqual(svg.count('stroke="{}"'.format(colour)), 1)
740854

855+
#
856+
# TODO: update the tests below here to check the new SVG based interface.
857+
#
741858
def test_all_edges_colour(self):
742859
t = self.get_binary_tree()
743860
colours = {u: "rgb({u}, {u}, {u})".format(u=u) for u in t.nodes() if u != t.root}
@@ -804,7 +921,7 @@ def test_unplotted_mutation(self):
804921
mutations_in_tree = list(t.mutations())
805922
self.assertEqual(svg.count('<rect'), len(mutations_in_tree) - 1)
806923

807-
def test_max_timescale(self):
924+
def test_max_tree_height(self):
808925
nodes = io.StringIO("""\
809926
id is_sample time
810927
0 1 0
@@ -847,3 +964,18 @@ def test_max_timescale(self):
847964
str_pos = svg2.find('>3<')
848965
snippet2 = svg2[svg2.rfind("<", 0, str_pos):str_pos]
849966
self.assertEqual(snippet1, snippet2)
967+
968+
def test_draw_simple_ts(self):
969+
ts = msprime.simulate(5, recombination_rate=1, random_seed=1)
970+
svg = ts.draw_svg()
971+
self.verify_basic_svg(svg, width=200 * ts.num_trees)
972+
973+
def test_tree_height_scale(self):
974+
ts = msprime.simulate(4, random_seed=2)
975+
svg = ts.draw_svg(tree_height_scale="time")
976+
self.verify_basic_svg(svg)
977+
svg = ts.draw_svg(tree_height_scale="rank")
978+
self.verify_basic_svg(svg)
979+
for bad_scale in [0, "", "NOT A SCALE"]:
980+
with self.assertRaises(ValueError):
981+
ts.draw_svg(tree_height_scale=bad_scale)

0 commit comments

Comments
 (0)