Skip to content

Commit b1f33d3

Browse files
Bump up test coverage
Test cleaning
1 parent 3c22c35 commit b1f33d3

File tree

2 files changed

+51
-38
lines changed

2 files changed

+51
-38
lines changed

bio2zarr/tskit.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,9 +45,7 @@ def __init__(
4545
if self._num_samples < 1:
4646
raise ValueError("individuals_nodes must have at least one sample")
4747
self.max_ploidy = individuals_nodes.shape[1]
48-
if sample_ids is None:
49-
sample_ids = [f"tsk_{j}" for j in range(self._num_samples)]
50-
elif len(sample_ids) != self._num_samples:
48+
if len(sample_ids) != self._num_samples:
5149
raise ValueError(
5250
f"Length of sample_ids ({len(sample_ids)}) does not match "
5351
f"number of samples ({self._num_samples})"

tests/test_tskit.py

Lines changed: 50 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,12 @@ def test_missing_dependency():
2929
)
3030

3131

32+
def tskit_model_mapping(ind_nodes, ind_names=None):
33+
if ind_names is None:
34+
ind_names = ["tsk{j}" for j in range(len(ind_nodes))]
35+
return tskit.VcfModelMapping(ind_nodes, ind_names)
36+
37+
3238
def add_mutations(ts):
3339
# Add some mutation to the tree sequence. This guarantees that
3440
# we have variation at all sites > 0.
@@ -88,15 +94,6 @@ def insert_branch_sites(ts, m=1):
8894
return tables.tree_sequence()
8995

9096

91-
@pytest.fixture()
92-
def fx_ts_isolated_samples():
93-
tables = tskit.Tree.generate_balanced(2, span=10).tree_sequence.dump_tables()
94-
# This also tests sample nodes that are not a single block at
95-
# the start of the nodes table.
96-
tables.nodes.add_row(time=0, flags=tskit.NODE_IS_SAMPLE)
97-
return insert_branch_sites(tables.tree_sequence())
98-
99-
10097
class TestSimpleTs:
10198
@pytest.fixture()
10299
def conversion(self, tmp_path):
@@ -193,17 +190,28 @@ class TestTskitFormat:
193190
"""Unit tests for TskitFormat without using full conversion."""
194191

195192
@pytest.fixture()
196-
def fx_simple_ts(self, tmp_path):
193+
def fx_simple_ts(self):
197194
return simple_ts(add_individuals=True)
198195

199196
@pytest.fixture()
200-
def fx_ts_2_diploids(self, tmp_path):
197+
def fx_ts_2_diploids(self):
201198
ts = msprime.sim_ancestry(2, sequence_length=10, random_seed=42)
202199
return add_mutations(ts)
203200

204201
@pytest.fixture()
205-
def fx_no_individuals_ts(self, tmp_path):
206-
return simple_ts(add_individuals=False)
202+
def fx_ts_isolated_samples(self):
203+
tables = tskit.Tree.generate_balanced(2, span=10).tree_sequence.dump_tables()
204+
# This also tests sample nodes that are not a single block at
205+
# the start of the nodes table.
206+
tables.nodes.add_row(time=0, flags=tskit.NODE_IS_SAMPLE)
207+
return insert_branch_sites(tables.tree_sequence())
208+
209+
def test_path_or_ts_input(self, tmp_path, fx_simple_ts):
210+
f1 = tsk.TskitFormat(fx_simple_ts)
211+
ts_path = tmp_path / "trees.ts"
212+
fx_simple_ts.dump(ts_path)
213+
f2 = tsk.TskitFormat(ts_path)
214+
f1.ts.tables.assert_equals(f2.ts.tables)
207215

208216
def test_small_position_dtype(self):
209217
tables = tskit.TableCollection(sequence_length=100)
@@ -311,6 +319,23 @@ def test_iter_field(self, fx_simple_ts):
311319
with pytest.raises(ValueError, match="Unknown field"):
312320
list(format_obj.iter_field("unknown_field", None, 0, 3))
313321

322+
def test_zero_samples(self, fx_simple_ts):
323+
model_mapping = tskit_model_mapping(np.array([]))
324+
with pytest.raises(ValueError, match="at least one sample"):
325+
tsk.TskitFormat(fx_simple_ts, model_mapping=model_mapping)
326+
327+
def test_no_valid_samples(self, fx_simple_ts):
328+
model_mapping = fx_simple_ts.map_to_vcf_model()
329+
model_mapping.individuals_nodes[:] = -1
330+
with pytest.raises(ValueError, match="at least one valid sample"):
331+
tsk.TskitFormat(fx_simple_ts, model_mapping=model_mapping)
332+
333+
def test_model_size_mismatch(self, fx_simple_ts):
334+
model_mapping = fx_simple_ts.map_to_vcf_model()
335+
model_mapping.individuals_name = ["x"]
336+
with pytest.raises(ValueError, match="match number of samples"):
337+
tsk.TskitFormat(fx_simple_ts, model_mapping=model_mapping)
338+
314339
@pytest.mark.parametrize(
315340
("ind_nodes", "expected_gts"),
316341
[
@@ -347,10 +372,7 @@ def test_iter_field(self, fx_simple_ts):
347372
],
348373
)
349374
def test_iter_alleles_and_genotypes(self, fx_simple_ts, ind_nodes, expected_gts):
350-
model_mapping = tskit.VcfModelMapping(
351-
ind_nodes, ["tsk{j}" for j in range(len(ind_nodes))]
352-
)
353-
375+
model_mapping = tskit_model_mapping(ind_nodes)
354376
format_obj = tsk.TskitFormat(fx_simple_ts, model_mapping=model_mapping)
355377

356378
shape = (2, 2) # (num_samples, max_ploidy)
@@ -375,9 +397,7 @@ def test_iter_alleles_and_genotypes(self, fx_simple_ts, ind_nodes, expected_gts)
375397
def test_iter_alleles_and_genotypes_missing_node(self, fx_ts_2_diploids):
376398
# Test with node ID that doesn't exist in tree sequence (out of range)
377399
ind_nodes = np.array([[10, 11], [12, 13]], dtype=np.int32)
378-
model_mapping = tskit.VcfModelMapping(
379-
ind_nodes, ["tsk{j}" for j in range(len(ind_nodes))]
380-
)
400+
model_mapping = tskit_model_mapping(ind_nodes)
381401
format_obj = tsk.TskitFormat(fx_ts_2_diploids, model_mapping=model_mapping)
382402
shape = (2, 2)
383403
with pytest.raises(
@@ -387,9 +407,7 @@ def test_iter_alleles_and_genotypes_missing_node(self, fx_ts_2_diploids):
387407

388408
def test_isolated_as_missing(self, fx_ts_isolated_samples):
389409
ind_nodes = np.array([[0], [1], [3]])
390-
model_mapping = tskit.VcfModelMapping(
391-
ind_nodes, ["tsk{j}" for j in range(len(ind_nodes))]
392-
)
410+
model_mapping = tskit_model_mapping(ind_nodes)
393411

394412
format_obj_default = tsk.TskitFormat(
395413
fx_ts_isolated_samples,
@@ -427,7 +445,7 @@ def test_isolated_as_missing(self, fx_ts_isolated_samples):
427445
expected_gt_missing = np.array([[1], [0], [-1]])
428446
nt.assert_array_equal(variant_data_missing.genotypes, expected_gt_missing)
429447

430-
def test_genotype_dtype_i1(self, tmp_path):
448+
def test_genotype_dtype_i1(self):
431449
tables = tskit.TableCollection(sequence_length=100)
432450
for _ in range(4):
433451
tables.nodes.add_row(flags=tskit.NODE_IS_SAMPLE, time=0)
@@ -438,15 +456,13 @@ def test_genotype_dtype_i1(self, tmp_path):
438456
tables.mutations.add_row(site=site_id, node=0, derived_state="T")
439457
tables.sort()
440458
tree_sequence = tables.tree_sequence()
441-
ts_path = tmp_path / "small_alleles.trees"
442-
tree_sequence.dump(ts_path)
443459

444-
format_obj = tsk.TskitFormat(ts_path)
460+
format_obj = tsk.TskitFormat(tree_sequence)
445461
schema = format_obj.generate_schema()
446462
call_genotype_spec = next(s for s in schema.fields if s.name == "call_genotype")
447463
assert call_genotype_spec.dtype == "i1"
448464

449-
def test_genotype_dtype_i4(self, tmp_path):
465+
def test_genotype_dtype_i4(self):
450466
tables = tskit.TableCollection(sequence_length=100)
451467
for _ in range(4):
452468
tables.nodes.add_row(flags=tskit.NODE_IS_SAMPLE, time=0)
@@ -459,10 +475,8 @@ def test_genotype_dtype_i4(self, tmp_path):
459475

460476
tables.sort()
461477
tree_sequence = tables.tree_sequence()
462-
ts_path = tmp_path / "large_alleles.trees"
463-
tree_sequence.dump(ts_path)
464478

465-
format_obj = tsk.TskitFormat(ts_path)
479+
format_obj = tsk.TskitFormat(tree_sequence)
466480
schema = format_obj.generate_schema()
467481
call_genotype_spec = next(s for s in schema.fields if s.name == "call_genotype")
468482
assert call_genotype_spec.dtype == "i4"
@@ -471,6 +485,7 @@ def test_genotype_dtype_i4(self, tmp_path):
471485
@pytest.mark.parametrize(
472486
"ts",
473487
[
488+
# Standard individuals-with-a-given-ploidy situation
474489
add_mutations(
475490
msprime.sim_ancestry(4, ploidy=1, sequence_length=10, random_seed=42)
476491
),
@@ -480,20 +495,20 @@ def test_genotype_dtype_i4(self, tmp_path):
480495
add_mutations(
481496
msprime.sim_ancestry(3, ploidy=12, sequence_length=10, random_seed=142)
482497
),
498+
# No individuals, ploidy1
499+
add_mutations(msprime.simulate(4, length=10, random_seed=412)),
483500
],
484501
)
485502
def test_against_tskit_vcf_output(ts, tmp_path):
486503
vcf_path = tmp_path / "ts.vcf"
487-
ts_path = tmp_path / "ts.trees"
488-
ts.dump(ts_path)
489504
with open(vcf_path, "w") as f:
490505
ts.write_vcf(f)
491506

492507
tskit_zarr = tmp_path / "tskit.zarr"
493508
vcf_zarr = tmp_path / "vcf.zarr"
494-
tsk.convert(ts_path, tskit_zarr)
509+
tsk.convert(ts, tskit_zarr, worker_processes=0)
495510

496-
vcf.convert([vcf_path], vcf_zarr)
511+
vcf.convert([vcf_path], vcf_zarr, worker_processes=0)
497512
ds1 = sg.load_dataset(tskit_zarr)
498513
ds2 = (
499514
sg.load_dataset(vcf_zarr)

0 commit comments

Comments
 (0)