Skip to content

Commit d024103

Browse files
Fix issues with max_tree_height with rank scale.
For rank scale in the draw_svg, the max_tree_height parameter was not being interpreted correctly. Closes #383
1 parent 7f45b25 commit d024103

File tree

2 files changed

+111
-45
lines changed

2 files changed

+111
-45
lines changed

python/tests/test_drawing.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1163,6 +1163,47 @@ def test_one_mutation_label_colour(self):
11631163
self.verify_basic_svg(svg)
11641164
self.assertEqual(svg.count('stroke="{}"'.format(colour)), 1)
11651165

1166+
def test_bad_tree_height_scale(self):
1167+
t = self.get_binary_tree()
1168+
for bad_scale in ["te", "asdf", "", [], b'23']:
1169+
with self.assertRaises(ValueError):
1170+
t.draw_svg(tree_height_scale=bad_scale)
1171+
1172+
def test_bad_max_tree_height(self):
1173+
t = self.get_binary_tree()
1174+
for bad_height in ["te", "asdf", "", [], b'23']:
1175+
with self.assertRaises(ValueError):
1176+
t.draw_svg(max_tree_height=bad_height)
1177+
1178+
def test_height_scale_time_and_max_tree_height(self):
1179+
ts = msprime.simulate(5, recombination_rate=2, random_seed=2)
1180+
t = ts.first()
1181+
# The default should be the same as tree.
1182+
svg1 = t.draw_svg(max_tree_height="tree")
1183+
self.verify_basic_svg(svg1)
1184+
svg2 = t.draw_svg()
1185+
self.assertEqual(svg1, svg2)
1186+
svg3 = t.draw_svg(max_tree_height="ts")
1187+
self.assertNotEqual(svg1, svg3)
1188+
svg4 = t.draw_svg(max_tree_height=max(ts.tables.nodes.time))
1189+
self.assertEqual(svg3, svg4)
1190+
1191+
def test_height_scale_rank_and_max_tree_height(self):
1192+
# Make sure the rank height scale and max_tree_height interact properly.
1193+
ts = msprime.simulate(5, recombination_rate=2, random_seed=2)
1194+
t = ts.first()
1195+
# The default should be the same as tree.
1196+
svg1 = t.draw_svg(max_tree_height="tree", tree_height_scale="rank")
1197+
self.verify_basic_svg(svg1)
1198+
svg2 = t.draw_svg(tree_height_scale="rank")
1199+
self.assertEqual(svg1, svg2)
1200+
svg3 = t.draw_svg("tmp.svg", max_tree_height="ts", tree_height_scale="rank")
1201+
self.assertNotEqual(svg1, svg3)
1202+
self.verify_basic_svg(svg3)
1203+
# Numeric max tree height not supported for rank scale.
1204+
with self.assertRaises(ValueError):
1205+
t.draw_svg(max_tree_height=2, tree_height_scale="rank")
1206+
11661207
#
11671208
# TODO: update the tests below here to check the new SVG based interface.
11681209
#

python/tskit/drawing.py

