Skip to content

Commit b532598

Browse files
timothymillarjeromekelleher
authored andcommitted
Add display_pedigree method #1097
1 parent c897aff commit b532598

File tree

4 files changed

+271
-3
lines changed

4 files changed

+271
-3
lines changed

docs/api.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,7 @@ Utilities
132132
convert_call_to_index
133133
convert_probability_to_call
134134
display_genotypes
135+
display_pedigree
135136
filter_partial_calls
136137
infer_call_ploidy
137138
infer_sample_ploidy

sgkit/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from pkg_resources import DistributionNotFound, get_distribution # type: ignore[import]
22

3-
from .display import display_genotypes
3+
from .display import display_genotypes, display_pedigree
44
from .distance.api import pairwise_distance
55
from .io.dataset import load_dataset, save_dataset
66
from .io.vcfzarr_reader import read_scikit_allel_vcfzarr
@@ -88,6 +88,7 @@
8888
"count_variant_genotypes",
8989
"create_genotype_dosage_dataset",
9090
"display_genotypes",
91+
"display_pedigree",
9192
"filter_partial_calls",
9293
"genee",
9394
"genomic_relationship",

sgkit/display.py

Lines changed: 78 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
1-
from typing import Any, Hashable, Mapping, Tuple
1+
from typing import Any, Dict, Hashable, Mapping, Optional, Tuple
22

33
import numpy as np
44
import pandas as pd
55
import xarray as xr
66

7+
from sgkit import variables
8+
from sgkit.stats.pedigree import parent_indices
79
from sgkit.typing import ArrayLike
10+
from sgkit.utils import define_variable_if_absent
811

912

1013
class GenotypeDisplay:
@@ -209,3 +212,77 @@ def display_genotypes(
209212
max_variants,
210213
max_samples,
211214
)
215+
216+
217+
def display_pedigree(
218+
ds: xr.Dataset,
219+
parent: Hashable = variables.parent,
220+
graph_attrs: Optional[Dict[Hashable, str]] = None,
221+
node_attrs: Optional[Dict[Hashable, ArrayLike]] = None,
222+
edge_attrs: Optional[Dict[Hashable, ArrayLike]] = None,
223+
) -> Any:
224+
"""Display a pedigree dataset as a directed acyclic graph.
225+
226+
Parameters
227+
----------
228+
ds
229+
Dataset containing pedigree structure.
230+
parent
231+
Input variable name holding parents of each sample as defined by
232+
:data:`sgkit.variables.parent_spec`.
233+
If the variable is not present in ``ds``, it will be computed
234+
using :func:`parent_indices`.
235+
graph_attrs
236+
Key-value pairs to pass through to graphviz as graph attributes.
237+
node_attrs
238+
Key-value pairs to pass through to graphviz as node attributes.
239+
Values will be broadcast to have shape (samples, ).
240+
edge_attrs
241+
Key-value pairs to pass through to graphviz as edge attributes.
242+
Values will be broadcast to have shape (samples, parents).
243+
244+
Raises
245+
------
246+
RuntimeError
247+
If the `Graphviz library <https://graphviz.readthedocs.io/en/stable/>`_ is not installed.
248+
249+
Returns
250+
-------
251+
A digraph representation of the pedigree.
252+
"""
253+
try:
254+
from graphviz import Digraph
255+
except ImportError: # pragma: no cover
256+
raise RuntimeError(
257+
"Visualizing pedigrees requires the `graphviz` python library and the `graphviz` system library to be installed."
258+
)
259+
ds = define_variable_if_absent(ds, variables.parent, parent, parent_indices)
260+
variables.validate(ds, {parent: variables.parent_spec})
261+
parent = ds[parent].values
262+
n_samples, n_parent_types = parent.shape
263+
graph_attrs = graph_attrs or {}
264+
node_attrs = node_attrs or {}
265+
edge_attrs = edge_attrs or {}
266+
# default to using samples coordinates for labels
267+
if ("label" not in node_attrs) and ("samples" in ds.coords):
268+
node_attrs["label"] = ds.samples.values
269+
# numpy broadcasting
270+
node_attrs = {k: np.broadcast_to(v, n_samples) for k, v in node_attrs.items()}
271+
edge_attrs = {k: np.broadcast_to(v, parent.shape) for k, v in edge_attrs.items()}
272+
# initialize graph
273+
graph = Digraph()
274+
graph.attr(**graph_attrs)
275+
# add nodes
276+
for i in range(n_samples):
277+
d = {k: str(v[i]) for k, v in node_attrs.items()}
278+
graph.node(str(i), **d)
279+
# add edges
280+
for i in range(n_samples):
281+
for j in range(n_parent_types):
282+
p = parent[i, j]
283+
if p >= 0:
284+
d = {}
285+
for k, v in edge_attrs.items():
286+
d[k] = str(v[i, j])
287+
graph.edge(str(p), str(i), **d)
288+
return graph

