Skip to content

Commit e99c448

Browse files
hyanwongmergify[bot]
authored andcommitted
Add a class for svg strings
1 parent 7d1d28e commit e99c448

File tree

6 files changed

+43
-5
lines changed

6 files changed

+43
-5
lines changed

docs/python-api.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1509,3 +1509,9 @@ Also see the {ref}`sec_python_api_metadata` summary.
15091509
:members:
15101510
```
15111511

1512+
#### The {class}`SVGString` class
1513+
```{eval-rst}
1514+
.. autoclass:: SVGString
1515+
:members:
1516+
:private-members: _repr_svg_
1517+
```

python/CHANGELOG.rst

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,12 @@
22
[0.5.1] - 2022-0X-XX
33
--------------------
44

5+
**Changes**
6+
7+
- SVG drawing routines now return a special string object that is automatically
8+
rendered in a Jupyter notebook (:user:`hyanwong`, :pr:`2377`)
9+
10+
511
--------------------
612
[0.5.0] - 2022-06-22
713
--------------------

python/tests/test_drawing.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1446,6 +1446,12 @@ def test_max_time(self):
14461446
with pytest.raises(ValueError):
14471447
t.draw_text(max_time=bad_max_time)
14481448

1449+
def test_no_repr_svg(self):
1450+
tree = self.get_simple_ts().first()
1451+
output = tree.draw(format="unicode")
1452+
with pytest.raises(AttributeError, match="no attribute"):
1453+
output._repr_svg_()
1454+
14491455

14501456
class TestDrawSvg(TestTreeDraw, xmlunittest.XmlTestMixin):
14511457
"""
@@ -1490,6 +1496,15 @@ def verify_basic_svg(self, svg, width=200, height=200, num_trees=1):
14901496
cls = group.attrib["class"]
14911497
assert re.search(r"\broot\b", cls)
14921498

1499+
def test_repr_svg(self):
1500+
ts = self.get_simple_ts()
1501+
svg = ts.draw_svg()
1502+
assert str(svg) == svg._repr_svg_()
1503+
svg = ts.first().draw_svg()
1504+
assert str(svg) == svg._repr_svg_()
1505+
svg = ts.first().draw(format="svg")
1506+
assert str(svg) == svg._repr_svg_()
1507+
14931508
def test_draw_to_file(self, tmp_path):
14941509
# NB: to view output files for testing changes to drawing code, it is possible
14951510
# to save to a fixed directory using e.g. `pytest --basetemp=/tmp/svgtest ...`

python/tskit/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@
8484
all_tree_labellings,
8585
TopologyCounter,
8686
)
87+
from tskit.drawing import SVGString # NOQA
8788
from tskit.exceptions import * # NOQA
8889
from tskit.util import * # NOQA
8990
from tskit.metadata import * # NOQA

python/tskit/drawing.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,16 @@ def linear_transform(self, y):
9494
return self.plot_min - (y - self.min_time) * y_scale
9595

9696

97+
class SVGString(str):
98+
"A string containing an SVG representation"
99+
100+
def _repr_svg_(self):
101+
"""
102+
Simply return the SVG string: called by jupyter notebooks to render trees.
103+
"""
104+
return self
105+
106+
97107
def check_orientation(orientation):
98108
if orientation is None:
99109
orientation = TOP
@@ -415,7 +425,7 @@ def remap_style(original_map, new_key, none_value):
415425
mutation_attrs=mutation_attrs,
416426
order=order,
417427
)
418-
return tree.drawing.tostring()
428+
return SVGString(tree.drawing.tostring())
419429

420430
else:
421431
if width is not None:

python/tskit/trees.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1794,7 +1794,7 @@ def draw_svg(
17941794
from the same tree sequence may share some plotted mutations.
17951795
17961796
:return: An SVG representation of a tree.
1797-
:rtype: str
1797+
:rtype: SVGString
17981798
"""
17991799
draw = drawing.SvgTree(
18001800
self,
@@ -1824,7 +1824,7 @@ def draw_svg(
18241824
if path is not None:
18251825
# TODO: removed the pretty here when this is stable.
18261826
draw.drawing.saveas(path, pretty=True)
1827-
return output
1827+
return drawing.SVGString(output)
18281828

18291829
def draw(
18301830
self,
@@ -6502,7 +6502,7 @@ def draw_svg(
65026502
at each y tickmark.
65036503
65046504
:return: An SVG representation of a tree sequence.
6505-
:rtype: str
6505+
:rtype: SVGString
65066506
65076507
.. note::
65086508
Technically, x_lim[0] specifies a *minimum* value for the start of the X
@@ -6541,7 +6541,7 @@ def draw_svg(
65416541
if path is not None:
65426542
# TODO remove the 'pretty' when we are done debugging this.
65436543
draw.drawing.saveas(path, pretty=True)
6544-
return output
6544+
return drawing.SVGString(output)
65456545

65466546
def draw_text(
65476547
self,

0 commit comments

Comments
 (0)