Skip to content

Commit 5bfb239

Browse files
Make input more flexible
1 parent cb94a58 commit 5bfb239

File tree

2 files changed

+130
-108
lines changed

2 files changed

+130
-108
lines changed

bio2zarr/tskit.py

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -12,23 +12,34 @@ class TskitFormat(vcz.Source):
1212
@core.requires_optional_dependency("tskit", "tskit")
1313
def __init__(
1414
self,
15-
ts_path,
16-
individuals_nodes=None,
17-
sample_ids=None,
15+
ts,
16+
*,
17+
model_mapping=None,
1818
contig_id=None,
1919
isolated_as_missing=False,
2020
):
2121
import tskit
2222

23-
self._path = ts_path
24-
self.ts = tskit.load(ts_path)
23+
self._path = None # Not sure what we're using this for?
24+
# Future versions here will need to deal with the complexities of
25+
# having lists of tree sequences for multiple chromosomes.
26+
if isinstance(ts, tskit.TreeSequence):
27+
self.ts = ts
28+
else:
29+
# input 'ts' is a path.
30+
self._path = ts
31+
self.ts = tskit.load(ts)
32+
2533
self.contig_id = contig_id if contig_id is not None else "1"
2634
self.isolated_as_missing = isolated_as_missing
2735

2836
self.positions = self.ts.sites_position
2937

30-
if individuals_nodes is None:
31-
individuals_nodes = self.ts.individuals_nodes
38+
if model_mapping is None:
39+
model_mapping = self.ts.map_to_vcf_model()
40+
41+
individuals_nodes = model_mapping.individuals_nodes
42+
sample_ids = model_mapping.individuals_name
3243