sgkit/tests/test_display.py

Lines changed: 190 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import pytest
55
import xarray as xr
66

7-
from sgkit import display_genotypes
7+
from sgkit import display_genotypes, display_pedigree
88
from sgkit.display import genotype_as_bytes
99
from sgkit.testing import simulate_genotype_call_dataset
1010

@@ -417,3 +417,192 @@ def test_genotype_as_bytes(genotype, phased, max_allele_chars, expect):
417417
expect,
418418
genotype_as_bytes(genotype, phased, max_allele_chars),
419419
)
420+
421+
422+
def pedigree_Hamilton_Kerr():
423+
ds = xr.Dataset()
424+
ds["sample_id"] = "samples", ["S1", "S2", "S3", "S4", "S5", "S6", "S7", "S8"]
425+
ds["parent_id"] = ["samples", "parents"], [
426+
[".", "."],
427+
[".", "."],
428+
[".", "S2"],
429+
["S1", "."],
430+
["S1", "S3"],
431+
["S1", "S3"],
432+
["S6", "S2"],
433+
["S6", "S2"],
434+
]
435+
ds["stat_Hamilton_Kerr_tau"] = ["samples", "parents"], [
436+
[1, 1],
437+
[2, 2],
438+
[0, 2],
439+
[2, 0],
440+
[1, 1],
441+
[2, 2],
442+
[2, 2],
443+
[2, 2],
444+
]
445+
ds["stat_Hamilton_Kerr_lambda"] = ["samples", "parents"], [
446+
[0.0, 0.0],
447+
[0.167, 0.167],
448+
[0.0, 0.167],
449+
[0.041, 0.0],
450+
[0.0, 0.0],
451+
[0.918, 0.041],
452+
[0.167, 0.167],
453+
[0.167, 0.167],
454+
]
455+
return ds
456+
457+
458+
def test_display_pedigree__no_coords():
459+
ds = pedigree_Hamilton_Kerr()
460+
graph = display_pedigree(ds)
461+
expect = """ digraph {
462+
\t0
463+
\t1
464+
\t2
465+
\t3
466+
\t4
467+
\t5
468+
\t6
469+
\t7
470+
\t1 -> 2
471+
\t0 -> 3
472+
\t0 -> 4
473+
\t2 -> 4
474+
\t0 -> 5
475+
\t2 -> 5
476+
\t5 -> 6
477+
\t1 -> 6
478+
\t5 -> 7
479+
\t1 -> 7
480+
}
481+
"""
482+
assert str(graph) == dedent(expect)
483+
484+
485+
def test_display_pedigree__samples_coords():
486+
ds = pedigree_Hamilton_Kerr()
487+
ds = ds.assign_coords(samples=ds.sample_id)
488+
graph = display_pedigree(ds)
489+
expect = """ digraph {
490+
\t0 [label=S1]
491+
\t1 [label=S2]
492+
\t2 [label=S3]
493+
\t3 [label=S4]
494+
\t4 [label=S5]
495+
\t5 [label=S6]
496+
\t6 [label=S7]
497+
\t7 [label=S8]
498+
\t1 -> 2
499+
\t0 -> 3
500+
\t0 -> 4
501+
\t2 -> 4
502+
\t0 -> 5
503+
\t2 -> 5
504+
\t5 -> 6
505+
\t1 -> 6
506+
\t5 -> 7
507+
\t1 -> 7
508+
}
509+
"""
510+
assert str(graph) == dedent(expect)
511+
512+
513+
def test_display_pedigree__samples_coords_reorder():
514+
ds = pedigree_Hamilton_Kerr()
515+
ds = ds.sel(samples=[7, 3, 5, 0, 4, 1, 2, 6])
516+
ds = ds.assign_coords(samples=ds.sample_id)
517+
graph = display_pedigree(ds)
518+
expect = """ digraph {
519+
\t0 [label=S8]
520+
\t1 [label=S4]
521+
\t2 [label=S6]
522+
\t3 [label=S1]
523+
\t4 [label=S5]
524+
\t5 [label=S2]
525+
\t6 [label=S3]
526+
\t7 [label=S7]
527+
\t2 -> 0
528+
\t5 -> 0
529+
\t3 -> 1
530+
\t3 -> 2
531+
\t6 -> 2
532+
\t3 -> 4
533+
\t6 -> 4
534+
\t5 -> 6
535+
\t2 -> 7
536+
\t5 -> 7
537+
}
538+
"""
539+
assert str(graph) == dedent(expect)
540+
541+
542+
def test_display_pedigree__samples_labels():
543+
ds = pedigree_Hamilton_Kerr()
544+
graph = display_pedigree(ds, node_attrs=dict(label=ds.sample_id))
545+
expect = """ digraph {
546+
\t0 [label=S1]
547+
\t1 [label=S2]
548+
\t2 [label=S3]
549+
\t3 [label=S4]
550+
\t4 [label=S5]
551+
\t5 [label=S6]
552+
\t6 [label=S7]
553+
\t7 [label=S8]
554+
\t1 -> 2
555+
\t0 -> 3
556+
\t0 -> 4
557+
\t2 -> 4
558+
\t0 -> 5
559+
\t2 -> 5
560+
\t5 -> 6
561+
\t1 -> 6
562+
\t5 -> 7
563+
\t1 -> 7
564+
}
565+
"""
566+
assert str(graph) == dedent(expect)
567+
568+
569+
def test_display_pedigree__broadcast():
570+
ds = pedigree_Hamilton_Kerr()
571+
inbreeding = np.array([0.0, 0.077, 0.231, 0.041, 0.0, 0.197, 0.196, 0.196])
572+
label = (ds.sample_id.str + "\n").str + inbreeding.astype("U")
573+
edges = xr.where(
574+
ds.stat_Hamilton_Kerr_tau == 2,
575+
"black:black",
576+
"black",
577+
)
578+
graph = display_pedigree(
579+
ds,
580+
graph_attrs=dict(splines="false", outputorder="edgesfirst"),
581+
node_attrs=dict(
582+
style="filled", fillcolor="black", fontcolor="white", label=label
583+
),
584+
edge_attrs=dict(arrowhead="crow", color=edges),
585+
)
586+
expect = """ digraph {
587+
\toutputorder=edgesfirst splines=false
588+
\t0 [label="S1\n 0.0" fillcolor=black fontcolor=white style=filled]
589+
\t1 [label="S2\n 0.077" fillcolor=black fontcolor=white style=filled]
590+
\t2 [label="S3\n 0.231" fillcolor=black fontcolor=white style=filled]
591+
\t3 [label="S4\n 0.041" fillcolor=black fontcolor=white style=filled]
592+
\t4 [label="S5\n 0.0" fillcolor=black fontcolor=white style=filled]
593+
\t5 [label="S6\n 0.197" fillcolor=black fontcolor=white style=filled]
594+
\t6 [label="S7\n 0.196" fillcolor=black fontcolor=white style=filled]
595+
\t7 [label="S8\n 0.196" fillcolor=black fontcolor=white style=filled]
596+
\t1 -> 2 [arrowhead=crow color="black:black"]
597+
\t0 -> 3 [arrowhead=crow color="black:black"]
598+
\t0 -> 4 [arrowhead=crow color=black]
599+
\t2 -> 4 [arrowhead=crow color=black]
600+
\t0 -> 5 [arrowhead=crow color="black:black"]
601+
\t2 -> 5 [arrowhead=crow color="black:black"]
602+
\t5 -> 6 [arrowhead=crow color="black:black"]
603+
\t1 -> 6 [arrowhead=crow color="black:black"]
604+
\t5 -> 7 [arrowhead=crow color="black:black"]
605+
\t1 -> 7 [arrowhead=crow color="black:black"]
606+
}
607+
"""
608+
assert str(graph) == dedent(expect)

0 commit comments

Comments
 (0)