Lines changed: 70 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,50 @@
2424
Module responsible for visualisations.
2525
"""
2626
import collections
27-
from _tskit import NULL
27+
import numbers
2828

2929
import svgwrite
3030
import numpy as np
3131

32+
from _tskit import NULL
33+
34+
LEFT = "left"
35+
RIGHT = "right"
36+
TOP = "top"
37+
BOTTOM = "bottom"
38+
39+
40+
def check_orientation(orientation):
41+
if orientation is None:
42+
orientation = TOP
43+
else:
44+
orientation = orientation.lower()
45+
orientations = [LEFT, RIGHT, TOP, BOTTOM]
46+
if orientation not in orientations:
47+
raise ValueError(
48+
"Unknown orientiation: choose from {}".format(orientations))
49+
return orientation
50+
51+
52+
def check_max_tree_height(max_tree_height, allow_numeric=True):
53+
if max_tree_height is None:
54+
max_tree_height = "tree"
55+
is_numeric = isinstance(max_tree_height, numbers.Real)
56+
if max_tree_height not in ["tree", "ts"] and not allow_numeric:
57+
raise ValueError("max_tree_height must be 'tree' or 'ts'")
58+
if max_tree_height not in ["tree", "ts"] and (allow_numeric and not is_numeric):
59+
raise ValueError(
60+
"max_tree_height must be a numeric value or one of 'tree' or 'ts'")
61+
return max_tree_height
62+
63+
64+
def check_tree_height_scale(tree_height_scale):
65+
if tree_height_scale is None:
66+
tree_height_scale = "time"
67+
if tree_height_scale not in ["time", "log_time", "rank"]:
68+
raise ValueError("tree_height_scale must be 'time', 'log_time' or 'rank'")
69+
return tree_height_scale
70+
3271

3372
def check_format(format):
3473
if format is None:
@@ -264,28 +303,39 @@ def setup_drawing(self):
264303
self.mutation_right_labels = self.mutation_labels.add(dwg.g(text_anchor="end"))
265304

266305
def assign_y_coordinates(self, tree_height_scale, max_tree_height):
306+
tree_height_scale = check_tree_height_scale(tree_height_scale)
307+
max_tree_height = check_max_tree_height(
308+
max_tree_height, tree_height_scale != "rank")
267309
ts = self.tree.tree_sequence
268310
node_time = ts.tables.nodes.time
269-
if tree_height_scale in [None, "time", "log_time"]:
270-
if max_tree_height in [None, "tree"]:
311+
312+
if tree_height_scale == "rank":
313+
assert tree_height_scale == "rank"
314+
if max_tree_height == "tree":
315+
# We only rank the times within the tree in this case.
316+
t = np.zeros_like(node_time) + node_time[self.tree.left_root]
317+
for u in self.tree.nodes():
318+
t[u] = node_time[u]
319+
node_time = t
320+
depth = {t: 2 * j for j, t in enumerate(np.unique(node_time))}
321+
node_height = [depth[node_time[u]] for u in range(ts.num_nodes)]
322+
max_tree_height = max(depth.values())
323+
else:
324+
assert tree_height_scale in ["time", "log_time"]
325+
if max_tree_height == "tree":
271326
max_tree_height = max(self.tree.time(root) for root in self.tree.roots)
272327
elif max_tree_height == "ts":
273328
max_tree_height = ts.max_root_time
274-
if tree_height_scale == "log_time":
275-
# add 1 so that don't reach log(0) = -inf error.
276-
# just shifts entire timeset by 1 year so shouldn't affect anything
277-
node_height = np.log(ts.tables.nodes.time + 1)
278-
elif tree_height_scale in [None, "time"]:
279-
node_height = node_time
280-
else:
281-
if tree_height_scale != "rank":
282-
raise ValueError(
283-
"Only 'time', 'log_time', "
284-
"and 'rank' are supported for tree_height_scale")
285-
depth = {t: 2 * j for j, t in enumerate(np.unique(node_time))}
286-
node_height = [depth[node_time[u]] for u in range(ts.num_nodes)]
287-
if max_tree_height is None:
288-
max_tree_height = max(depth.values())
329+
330+
if tree_height_scale == "log_time":
331+
# add 1 so that don't reach log(0) = -inf error.
332+
# just shifts entire timeset by 1 year so shouldn't affect anything
333+
node_height = np.log(ts.tables.nodes.time + 1)
334+
elif tree_height_scale == "time":
335+
node_height = node_time
336+
337+
assert float(max_tree_height) == max_tree_height
338+
289339
# In pathological cases, all the roots are at 0
290340
if max_tree_height == 0:
291341
max_tree_height = 1
@@ -466,32 +516,6 @@ def __str__(self):
466516
return "".join(self.canvas.reshape(self.width * self.height))
467517

468518

469-
LEFT = "left"
470-
RIGHT = "right"
471-
TOP = "top"
472-
BOTTOM = "bottom"
473-
474-
475-
def check_orientation(orientation):
476-
if orientation is None:
477-
orientation = TOP
478-
else:
479-
orientation = orientation.lower()
480-
orientations = [LEFT, RIGHT, TOP, BOTTOM]
481-
if orientation not in orientations:
482-
raise ValueError(
483-
"Unknown orientiation: choose from {}".format(orientations))
484-
return orientation
485-
486-
487-
def check_max_tree_height(max_tree_height):
488-
if max_tree_height is None:
489-
max_tree_height = "tree"
490-
if max_tree_height not in ["tree", "ts"]:
491-
raise ValueError("max_tree_height must be 'tree' or 'ts'")
492-
return max_tree_height
493-
494-
495519
def to_np_unicode(string):
496520
"""
497521
Converts the specified string to a numpy unicode array.
@@ -574,7 +598,8 @@ def __init__(
574598
self, tree, node_labels=None, max_tree_height=None, use_ascii=False,
575599
orientation=None):
576600
self.tree = tree
577-
self.max_tree_height = check_max_tree_height(max_tree_height)
601+
self.max_tree_height = check_max_tree_height(
602+
max_tree_height, allow_numeric=False)
578603
self.use_ascii = use_ascii
579604
self.orientation = check_orientation(orientation)
580605
self.horizontal_line_char = '━'

0 commit comments

Comments
 (0)