3344
self.is_phased = True
3445
if individuals_nodes.shape[1] == 1:
@@ -241,8 +252,7 @@ def convert(
241252
ts_path,
242253
zarr_path,
243254
*,
244-
individuals_nodes=None,
245-
sample_ids=None,
255+
model_mapping=None,
246256
contig_id=None,
247257
isolated_as_missing=False,
248258
variants_chunk_size=None,
@@ -252,8 +262,7 @@ def convert(
252262
):
253263
tskit_format = TskitFormat(
254264
ts_path,
255-
individuals_nodes=individuals_nodes,
256-
sample_ids=sample_ids,
265+
model_mapping=model_mapping,
257266
contig_id=contig_id,
258267
isolated_as_missing=isolated_as_missing,
259268
)

tests/test_tskit.py

Lines changed: 110 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
import os
2-
import tempfile
31
from unittest import mock
42

53
import numpy as np
@@ -13,6 +11,22 @@
1311
from bio2zarr import vcf
1412

1513

14+
def test_missing_dependency():
15+
with mock.patch(
16+
"importlib.import_module",
17+
side_effect=ImportError("No module named 'tskit'"),
18+
):
19+
with pytest.raises(ImportError) as exc_info:
20+
tsk.convert(
21+
"UNUSED_PATH",
22+
"UNUSED_PATH",
23+
)
24+
assert (
25+
"This process requires the optional tskit module. Install "
26+
"it with: pip install bio2zarr[tskit]" in str(exc_info.value)
27+
)
28+
29+
1630
def simple_ts(add_individuals=False):
1731
tables = tskit.TableCollection(sequence_length=100)
1832
for _ in range(4):
@@ -37,98 +51,96 @@ def simple_ts(add_individuals=False):
3751
return tables.tree_sequence()
3852

3953

40-
class TestTskit:
41-
def test_simple_tree_sequence(self, tmp_path):
42-
tree_sequence = simple_ts()
43-
tree_sequence.dump(tmp_path / "test.trees")
44-
45-
# Manually specify the individuals_nodes, other tests use
46-
# tsk individuals.
47-
ind_nodes = np.array([[0, 1], [2, 3]])
48-
49-
with tempfile.TemporaryDirectory() as tempdir:
50-
zarr_path = os.path.join(tempdir, "test_output.zarr")
51-
tsk.convert(
52-
tmp_path / "test.trees",
53-
zarr_path,
54-
individuals_nodes=ind_nodes,
55-
show_progress=False,
56-
)
57-
zroot = zarr.open(zarr_path, mode="r")
58-
pos = zroot["variant_position"][:]
59-
assert pos.shape == (3,)
60-
assert pos.dtype == np.int8
61-
assert np.array_equal(pos, [10, 20, 30])
62-
63-
alleles = zroot["variant_allele"][:]
64-
assert alleles.shape == (3, 2)
65-
assert alleles.dtype == "O"
66-
assert np.array_equal(alleles, [["A", "TTTT"], ["CCC", "G"], ["G", "AA"]])
67-
68-
lengths = zroot["variant_length"][:]
69-
assert lengths.shape == (3,)
70-
assert lengths.dtype == np.int8
71-
assert np.array_equal(lengths, [1, 3, 1])
72-
73-
genotypes = zroot["call_genotype"][:]
74-
assert genotypes.shape == (3, 2, 2)
75-
assert genotypes.dtype == np.int8
76-
assert np.array_equal(
77-
genotypes, [[[1, 1], [0, 0]], [[0, 0], [1, 1]], [[1, 0], [0, 0]]]
78-
)
54+
class TestSimpleTs:
55+
@pytest.fixture()
56+
def conversion(self, tmp_path):
57+
ts = simple_ts()
58+
zarr_path = tmp_path / "test_output.vcz"
59+
tsk.convert(ts, zarr_path)
60+
zroot = zarr.open(zarr_path, mode="r")
61+
return ts, zroot
62+
63+
def test_position(self, conversion):
64+
ts, zroot = conversion
65+
66+
pos = zroot["variant_position"][:]
67+
assert pos.shape == (3,)
68+
assert pos.dtype == np.int8
69+
assert np.array_equal(pos, [10, 20, 30])
70+
71+
def test_alleles(self, conversion):
72+
ts, zroot = conversion
73+
alleles = zroot["variant_allele"][:]
74+
assert alleles.shape == (3, 2)
75+
assert alleles.dtype == "O"
76+
assert np.array_equal(alleles, [["A", "TTTT"], ["CCC", "G"], ["G", "AA"]])
77+
78+
def test_variant_length(self, conversion):
79+
ts, zroot = conversion
80+
lengths = zroot["variant_length"][:]
81+
assert lengths.shape == (3,)
82+
assert lengths.dtype == np.int8
83+
assert np.array_equal(lengths, [1, 3, 1])
84+
85+
def test_genotypes(self, conversion):
86+
ts, zroot = conversion
87+
genotypes = zroot["call_genotype"][:]
88+
assert genotypes.shape == (3, 4, 1)
89+
assert genotypes.dtype == np.int8
90+
assert np.array_equal(
91+
genotypes,
92+
[[[1], [1], [0], [0]], [[0], [0], [1], [1]], [[1], [0], [0], [0]]],
93+
)
7994

80-
phased = zroot["call_genotype_phased"][:]
81-
assert phased.shape == (3, 2)
82-
assert phased.dtype == "bool"
83-
assert np.all(phased)
84-
85-
contigs = zroot["contig_id"][:]
86-
assert contigs.shape == (1,)
87-
assert contigs.dtype == "O"
88-
assert np.array_equal(contigs, ["1"])
89-
90-
contig = zroot["variant_contig"][:]
91-
assert contig.shape == (3,)
92-
assert contig.dtype == np.int8
93-
assert np.array_equal(contig, [0, 0, 0])
94-
95-
samples = zroot["sample_id"][:]
96-
assert samples.shape == (2,)
97-
assert samples.dtype == "O"
98-
assert np.array_equal(samples, ["tsk_0", "tsk_1"])
99-
100-
region_index = zroot["region_index"][:]
101-
assert region_index.shape == (1, 6)
102-
assert region_index.dtype == np.int8
103-
assert np.array_equal(region_index, [[0, 0, 10, 30, 30, 3]])
104-
105-
assert set(zroot.array_keys()) == {
106-
"variant_position",
107-
"variant_allele",
108-
"variant_length",
109-
"call_genotype",
110-
"call_genotype_phased",
111-
"call_genotype_mask",
112-
"contig_id",
113-
"variant_contig",
114-
"sample_id",
115-
"region_index",
116-
}
117-
118-
def test_missing_dependency(self):
119-
with mock.patch(
120-
"importlib.import_module",
121-
side_effect=ImportError("No module named 'tskit'"),
122-
):
123-
with pytest.raises(ImportError) as exc_info:
124-
tsk.convert(
125-
"UNUSED_PATH",
126-
"UNUSED_PATH",
127-
)
128-
assert (
129-
"This process requires the optional tskit module. Install "
130-
"it with: pip install bio2zarr[tskit]" in str(exc_info.value)
131-
)
95+
def test_phased(self, conversion):
96+
ts, zroot = conversion
97+
phased = zroot["call_genotype_phased"][:]
98+
assert phased.shape == (3, 4)
99+
assert phased.dtype == "bool"
100+
assert np.all(~phased)
101+
102+
def test_contig_id(self, conversion):
103+
ts, zroot = conversion
104+
contigs = zroot["contig_id"][:]
105+
assert contigs.shape == (1,)
106+
assert contigs.dtype == "O"
107+
assert np.array_equal(contigs, ["1"])
108+
109+
def test_variant_contig(self, conversion):
110+
ts, zroot = conversion
111+
contig = zroot["variant_contig"][:]
112+
assert contig.shape == (3,)
113+
assert contig.dtype == np.int8
114+
assert np.array_equal(contig, [0, 0, 0])
115+
116+
def test_sample_id(self, conversion):
117+
ts, zroot = conversion
118+
samples = zroot["sample_id"][:]
119+
assert samples.shape == (4,)
120+
assert samples.dtype == "O"
121+
assert np.array_equal(samples, ["tsk_0", "tsk_1", "tsk_2", "tsk_3"])
122+
123+
def test_region_index(self, conversion):
124+
ts, zroot = conversion
125+
region_index = zroot["region_index"][:]
126+
assert region_index.shape == (1, 6)
127+
assert region_index.dtype == np.int8
128+
assert np.array_equal(region_index, [[0, 0, 10, 30, 30, 3]])
129+
130+
def test_fields(self, conversion):
131+
ts, zroot = conversion
132+
assert set(zroot.array_keys()) == {
133+
"variant_position",
134+
"variant_allele",
135+
"variant_length",
136+
"call_genotype",
137+
"call_genotype_phased",
138+
"call_genotype_mask",
139+
"contig_id",
140+
"variant_contig",
141+
"sample_id",
142+
"region_index",
143+
}
132144

133145

134146
class TestTskitFormat:
@@ -463,7 +475,7 @@ def insert_branch_sites(tsk, m=1):
463475
expected_gt_missing = np.array([[1], [0], [-1]])
464476
assert np.array_equal(variant_data_missing.genotypes, expected_gt_missing)
465477

466-
def test_genotype_dtype_selection(self, tmp_path):
478+
def test_genotype_dtype_i1(self, tmp_path):
467479
tables = tskit.TableCollection(sequence_length=100)
468480
for _ in range(4):
469481
tables.nodes.add_row(flags=tskit.NODE_IS_SAMPLE, time=0)
@@ -477,12 +489,12 @@ def test_genotype_dtype_selection(self, tmp_path):
477489
ts_path = tmp_path / "small_alleles.trees"
478490
tree_sequence.dump(ts_path)
479491

480-
ind_nodes = np.array([[0, 1], [2, 3]])
481-
format_obj = tsk.TskitFormat(ts_path, individuals_nodes=ind_nodes)
492+
format_obj = tsk.TskitFormat(ts_path)
482493
schema = format_obj.generate_schema()
483494
call_genotype_spec = next(s for s in schema.fields if s.name == "call_genotype")
484495
assert call_genotype_spec.dtype == "i1"
485496

497+
def test_genotype_dtype_i4(self, tmp_path):
486498
tables = tskit.TableCollection(sequence_length=100)
487499
for _ in range(4):
488500
tables.nodes.add_row(flags=tskit.NODE_IS_SAMPLE, time=0)
@@ -498,7 +510,7 @@ def test_genotype_dtype_selection(self, tmp_path):
498510
ts_path = tmp_path / "large_alleles.trees"
499511
tree_sequence.dump(ts_path)
500512

501-
format_obj = tsk.TskitFormat(ts_path, individuals_nodes=ind_nodes)
513+
format_obj = tsk.TskitFormat(ts_path)
502514
schema = format_obj.generate_schema()
503515
call_genotype_spec = next(s for s in schema.fields if s.name == "call_genotype")
504516
assert call_genotype_spec.dtype == "i4"
@@ -508,6 +520,7 @@ def test_genotype_dtype_selection(self, tmp_path):
508520
"ts",
509521
[
510522
simple_ts(add_individuals=True),
523+
simple_ts(add_individuals=False),
511524
],
512525
)
513526
def test_against_tskit_vcf_output(ts, tmp_path):

0 commit comments

Comments
 (0)