Skip to content

Commit cb94a58

Browse files
Patch up some details in tskit model to align with VCF
1 parent 7f985a2 commit cb94a58

File tree

3 files changed

+102
-57
lines changed

3 files changed

+102
-57
lines changed

bio2zarr/tskit.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,12 @@ def __init__(
3030
if individuals_nodes is None:
3131
individuals_nodes = self.ts.individuals_nodes
3232

33+
self.is_phased = True
34+
if individuals_nodes.shape[1] == 1:
35+
# For simplicity we defined haploids as unphased to do the same thing as the
36+
# VCF conversion code. We should just omit the array for haploids anyway.
37+
self.is_phased = False
38+
3339
self._num_samples = individuals_nodes.shape[0]
3440
if self._num_samples < 1:
3541
raise ValueError("individuals_nodes must have at least one sample")
@@ -90,7 +96,7 @@ def iter_field(self, field_name, shape, start, stop):
9096

9197
def iter_alleles_and_genotypes(self, start, stop, shape, num_alleles):
9298
# All genotypes in tskit are considered phased
93-
phased = np.ones(shape[:-1], dtype=bool)
99+
phased = np.full(shape[:-1], self.is_phased, dtype=bool)
94100

95101
for variant in self.ts.variants(
96102
isolated_as_missing=self.isolated_as_missing,
@@ -101,14 +107,15 @@ def iter_alleles_and_genotypes(self, start, stop, shape, num_alleles):
101107
):
102108
gt = np.full(shape, constants.INT_FILL, dtype=np.int8)
103109
alleles = np.full(num_alleles, constants.STR_FILL, dtype="O")
104-
variant_length = 0
110+
# length is the length of the REF allele unless other fields
111+
# are included.
112+
variant_length = len(variant.alleles[0])
105113
for i, allele in enumerate(variant.alleles):
106114
# None is returned by tskit in the case of a missing allele
107115
if allele is None:
108116
continue
109117
assert i < num_alleles
110118
alleles[i] = allele
111-
variant_length = max(variant_length, len(allele))
112119
gt[self.sample_indices, self.ploidy_indices] = variant.genotypes[
113120
self.genotype_indices
114121
]

tests/test_cli.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -647,7 +647,7 @@ def test_vcf_convert_overwrite_zarr_confirm_yes(self, mocked, tmp_path, response
647647
@pytest.mark.parametrize(("progress", "flag"), [(True, "-P"), (False, "-Q")])
648648
@mock.patch("bio2zarr.tskit.convert")
649649
def test_convert_tskit(self, mocked, tmp_path, progress, flag):
650-
ts_path = "tests/data/ts/example.trees"
650+
ts_path = "tests/data/tskit/example.trees"
651651
zarr_path = tmp_path / "zarr"
652652
runner = ct.CliRunner()
653653
result = runner.invoke(
@@ -669,7 +669,7 @@ def test_convert_tskit(self, mocked, tmp_path, progress, flag):
669669
@pytest.mark.parametrize("response", ["y", "Y", "yes"])
670670
@mock.patch("bio2zarr.tskit.convert")
671671
def test_tskit_convert_overwrite_zarr_confirm_yes(self, mocked, tmp_path, response):
672-
ts_path = "tests/data/ts/example.trees"
672+
ts_path = "tests/data/tskit/example.trees"
673673
zarr_path = tmp_path / "zarr"
674674
zarr_path.mkdir()
675675
runner = ct.CliRunner()
@@ -691,7 +691,7 @@ def test_tskit_convert_overwrite_zarr_confirm_yes(self, mocked, tmp_path, respon
691691
@pytest.mark.parametrize("response", ["n", "N", "No"])
692692
@mock.patch("bio2zarr.tskit.convert")
693693
def test_tskit_convert_overwrite_zarr_confirm_no(self, mocked, tmp_path, response):
694-
ts_path = "tests/data/ts/example.trees"
694+
ts_path = "tests/data/tskit/example.trees"
695695
zarr_path = tmp_path / "zarr"
696696
zarr_path.mkdir()
697697
runner = ct.CliRunner()
@@ -708,7 +708,7 @@ def test_tskit_convert_overwrite_zarr_confirm_no(self, mocked, tmp_path, respons
708708
@pytest.mark.parametrize("force_arg", ["-f", "--force"])
709709
@mock.patch("bio2zarr.tskit.convert")
710710
def test_tskit_convert_overwrite_zarr_force(self, mocked, tmp_path, force_arg):
711-
ts_path = "tests/data/ts/example.trees"
711+
ts_path = "tests/data/tskit/example.trees"
712712
zarr_path = tmp_path / "zarr"
713713
zarr_path.mkdir()
714714
runner = ct.CliRunner()
@@ -728,7 +728,7 @@ def test_tskit_convert_overwrite_zarr_force(self, mocked, tmp_path, force_arg):
728728

729729
@mock.patch("bio2zarr.tskit.convert")
730730
def test_tskit_convert_with_options(self, mocked, tmp_path):
731-
ts_path = "tests/data/ts/example.trees"
731+
ts_path = "tests/data/tskit/example.trees"
732732
zarr_path = tmp_path / "zarr"
733733
runner = ct.CliRunner()
734734
result = runner.invoke(
@@ -1028,7 +1028,7 @@ def test_part_size_multiple_vcfs(self):
10281028

10291029
class TestTskitEndToEnd:
10301030
def test_convert(self, tmp_path):
1031-
ts_path = "tests/data/ts/example.trees"
1031+
ts_path = "tests/data/tskit/example.trees"
10321032
zarr_path = tmp_path / "zarr"
10331033
runner = ct.CliRunner()
10341034
result = runner.invoke(

tests/test_tskit.py

Lines changed: 86 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -4,42 +4,51 @@
44

55
import numpy as np
66
import pytest
7+
import sgkit as sg
78
import tskit
9+
import xarray.testing as xt
810
import zarr
911

10-
from bio2zarr import tskit as ts
12+
from bio2zarr import tskit as tsk
13+
from bio2zarr import vcf
14+
15+
16+
def simple_ts(add_individuals=False):
17+
tables = tskit.TableCollection(sequence_length=100)
18+
for _ in range(4):
19+
ind = -1
20+
if add_individuals:
21+
ind = tables.individuals.add_row()
22+
tables.nodes.add_row(flags=tskit.NODE_IS_SAMPLE, time=0, individual=ind)
23+
tables.nodes.add_row(flags=0, time=1) # MRCA for 0,1
24+
tables.nodes.add_row(flags=0, time=1) # MRCA for 2,3
25+
tables.edges.add_row(left=0, right=100, parent=4, child=0)
26+
tables.edges.add_row(left=0, right=100, parent=4, child=1)
27+
tables.edges.add_row(left=0, right=100, parent=5, child=2)
28+
tables.edges.add_row(left=0, right=100, parent=5, child=3)
29+
site_id = tables.sites.add_row(position=10, ancestral_state="A")
30+
tables.mutations.add_row(site=site_id, node=4, derived_state="TTTT")
31+
site_id = tables.sites.add_row(position=20, ancestral_state="CCC")
32+
tables.mutations.add_row(site=site_id, node=5, derived_state="G")
33+
site_id = tables.sites.add_row(position=30, ancestral_state="G")
34+
tables.mutations.add_row(site=site_id, node=0, derived_state="AA")
35+
36+
tables.sort()
37+
return tables.tree_sequence()
1138

1239

1340
class TestTskit:
1441
def test_simple_tree_sequence(self, tmp_path):
15-
tables = tskit.TableCollection(sequence_length=100)
16-
tables.nodes.add_row(flags=tskit.NODE_IS_SAMPLE, time=0)
17-
tables.nodes.add_row(flags=tskit.NODE_IS_SAMPLE, time=0)
18-
tables.nodes.add_row(flags=tskit.NODE_IS_SAMPLE, time=0)
19-
tables.nodes.add_row(flags=tskit.NODE_IS_SAMPLE, time=0)
20-
tables.nodes.add_row(flags=0, time=1) # MRCA for 0,1
21-
tables.nodes.add_row(flags=0, time=1) # MRCA for 2,3
22-
tables.edges.add_row(left=0, right=100, parent=4, child=0)
23-
tables.edges.add_row(left=0, right=100, parent=4, child=1)
24-
tables.edges.add_row(left=0, right=100, parent=5, child=2)
25-
tables.edges.add_row(left=0, right=100, parent=5, child=3)
26-
site_id = tables.sites.add_row(position=10, ancestral_state="A")
27-
tables.mutations.add_row(site=site_id, node=4, derived_state="TTTT")
28-
site_id = tables.sites.add_row(position=20, ancestral_state="CCC")
29-
tables.mutations.add_row(site=site_id, node=5, derived_state="G")
30-
site_id = tables.sites.add_row(position=30, ancestral_state="G")
31-
tables.mutations.add_row(site=site_id, node=0, derived_state="AA")
32-
tables.sort()
33-
tree_sequence = tables.tree_sequence()
42+
tree_sequence = simple_ts()
3443
tree_sequence.dump(tmp_path / "test.trees")
3544

3645
# Manually specify the individuals_nodes, other tests use
37-
# ts individuals.
46+
# tsk individuals.
3847
ind_nodes = np.array([[0, 1], [2, 3]])
3948

4049
with tempfile.TemporaryDirectory() as tempdir:
4150
zarr_path = os.path.join(tempdir, "test_output.zarr")
42-
ts.convert(
51+
tsk.convert(
4352
tmp_path / "test.trees",
4453
zarr_path,
4554
individuals_nodes=ind_nodes,
@@ -59,7 +68,7 @@ def test_simple_tree_sequence(self, tmp_path):
5968
lengths = zroot["variant_length"][:]
6069
assert lengths.shape == (3,)
6170
assert lengths.dtype == np.int8
62-
assert np.array_equal(lengths, [4, 3, 2])
71+
assert np.array_equal(lengths, [1, 3, 1])
6372

6473
genotypes = zroot["call_genotype"][:]
6574
assert genotypes.shape == (3, 2, 2)
@@ -91,7 +100,7 @@ def test_simple_tree_sequence(self, tmp_path):
91100
region_index = zroot["region_index"][:]
92101
assert region_index.shape == (1, 6)
93102
assert region_index.dtype == np.int8
94-
assert np.array_equal(region_index, [[0, 0, 10, 30, 31, 3]])
103+
assert np.array_equal(region_index, [[0, 0, 10, 30, 30, 3]])
95104

96105
assert set(zroot.array_keys()) == {
97106
"variant_position",
@@ -112,7 +121,7 @@ def test_missing_dependency(self):
112121
side_effect=ImportError("No module named 'tskit'"),
113122
):
114123
with pytest.raises(ImportError) as exc_info:
115-
ts.convert(
124+
tsk.convert(
116125
"UNUSED_PATH",
117126
"UNUSED_PATH",
118127
)
@@ -193,15 +202,15 @@ def test_position_dtype_selection(self, tmp_path):
193202
ts_large.dump(ts_path_large)
194203

195204
ind_nodes = np.array([[0], [1]])
196-
format_obj_small = ts.TskitFormat(ts_path_small, individuals_nodes=ind_nodes)
205+
format_obj_small = tsk.TskitFormat(ts_path_small, individuals_nodes=ind_nodes)
197206
schema_small = format_obj_small.generate_schema()
198207

199208
position_field = next(
200209
f for f in schema_small.fields if f.name == "variant_position"
201210
)
202211
assert position_field.dtype == "i1"
203212

204-
format_obj_large = ts.TskitFormat(ts_path_large, individuals_nodes=ind_nodes)
213+
format_obj_large = tsk.TskitFormat(ts_path_large, individuals_nodes=ind_nodes)
205214
schema_large = format_obj_large.generate_schema()
206215

207216
position_field = next(
@@ -213,14 +222,14 @@ def test_initialization(self, simple_ts):
213222
ts_path, tree_sequence = simple_ts
214223

215224
# Test with default parameters
216-
format_obj = ts.TskitFormat(ts_path)
225+
format_obj = tsk.TskitFormat(ts_path)
217226
assert format_obj.path == ts_path
218227
assert format_obj.ts.num_sites == tree_sequence.num_sites
219228
assert format_obj.contig_id == "1"
220229
assert not format_obj.isolated_as_missing
221230

222231
# Test with custom parameters
223-
format_obj = ts.TskitFormat(
232+
format_obj = tsk.TskitFormat(
224233
ts_path,
225234
sample_ids=["ind1", "ind2"],
226235
contig_id="chr1",
@@ -234,7 +243,7 @@ def test_initialization(self, simple_ts):
234243

235244
def test_basic_properties(self, simple_ts):
236245
ts_path, _ = simple_ts
237-
format_obj = ts.TskitFormat(ts_path)
246+
format_obj = tsk.TskitFormat(ts_path)
238247

239248
assert format_obj.num_records == format_obj.ts.num_sites
240249
assert format_obj.num_samples == 2 # Two individuals
@@ -251,7 +260,7 @@ def test_basic_properties(self, simple_ts):
251260
def test_custom_sample_ids(self, simple_ts):
252261
ts_path, _ = simple_ts
253262
custom_ids = ["sample_X", "sample_Y"]
254-
format_obj = ts.TskitFormat(ts_path, sample_ids=custom_ids)
263+
format_obj = tsk.TskitFormat(ts_path, sample_ids=custom_ids)
255264

256265
assert format_obj.num_samples == 2
257266
assert len(format_obj.samples) == 2
@@ -262,11 +271,11 @@ def test_sample_id_length_mismatch(self, simple_ts):
262271
ts_path, _ = simple_ts
263272
# Wrong number of sample IDs
264273
with pytest.raises(ValueError, match="Length of sample_ids.*does not match"):
265-
ts.TskitFormat(ts_path, sample_ids=["only_one_id"])
274+
tsk.TskitFormat(ts_path, sample_ids=["only_one_id"])
266275

267276
def test_schema_generation(self, simple_ts):
268277
ts_path, _ = simple_ts
269-
format_obj = ts.TskitFormat(ts_path)
278+
format_obj = tsk.TskitFormat(ts_path)
270279

271280
schema = format_obj.generate_schema()
272281
assert schema.dimensions["variants"].size == 3
@@ -289,13 +298,13 @@ def test_schema_generation(self, simple_ts):
289298

290299
def test_iter_contig(self, simple_ts):
291300
ts_path, _ = simple_ts
292-
format_obj = ts.TskitFormat(ts_path)
301+
format_obj = tsk.TskitFormat(ts_path)
293302
contig_indices = list(format_obj.iter_contig(1, 3))
294303
assert contig_indices == [0, 0]
295304

296305
def test_iter_field(self, simple_ts):
297306
ts_path, _ = simple_ts
298-
format_obj = ts.TskitFormat(ts_path)
307+
format_obj = tsk.TskitFormat(ts_path)
299308
positions = list(format_obj.iter_field("position", None, 0, 3))
300309
assert positions == [10, 20, 30]
301310
positions = list(format_obj.iter_field("position", None, 1, 3))
@@ -341,7 +350,7 @@ def test_iter_field(self, simple_ts):
341350
def test_iter_alleles_and_genotypes(self, simple_ts, ind_nodes, expected_gts):
342351
ts_path, _ = simple_ts
343352

344-
format_obj = ts.TskitFormat(ts_path, individuals_nodes=ind_nodes)
353+
format_obj = tsk.TskitFormat(ts_path, individuals_nodes=ind_nodes)
345354

346355
shape = (2, 2) # (num_samples, max_ploidy)
347356
results = list(format_obj.iter_alleles_and_genotypes(0, 3, shape, 2))
@@ -350,7 +359,7 @@ def test_iter_alleles_and_genotypes(self, simple_ts, ind_nodes, expected_gts):
350359

351360
for i, variant_data in enumerate(results):
352361
if i == 0:
353-
assert variant_data.variant_length == 2
362+
assert variant_data.variant_length == 1
354363
assert np.array_equal(variant_data.alleles, ("A", "TT"))
355364
elif i == 1:
356365
assert variant_data.variant_length == 3
@@ -371,7 +380,7 @@ def test_iter_alleles_and_genotypes_errors(self, simple_ts):
371380

372381
# Test with node ID that doesn't exist in tree sequence (out of range)
373382
invalid_nodes = np.array([[10, 11], [12, 13]], dtype=np.int32)
374-
format_obj = ts.TskitFormat(ts_path, individuals_nodes=invalid_nodes)
383+
format_obj = tsk.TskitFormat(ts_path, individuals_nodes=invalid_nodes)
375384
shape = (2, 2)
376385
with pytest.raises(
377386
tskit.LibraryError, match="out of bounds"
@@ -383,23 +392,23 @@ def test_iter_alleles_and_genotypes_errors(self, simple_ts):
383392
with pytest.raises(
384393
ValueError, match="individuals_nodes must have at least one sample"
385394
):
386-
format_obj = ts.TskitFormat(ts_path, individuals_nodes=empty_nodes)
395+
format_obj = tsk.TskitFormat(ts_path, individuals_nodes=empty_nodes)
387396

388397
# Test with all invalid nodes (-1)
389398
all_invalid = np.full((2, 2), -1, dtype=np.int32)
390399
with pytest.raises(
391400
ValueError, match="individuals_nodes must have at least one valid sample"
392401
):
393-
format_obj = ts.TskitFormat(ts_path, individuals_nodes=all_invalid)
402+
format_obj = tsk.TskitFormat(ts_path, individuals_nodes=all_invalid)
394403

395404
def test_isolated_as_missing(self, tmp_path):
396-
def insert_branch_sites(ts, m=1):
405+
def insert_branch_sites(tsk, m=1):
397406
if m == 0:
398-
return ts
399-
tables = ts.dump_tables()
407+
return tsk
408+
tables = tsk.dump_tables()
400409
tables.sites.clear()
401410
tables.mutations.clear()
402-
for tree in ts.trees():
411+
for tree in tsk.trees():
403412
left, right = tree.interval
404413
delta = (right - left) / (m * len(list(tree.nodes())))
405414
x = left
@@ -422,7 +431,7 @@ def insert_branch_sites(ts, m=1):
422431
ts_path = tmp_path / "isolated_sample.trees"
423432
tree_sequence.dump(ts_path)
424433
ind_nodes = np.array([[0], [1], [3]])
425-
format_obj_default = ts.TskitFormat(
434+
format_obj_default = tsk.TskitFormat(
426435
ts_path, individuals_nodes=ind_nodes, isolated_as_missing=False
427436
)
428437
shape = (3, 1) # (num_samples, max_ploidy)
@@ -438,7 +447,7 @@ def insert_branch_sites(ts, m=1):
438447
expected_gt_default = np.array([[1], [0], [0]])
439448
assert np.array_equal(variant_data_default.genotypes, expected_gt_default)
440449

441-
format_obj_missing = ts.TskitFormat(
450+
format_obj_missing = tsk.TskitFormat(
442451
ts_path, individuals_nodes=ind_nodes, isolated_as_missing=True
443452
)
444453
results_missing = list(
@@ -469,7 +478,7 @@ def test_genotype_dtype_selection(self, tmp_path):
469478
tree_sequence.dump(ts_path)
470479

471480
ind_nodes = np.array([[0, 1], [2, 3]])
472-
format_obj = ts.TskitFormat(ts_path, individuals_nodes=ind_nodes)
481+
format_obj = tsk.TskitFormat(ts_path, individuals_nodes=ind_nodes)
473482
schema = format_obj.generate_schema()
474483
call_genotype_spec = next(s for s in schema.fields if s.name == "call_genotype")
475484
assert call_genotype_spec.dtype == "i1"
@@ -489,7 +498,36 @@ def test_genotype_dtype_selection(self, tmp_path):
489498
ts_path = tmp_path / "large_alleles.trees"
490499
tree_sequence.dump(ts_path)
491500

492-
format_obj = ts.TskitFormat(ts_path, individuals_nodes=ind_nodes)
501+
format_obj = tsk.TskitFormat(ts_path, individuals_nodes=ind_nodes)
493502
schema = format_obj.generate_schema()
494503
call_genotype_spec = next(s for s in schema.fields if s.name == "call_genotype")
495504
assert call_genotype_spec.dtype == "i4"
505+
506+
507+
@pytest.mark.parametrize(
508+
"ts",
509+
[
510+
simple_ts(add_individuals=True),
511+
],
512+
)
513+
def test_against_tskit_vcf_output(ts, tmp_path):
514+
vcf_path = tmp_path / "ts.vcf"
515+
ts_path = tmp_path / "ts.trees"
516+
ts.dump(ts_path)
517+
with open(vcf_path, "w") as f:
518+
ts.write_vcf(f)
519+
520+
tskit_zarr = tmp_path / "tskit.zarr"
521+
vcf_zarr = tmp_path / "vcf.zarr"
522+
tsk.convert(ts_path, tskit_zarr)
523+
524+
vcf.convert([vcf_path], vcf_zarr)
525+
ds1 = sg.load_dataset(tskit_zarr)
526+
ds2 = (
527+
sg.load_dataset(vcf_zarr)
528+
.drop_dims("filters")
529+
.drop_vars(
530+
["variant_id", "variant_id_mask", "variant_quality", "contig_length"]
531+
)
532+
)
533+
xt.assert_equal(ds1, ds2)

0 commit comments

Comments
 (0)