diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index da2c4d98..46002159 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -94,11 +94,11 @@ jobs: python -m venv env-tskit source env-tskit/bin/activate python -m pip install . - python -m bio2zarr tskit2zarr convert tests/data/ts/example.trees ts.vcz > ts.txt 2>&1 || echo $? > ts_exit.txt + python -m bio2zarr tskit2zarr convert tests/data/tskit/example.trees ts.vcz > ts.txt 2>&1 || echo $? > ts_exit.txt test "$(cat ts_exit.txt)" = "1" grep -q "This process requires the optional tskit module. Install it with: pip install bio2zarr\[tskit\]" ts.txt python -m pip install '.[tskit]' - python -m bio2zarr tskit2zarr convert tests/data/ts/example.trees ts.vcz + python -m bio2zarr tskit2zarr convert tests/data/tskit/example.trees ts.vcz deactivate python -m venv env-vcf diff --git a/CHANGELOG.md b/CHANGELOG.md index ab7fae43..0909fd68 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,9 @@ - Make format-specific dependencies optional (#385) +- Change default number of worker processes to zero (#404) to simplify + debugging + Breaking changes - Remove explicit sample, contig and filter lists from the schema. diff --git a/bio2zarr/cli.py b/bio2zarr/cli.py index 6d41ed27..e8dade34 100644 --- a/bio2zarr/cli.py +++ b/bio2zarr/cli.py @@ -8,7 +8,7 @@ import numcodecs import tabulate -from . import plink, provenance, vcf_utils +from . import core, plink, provenance, vcf_utils from . import tskit as tskit_mod from . import vcf as vcf_mod @@ -89,7 +89,12 @@ def list_commands(self, ctx): version = click.version_option(version=f"{provenance.__version__}") worker_processes = click.option( - "-p", "--worker-processes", type=int, default=1, help="Number of worker processes" + "-p", + "--worker-processes", + type=int, + default=core.DEFAULT_WORKER_PROCESSES, + help="Number of worker processes", + show_default=True, ) column_chunk_size = click.option( diff --git a/bio2zarr/core.py b/bio2zarr/core.py index 723dd2b7..2ab2d848 100644 --- a/bio2zarr/core.py +++ b/bio2zarr/core.py @@ -130,12 +130,20 @@ def du(path): return total +# We set the default number of worker processes to 0 because it avoids +# complexity in the call chain and makes things easier to debug by +# default. However, it does use the SynchronousExecutor here, which +# is technically not recommended by the Python docs. +DEFAULT_WORKER_PROCESSES = 0 + + class SynchronousExecutor(cf.Executor): - # Arguably we should use workers=0 as the default and use this + # Since https://github.com/sgkit-dev/bio2zarr/issues/404 we + # set worker_processses=0 as the default and use this # executor implementation. However, the docs are fairly explicit # about saying we shouldn't instantiate Future objects directly, - # so it's best to keep this as a semi-secret debugging interface - # for now. + # so we may need to revisit this is obscure problems start to + # arise. def submit(self, fn, /, *args, **kwargs): future = cf.Future() future.set_result(fn(*args, **kwargs)) diff --git a/bio2zarr/plink.py b/bio2zarr/plink.py index 6ae2b873..500bd422 100644 --- a/bio2zarr/plink.py +++ b/bio2zarr/plink.py @@ -291,7 +291,7 @@ def convert( *, variants_chunk_size=None, samples_chunk_size=None, - worker_processes=1, + worker_processes=core.DEFAULT_WORKER_PROCESSES, show_progress=False, ): plink_format = PlinkFormat(prefix) diff --git a/bio2zarr/tskit.py b/bio2zarr/tskit.py index 2e442461..7a632772 100644 --- a/bio2zarr/tskit.py +++ b/bio2zarr/tskit.py @@ -12,31 +12,46 @@ class TskitFormat(vcz.Source): @core.requires_optional_dependency("tskit", "tskit") def __init__( self, - ts_path, - individuals_nodes=None, - sample_ids=None, + ts, + *, + model_mapping=None, contig_id=None, isolated_as_missing=False, ): import tskit - self._path = ts_path - self.ts = tskit.load(ts_path) + self._path = None + # Future versions here will need to deal with the complexities of + # having lists of tree sequences for multiple chromosomes. + if isinstance(ts, tskit.TreeSequence): + self.ts = ts + else: + # input 'ts' is a path. + self._path = ts + logger.info(f"Loading from {ts}") + self.ts = tskit.load(ts) + logger.info( + f"Input has {self.ts.num_individuals} individuals and " + f"{self.ts.num_sites} sites" + ) + self.contig_id = contig_id if contig_id is not None else "1" self.isolated_as_missing = isolated_as_missing self.positions = self.ts.sites_position - if individuals_nodes is None: - individuals_nodes = self.ts.individuals_nodes + if model_mapping is None: + model_mapping = self.ts.map_to_vcf_model() + + individuals_nodes = model_mapping.individuals_nodes + sample_ids = model_mapping.individuals_name self._num_samples = individuals_nodes.shape[0] + logger.info(f"Converting for {self._num_samples} samples") if self._num_samples < 1: raise ValueError("individuals_nodes must have at least one sample") self.max_ploidy = individuals_nodes.shape[1] - if sample_ids is None: - sample_ids = [f"tsk_{j}" for j in range(self._num_samples)] - elif len(sample_ids) != self._num_samples: + if len(sample_ids) != self._num_samples: raise ValueError( f"Length of sample_ids ({len(sample_ids)}) does not match " f"number of samples ({self._num_samples})" @@ -91,6 +106,7 @@ def iter_field(self, field_name, shape, start, stop): def iter_alleles_and_genotypes(self, start, stop, shape, num_alleles): # All genotypes in tskit are considered phased phased = np.ones(shape[:-1], dtype=bool) + logger.debug(f"Getting genotpes start={start} stop={stop}") for variant in self.ts.variants( isolated_as_missing=self.isolated_as_missing, @@ -101,14 +117,15 @@ def iter_alleles_and_genotypes(self, start, stop, shape, num_alleles): ): gt = np.full(shape, constants.INT_FILL, dtype=np.int8) alleles = np.full(num_alleles, constants.STR_FILL, dtype="O") - variant_length = 0 + # length is the length of the REF allele unless other fields + # are included. + variant_length = len(variant.alleles[0]) for i, allele in enumerate(variant.alleles): # None is returned by tskit in the case of a missing allele if allele is None: continue assert i < num_alleles alleles[i] = allele - variant_length = max(variant_length, len(allele)) gt[self.sample_indices, self.ploidy_indices] = variant.genotypes[ self.genotype_indices ] @@ -231,22 +248,35 @@ def generate_schema( def convert( - ts_path, - zarr_path, + ts_or_path, + vcz_path, *, - individuals_nodes=None, - sample_ids=None, + model_mapping=None, contig_id=None, isolated_as_missing=False, variants_chunk_size=None, samples_chunk_size=None, - worker_processes=1, + worker_processes=core.DEFAULT_WORKER_PROCESSES, show_progress=False, ): + """ + Convert a :class:`tskit.TreeSequence` (or path to a tree sequence + file) to VCF Zarr format stored at the specified path. + + .. todo:: Document parameters + """ + # FIXME there's some tricky details here in how we're handling + # parallelism that we'll need to tackle properly, and maybe + # review the current structures a bit. Basically, it looks like + # we're pickling/unpickling the format object when we have + # multiple workers, and this results in several copies of the + # tree sequence object being pass around. This is fine most + # of the time, but results in lots of memory being used when + # we're dealing with really massive files. + # See https://github.com/sgkit-dev/bio2zarr/issues/403 tskit_format = TskitFormat( - ts_path, - individuals_nodes=individuals_nodes, - sample_ids=sample_ids, + ts_or_path, + model_mapping=model_mapping, contig_id=contig_id, isolated_as_missing=isolated_as_missing, ) @@ -254,7 +284,7 @@ def convert( variants_chunk_size=variants_chunk_size, samples_chunk_size=samples_chunk_size, ) - zarr_path = pathlib.Path(zarr_path) + zarr_path = pathlib.Path(vcz_path) vzw = vcz.VcfZarrWriter(TskitFormat, zarr_path) # Rough heuristic to split work up enough to keep utilisation high target_num_partitions = max(1, worker_processes * 4) diff --git a/bio2zarr/vcf.py b/bio2zarr/vcf.py index c64023af..29654c73 100644 --- a/bio2zarr/vcf.py +++ b/bio2zarr/vcf.py @@ -285,7 +285,12 @@ def scan_vcf(path, target_num_partitions): return metadata, vcf.raw_header -def scan_vcfs(paths, show_progress, target_num_partitions, worker_processes=1): +def scan_vcfs( + paths, + show_progress, + target_num_partitions, + worker_processes=core.DEFAULT_WORKER_PROCESSES, +): logger.info( f"Scanning {len(paths)} VCFs attempting to split into {target_num_partitions}" f" partitions." @@ -1051,6 +1056,10 @@ def iter_genotypes(self, shape, start, stop): phased = value[:, -1] if value is not None else None sanitised_genotypes = sanitise_value_int_2d(shape, genotypes) sanitised_phased = sanitise_value_int_1d(shape[:-1], phased) + # Force haploids to always be phased + # https://github.com/sgkit-dev/bio2zarr/issues/399 + if sanitised_genotypes.shape[1] == 1: + sanitised_phased[:] = True yield sanitised_genotypes, sanitised_phased def iter_alleles_and_genotypes(self, start, stop, shape, num_alleles): @@ -1294,7 +1303,7 @@ def init( vcfs, *, column_chunk_size=16, - worker_processes=1, + worker_processes=core.DEFAULT_WORKER_PROCESSES, target_num_partitions=None, show_progress=False, compressor=None, @@ -1446,7 +1455,9 @@ def process_partition(self, partition_index): f"{num_records} records last_pos={last_position}" ) - def explode(self, *, worker_processes=1, show_progress=False): + def explode( + self, *, worker_processes=core.DEFAULT_WORKER_PROCESSES, show_progress=False + ): self.load_metadata() num_records = self.metadata.num_records if np.isinf(num_records): @@ -1514,7 +1525,7 @@ def explode( vcfs, *, column_chunk_size=16, - worker_processes=1, + worker_processes=core.DEFAULT_WORKER_PROCESSES, show_progress=False, compressor=None, ): @@ -1539,7 +1550,7 @@ def explode_init( *, column_chunk_size=16, target_num_partitions=1, - worker_processes=1, + worker_processes=core.DEFAULT_WORKER_PROCESSES, show_progress=False, compressor=None, ): @@ -1601,7 +1612,7 @@ def convert( *, variants_chunk_size=None, samples_chunk_size=None, - worker_processes=1, + worker_processes=core.DEFAULT_WORKER_PROCESSES, local_alleles=None, show_progress=False, icf_path=None, @@ -1645,7 +1656,7 @@ def encode( dimension_separator=None, max_memory=None, local_alleles=None, - worker_processes=1, + worker_processes=core.DEFAULT_WORKER_PROCESSES, show_progress=False, ): # Rough heuristic to split work up enough to keep utilisation high @@ -1683,7 +1694,7 @@ def encode_init( max_variant_chunks=None, dimension_separator=None, max_memory=None, - worker_processes=1, + worker_processes=core.DEFAULT_WORKER_PROCESSES, show_progress=False, ): icf_store = IntermediateColumnarFormat(icf_path) diff --git a/docs/_config.yml b/docs/_config.yml index 053ce11d..433a594c 100644 --- a/docs/_config.yml +++ b/docs/_config.yml @@ -24,7 +24,7 @@ html: extra_footer: |
Documentation available under the terms of the - CC0 1.0 + CC0 1.0 license.
@@ -32,6 +32,7 @@ sphinx: extra_extensions: - sphinx_click.ext - sphinx.ext.todo + - sphinx.ext.autodoc config: html_show_copyright: false # This is needed to make sure that text is output in single block from @@ -40,3 +41,6 @@ sphinx: todo_include_todos: true myst_enable_extensions: - colon_fence + intersphinx_mapping: + python: ["https://docs.python.org/3/", null] + tskit: ["https://tskit.dev/tskit/docs/stable", null] diff --git a/docs/_toc.yml b/docs/_toc.yml index 007cfea5..718b3512 100644 --- a/docs/_toc.yml +++ b/docs/_toc.yml @@ -9,6 +9,10 @@ chapters: - file: plink2zarr/overview sections: - file: plink2zarr/cli_ref +- file: tskit2zarr/overview + sections: + - file: tskit2zarr/python_api + - file: tskit2zarr/cli_ref - file: vcfpartition/overview sections: - file: vcfpartition/cli_ref diff --git a/docs/plink2zarr/cli_ref.md b/docs/plink2zarr/cli_ref.md index 2e13db25..a01ad1c3 100644 --- a/docs/plink2zarr/cli_ref.md +++ b/docs/plink2zarr/cli_ref.md @@ -14,4 +14,4 @@ .. click:: bio2zarr.cli:convert_plink :prog: plink2zarr convert :nested: full - +``` diff --git a/docs/tskit2zarr/cli_ref.md b/docs/tskit2zarr/cli_ref.md new file mode 100644 index 00000000..9e478809 --- /dev/null +++ b/docs/tskit2zarr/cli_ref.md @@ -0,0 +1,18 @@ +(sec-tskit2zarr-cli-ref)= +# CLI Reference + +% A note on cross references... There's some weird long-standing problem with +% cross referencing program values in Sphinx, which means that we can't use +% the built-in labels generated by sphinx-click. We can make our own explicit +% targets, but these have to have slightly weird names to avoid conflicting +% with what sphinx-click is doing. So, hence the cmd- prefix. +% Based on: https://github.com/skypilot-org/skypilot/pull/2834 + +```{eval-rst} + +.. _cmd-tskit2zarr-convert: +.. click:: bio2zarr.cli:convert_tskit + :prog: tskit2zarr convert + :nested: full + +``` diff --git a/docs/tskit2zarr/overview.md b/docs/tskit2zarr/overview.md new file mode 100644 index 00000000..a808dd73 --- /dev/null +++ b/docs/tskit2zarr/overview.md @@ -0,0 +1,10 @@ +(sec-tskit2zarr)= +# tskit2zarr + +Convert tskit data to the +[VCF Zarr specification](https://github.com/sgkit-dev/vcf-zarr-spec/) +reliably in parallel. + +See {ref}`sec-tskit2zarr-cli-ref` for detailed documentation on +command line options. + diff --git a/docs/tskit2zarr/python_api.md b/docs/tskit2zarr/python_api.md new file mode 100644 index 00000000..31979af5 --- /dev/null +++ b/docs/tskit2zarr/python_api.md @@ -0,0 +1,37 @@ +(sec-tskit2zarr-python-api)= +# Python API + +Basic usage: +```python +import bio2zarr.tskit as ts2z + +ts2z.convert(ts_path, vcz_path, worker_processes=8) +``` + +This will convert the [tskit](https://tskit.dev) tree sequence stored +at ``ts_path`` to VCF Zarr stored at ``vcz_path`` using 8 worker processes. +The details of how we map from the +tskit {ref}`tskit:sec_data_model` to VCF Zarr are taken care of by +{meth}`tskit.TreeSequence.map_to_vcf_model` +method, which is called with no +parameters by default if the ``model_mapping`` parameter to +{func}`~bio2zarr.tskit.convert` is not specified. + +For more control over the properties of the output, for example +to pick a specific subset of individuals, you can use +{meth}`~tskit.TreeSequence.map_to_vcf_model` +to return the required mapping: + +```python +model_mapping = ts.map_to_vcf_model(individuals=[0, 1]) +ts2z.convert(ts, vcz_path, model_mapping=model_mapping) +``` + + +## API reference + +```{eval-rst} + +.. autofunction:: bio2zarr.tskit.convert + +``` diff --git a/pyproject.toml b/pyproject.toml index 4831cf35..0ba72320 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -64,17 +64,16 @@ dev = [ "pytest-xdist", "sgkit>=0.8.0", "tqdm", - "tskit @ git+https://github.com/tskit-dev/tskit.git@main#subdirectory=python", + "tskit>=0.6.4", "bed_reader", "cyvcf2" ] -# TODO Using dev version of tskit for CI, FIXME before release -tskit = ["tskit @ git+https://github.com/tskit-dev/tskit.git@main#subdirectory=python"] +tskit = ["tskit>=0.6.4"] vcf = ["cyvcf2"] all = [ - "tskit @ git+https://github.com/tskit-dev/tskit.git@main#subdirectory=python", + "tskit>=0.6.4", "cyvcf2" - ] +] [tool.setuptools] diff --git a/tests/data/ts/example.trees b/tests/data/ts/example.trees deleted file mode 100644 index 4910ec22..00000000 Binary files a/tests/data/ts/example.trees and /dev/null differ diff --git a/tests/data/tskit/example.trees b/tests/data/tskit/example.trees new file mode 100644 index 00000000..66e48465 Binary files /dev/null and b/tests/data/tskit/example.trees differ diff --git a/tests/test_cli.py b/tests/test_cli.py index 841db800..5113f356 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -7,19 +7,19 @@ import pytest from bio2zarr import __main__ as main -from bio2zarr import cli, provenance +from bio2zarr import cli, core, provenance DEFAULT_EXPLODE_ARGS = dict( column_chunk_size=64, compressor=None, - worker_processes=1, + worker_processes=core.DEFAULT_WORKER_PROCESSES, show_progress=True, ) DEFAULT_DEXPLODE_PARTITION_ARGS = dict() DEFAULT_DEXPLODE_INIT_ARGS = dict( - worker_processes=1, + worker_processes=core.DEFAULT_WORKER_PROCESSES, column_chunk_size=64, compressor=None, show_progress=True, @@ -30,7 +30,7 @@ variants_chunk_size=None, samples_chunk_size=None, max_variant_chunks=None, - worker_processes=1, + worker_processes=core.DEFAULT_WORKER_PROCESSES, max_memory=None, show_progress=True, ) @@ -57,7 +57,7 @@ variants_chunk_size=None, samples_chunk_size=None, show_progress=True, - worker_processes=1, + worker_processes=core.DEFAULT_WORKER_PROCESSES, local_alleles=False, ) @@ -67,14 +67,14 @@ variants_chunk_size=None, samples_chunk_size=None, show_progress=True, - worker_processes=1, + worker_processes=core.DEFAULT_WORKER_PROCESSES, ) DEFAULT_PLINK_CONVERT_ARGS = dict( variants_chunk_size=None, samples_chunk_size=None, show_progress=True, - worker_processes=1, + worker_processes=core.DEFAULT_WORKER_PROCESSES, ) @@ -647,7 +647,7 @@ def test_vcf_convert_overwrite_zarr_confirm_yes(self, mocked, tmp_path, response @pytest.mark.parametrize(("progress", "flag"), [(True, "-P"), (False, "-Q")]) @mock.patch("bio2zarr.tskit.convert") def test_convert_tskit(self, mocked, tmp_path, progress, flag): - ts_path = "tests/data/ts/example.trees" + ts_path = "tests/data/tskit/example.trees" zarr_path = tmp_path / "zarr" runner = ct.CliRunner() result = runner.invoke( @@ -669,7 +669,7 @@ def test_convert_tskit(self, mocked, tmp_path, progress, flag): @pytest.mark.parametrize("response", ["y", "Y", "yes"]) @mock.patch("bio2zarr.tskit.convert") def test_tskit_convert_overwrite_zarr_confirm_yes(self, mocked, tmp_path, response): - ts_path = "tests/data/ts/example.trees" + ts_path = "tests/data/tskit/example.trees" zarr_path = tmp_path / "zarr" zarr_path.mkdir() runner = ct.CliRunner() @@ -691,7 +691,7 @@ def test_tskit_convert_overwrite_zarr_confirm_yes(self, mocked, tmp_path, respon @pytest.mark.parametrize("response", ["n", "N", "No"]) @mock.patch("bio2zarr.tskit.convert") def test_tskit_convert_overwrite_zarr_confirm_no(self, mocked, tmp_path, response): - ts_path = "tests/data/ts/example.trees" + ts_path = "tests/data/tskit/example.trees" zarr_path = tmp_path / "zarr" zarr_path.mkdir() runner = ct.CliRunner() @@ -708,7 +708,7 @@ def test_tskit_convert_overwrite_zarr_confirm_no(self, mocked, tmp_path, respons @pytest.mark.parametrize("force_arg", ["-f", "--force"]) @mock.patch("bio2zarr.tskit.convert") def test_tskit_convert_overwrite_zarr_force(self, mocked, tmp_path, force_arg): - ts_path = "tests/data/ts/example.trees" + ts_path = "tests/data/tskit/example.trees" zarr_path = tmp_path / "zarr" zarr_path.mkdir() runner = ct.CliRunner() @@ -728,7 +728,7 @@ def test_tskit_convert_overwrite_zarr_force(self, mocked, tmp_path, force_arg): @mock.patch("bio2zarr.tskit.convert") def test_tskit_convert_with_options(self, mocked, tmp_path): - ts_path = "tests/data/ts/example.trees" + ts_path = "tests/data/tskit/example.trees" zarr_path = tmp_path / "zarr" runner = ct.CliRunner() result = runner.invoke( @@ -1028,7 +1028,7 @@ def test_part_size_multiple_vcfs(self): class TestTskitEndToEnd: def test_convert(self, tmp_path): - ts_path = "tests/data/ts/example.trees" + ts_path = "tests/data/tskit/example.trees" zarr_path = tmp_path / "zarr" runner = ct.CliRunner() result = runner.invoke( diff --git a/tests/test_simulated_data.py b/tests/test_simulated_data.py index ad8386f8..d8d52a44 100644 --- a/tests/test_simulated_data.py +++ b/tests/test_simulated_data.py @@ -1,5 +1,5 @@ -import sys - +import msprime +import numpy as np import numpy.testing as nt import pysam import pytest @@ -9,10 +9,6 @@ def run_simulation(num_samples=2, ploidy=1, seed=42, sequence_length=100_000): - # Import here to avoid problems on OSX (see below) - # https://github.com/sgkit-dev/bio2zarr/issues/336 - import msprime - ts = msprime.sim_ancestry( num_samples, population_size=10**4, @@ -37,6 +33,10 @@ def assert_ts_ds_equal(ts, ds, ploidy=1): ts.genotype_matrix().reshape((ts.num_sites, ts.num_individuals, ploidy)), ds.call_genotype.values, ) + nt.assert_array_equal( + ds.call_genotype_phased.values, + np.ones((ts.num_sites, ts.num_individuals), dtype=bool), + ) nt.assert_equal(ds.variant_allele[:, 0].values, "A") nt.assert_equal(ds.variant_allele[:, 1].values, "T") nt.assert_equal(ds.variant_position, ts.sites_position) @@ -52,8 +52,6 @@ def write_vcf(ts, vcf_path, contig_id="1", indexed=False): return vcf_path -# https://github.com/sgkit-dev/bio2zarr/issues/336 -@pytest.mark.skipif(sys.platform == "darwin", reason="msprime OSX pip packages broken") class TestTskitRoundTripVcf: @pytest.mark.parametrize("ploidy", [1, 2, 3, 4]) def test_ploidy(self, ploidy, tmp_path): @@ -127,8 +125,6 @@ def test_mixed_indexed(self, num_contigs, tmp_path): self.validate_tss_vcf_list(contig_ids, tss, vcfs, tmp_path) -# https://github.com/sgkit-dev/bio2zarr/issues/336 -@pytest.mark.skipif(sys.platform == "darwin", reason="msprime OSX pip packages broken") class TestIncompatibleContigs: def test_different_lengths(self, tmp_path): vcfs = [] diff --git a/tests/test_ts.py b/tests/test_ts.py deleted file mode 100644 index e56924c9..00000000 --- a/tests/test_ts.py +++ /dev/null @@ -1,495 +0,0 @@ -import os -import tempfile -from unittest import mock - -import numpy as np -import pytest -import tskit -import zarr - -from bio2zarr import tskit as ts - - -class TestTskit: - def test_simple_tree_sequence(self, tmp_path): - tables = tskit.TableCollection(sequence_length=100) - tables.nodes.add_row(flags=tskit.NODE_IS_SAMPLE, time=0) - tables.nodes.add_row(flags=tskit.NODE_IS_SAMPLE, time=0) - tables.nodes.add_row(flags=tskit.NODE_IS_SAMPLE, time=0) - tables.nodes.add_row(flags=tskit.NODE_IS_SAMPLE, time=0) - tables.nodes.add_row(flags=0, time=1) # MRCA for 0,1 - tables.nodes.add_row(flags=0, time=1) # MRCA for 2,3 - tables.edges.add_row(left=0, right=100, parent=4, child=0) - tables.edges.add_row(left=0, right=100, parent=4, child=1) - tables.edges.add_row(left=0, right=100, parent=5, child=2) - tables.edges.add_row(left=0, right=100, parent=5, child=3) - site_id = tables.sites.add_row(position=10, ancestral_state="A") - tables.mutations.add_row(site=site_id, node=4, derived_state="TTTT") - site_id = tables.sites.add_row(position=20, ancestral_state="CCC") - tables.mutations.add_row(site=site_id, node=5, derived_state="G") - site_id = tables.sites.add_row(position=30, ancestral_state="G") - tables.mutations.add_row(site=site_id, node=0, derived_state="AA") - tables.sort() - tree_sequence = tables.tree_sequence() - tree_sequence.dump(tmp_path / "test.trees") - - # Manually specify the individuals_nodes, other tests use - # ts individuals. - ind_nodes = np.array([[0, 1], [2, 3]]) - - with tempfile.TemporaryDirectory() as tempdir: - zarr_path = os.path.join(tempdir, "test_output.zarr") - ts.convert( - tmp_path / "test.trees", - zarr_path, - individuals_nodes=ind_nodes, - show_progress=False, - ) - zroot = zarr.open(zarr_path, mode="r") - pos = zroot["variant_position"][:] - assert pos.shape == (3,) - assert pos.dtype == np.int8 - assert np.array_equal(pos, [10, 20, 30]) - - alleles = zroot["variant_allele"][:] - assert alleles.shape == (3, 2) - assert alleles.dtype == "O" - assert np.array_equal(alleles, [["A", "TTTT"], ["CCC", "G"], ["G", "AA"]]) - - lengths = zroot["variant_length"][:] - assert lengths.shape == (3,) - assert lengths.dtype == np.int8 - assert np.array_equal(lengths, [4, 3, 2]) - - genotypes = zroot["call_genotype"][:] - assert genotypes.shape == (3, 2, 2) - assert genotypes.dtype == np.int8 - assert np.array_equal( - genotypes, [[[1, 1], [0, 0]], [[0, 0], [1, 1]], [[1, 0], [0, 0]]] - ) - - phased = zroot["call_genotype_phased"][:] - assert phased.shape == (3, 2) - assert phased.dtype == "bool" - assert np.all(phased) - - contigs = zroot["contig_id"][:] - assert contigs.shape == (1,) - assert contigs.dtype == "O" - assert np.array_equal(contigs, ["1"]) - - contig = zroot["variant_contig"][:] - assert contig.shape == (3,) - assert contig.dtype == np.int8 - assert np.array_equal(contig, [0, 0, 0]) - - samples = zroot["sample_id"][:] - assert samples.shape == (2,) - assert samples.dtype == "O" - assert np.array_equal(samples, ["tsk_0", "tsk_1"]) - - region_index = zroot["region_index"][:] - assert region_index.shape == (1, 6) - assert region_index.dtype == np.int8 - assert np.array_equal(region_index, [[0, 0, 10, 30, 31, 3]]) - - assert set(zroot.array_keys()) == { - "variant_position", - "variant_allele", - "variant_length", - "call_genotype", - "call_genotype_phased", - "call_genotype_mask", - "contig_id", - "variant_contig", - "sample_id", - "region_index", - } - - def test_missing_dependency(self): - with mock.patch( - "importlib.import_module", - side_effect=ImportError("No module named 'tskit'"), - ): - with pytest.raises(ImportError) as exc_info: - ts.convert( - "UNUSED_PATH", - "UNUSED_PATH", - ) - assert ( - "This process requires the optional tskit module. Install " - "it with: pip install bio2zarr[tskit]" in str(exc_info.value) - ) - - -class TestTskitFormat: - """Unit tests for TskitFormat without using full conversion.""" - - @pytest.fixture() - def simple_ts(self, tmp_path): - tables = tskit.TableCollection(sequence_length=100) - tables.individuals.add_row() - tables.individuals.add_row() - tables.nodes.add_row(flags=tskit.NODE_IS_SAMPLE, time=0, individual=0) - tables.nodes.add_row(flags=tskit.NODE_IS_SAMPLE, time=0, individual=0) - tables.nodes.add_row(flags=tskit.NODE_IS_SAMPLE, time=0, individual=1) - tables.nodes.add_row(flags=tskit.NODE_IS_SAMPLE, time=0, individual=1) - tables.nodes.add_row(flags=0, time=1) # MRCA for 0,1 - tables.nodes.add_row(flags=0, time=1) # MRCA for 2,3 - tables.edges.add_row(left=0, right=100, parent=4, child=0) - tables.edges.add_row(left=0, right=100, parent=4, child=1) - tables.edges.add_row(left=0, right=100, parent=5, child=2) - tables.edges.add_row(left=0, right=100, parent=5, child=3) - site_id = tables.sites.add_row(position=10, ancestral_state="A") - tables.mutations.add_row(site=site_id, node=4, derived_state="TT") - site_id = tables.sites.add_row(position=20, ancestral_state="CCC") - tables.mutations.add_row(site=site_id, node=5, derived_state="G") - site_id = tables.sites.add_row(position=30, ancestral_state="G") - tables.mutations.add_row(site=site_id, node=0, derived_state="A") - tables.sort() - tree_sequence = tables.tree_sequence() - ts_path = tmp_path / "test.trees" - tree_sequence.dump(ts_path) - return ts_path, tree_sequence - - @pytest.fixture() - def no_individuals_ts(self, tmp_path): - tables = tskit.TableCollection(sequence_length=100) - tables.nodes.add_row(flags=tskit.NODE_IS_SAMPLE, time=0) - tables.nodes.add_row(flags=tskit.NODE_IS_SAMPLE, time=0) - tables.nodes.add_row(flags=tskit.NODE_IS_SAMPLE, time=0) - tables.nodes.add_row(flags=tskit.NODE_IS_SAMPLE, time=0) - tables.nodes.add_row(flags=0, time=1) # MRCA for 0,1 - tables.nodes.add_row(flags=0, time=1) # MRCA for 2,3 - tables.edges.add_row(left=0, right=100, parent=4, child=0) - tables.edges.add_row(left=0, right=100, parent=4, child=1) - tables.edges.add_row(left=0, right=100, parent=5, child=2) - tables.edges.add_row(left=0, right=100, parent=5, child=3) - site_id = tables.sites.add_row(position=10, ancestral_state="A") - tables.mutations.add_row(site=site_id, node=4, derived_state="T") - site_id = tables.sites.add_row(position=20, ancestral_state="C") - tables.mutations.add_row(site=site_id, node=5, derived_state="G") - tables.sort() - tree_sequence = tables.tree_sequence() - ts_path = tmp_path / "no_individuals.trees" - tree_sequence.dump(ts_path) - return ts_path, tree_sequence - - def test_position_dtype_selection(self, tmp_path): - tables = tskit.TableCollection(sequence_length=100) - tables.nodes.add_row(flags=tskit.NODE_IS_SAMPLE, time=0) - tables.sites.add_row(position=10, ancestral_state="A") - tables.sites.add_row(position=20, ancestral_state="C") - ts_small = tables.tree_sequence() - ts_path_small = tmp_path / "small_positions.trees" - ts_small.dump(ts_path_small) - - tables = tskit.TableCollection(sequence_length=3_000_000_000) - tables.nodes.add_row(flags=tskit.NODE_IS_SAMPLE, time=0) - tables.sites.add_row(position=10, ancestral_state="A") - tables.sites.add_row(position=np.iinfo(np.int32).max + 1, ancestral_state="C") - ts_large = tables.tree_sequence() - ts_path_large = tmp_path / "large_positions.trees" - ts_large.dump(ts_path_large) - - ind_nodes = np.array([[0], [1]]) - format_obj_small = ts.TskitFormat(ts_path_small, individuals_nodes=ind_nodes) - schema_small = format_obj_small.generate_schema() - - position_field = next( - f for f in schema_small.fields if f.name == "variant_position" - ) - assert position_field.dtype == "i1" - - format_obj_large = ts.TskitFormat(ts_path_large, individuals_nodes=ind_nodes) - schema_large = format_obj_large.generate_schema() - - position_field = next( - f for f in schema_large.fields if f.name == "variant_position" - ) - assert position_field.dtype == "i8" - - def test_initialization(self, simple_ts): - ts_path, tree_sequence = simple_ts - - # Test with default parameters - format_obj = ts.TskitFormat(ts_path) - assert format_obj.path == ts_path - assert format_obj.ts.num_sites == tree_sequence.num_sites - assert format_obj.contig_id == "1" - assert not format_obj.isolated_as_missing - - # Test with custom parameters - format_obj = ts.TskitFormat( - ts_path, - sample_ids=["ind1", "ind2"], - contig_id="chr1", - isolated_as_missing=True, - ) - assert format_obj.contig_id == "chr1" - assert format_obj.isolated_as_missing - assert format_obj.path == ts_path - assert format_obj.samples[0].id == "ind1" - assert format_obj.samples[1].id == "ind2" - - def test_basic_properties(self, simple_ts): - ts_path, _ = simple_ts - format_obj = ts.TskitFormat(ts_path) - - assert format_obj.num_records == format_obj.ts.num_sites - assert format_obj.num_samples == 2 # Two individuals - assert len(format_obj.samples) == 2 - assert format_obj.samples[0].id == "tsk_0" - assert format_obj.samples[1].id == "tsk_1" - - assert format_obj.root_attrs == {} - - contigs = format_obj.contigs - assert len(contigs) == 1 - assert contigs[0].id == "1" - - def test_custom_sample_ids(self, simple_ts): - ts_path, _ = simple_ts - custom_ids = ["sample_X", "sample_Y"] - format_obj = ts.TskitFormat(ts_path, sample_ids=custom_ids) - - assert format_obj.num_samples == 2 - assert len(format_obj.samples) == 2 - assert format_obj.samples[0].id == "sample_X" - assert format_obj.samples[1].id == "sample_Y" - - def test_sample_id_length_mismatch(self, simple_ts): - ts_path, _ = simple_ts - # Wrong number of sample IDs - with pytest.raises(ValueError, match="Length of sample_ids.*does not match"): - ts.TskitFormat(ts_path, sample_ids=["only_one_id"]) - - def test_schema_generation(self, simple_ts): - ts_path, _ = simple_ts - format_obj = ts.TskitFormat(ts_path) - - schema = format_obj.generate_schema() - assert schema.dimensions["variants"].size == 3 - assert schema.dimensions["samples"].size == 2 - assert schema.dimensions["ploidy"].size == 2 - assert schema.dimensions["alleles"].size == 2 # A/T, C/G, G/A -> max is 2 - field_names = [field.name for field in schema.fields] - assert "variant_position" in field_names - assert "variant_allele" in field_names - assert "variant_length" in field_names - assert "variant_contig" in field_names - assert "call_genotype" in field_names - assert "call_genotype_phased" in field_names - assert "call_genotype_mask" in field_names - schema = format_obj.generate_schema( - variants_chunk_size=10, samples_chunk_size=5 - ) - assert schema.dimensions["variants"].chunk_size == 10 - assert schema.dimensions["samples"].chunk_size == 5 - - def test_iter_contig(self, simple_ts): - ts_path, _ = simple_ts - format_obj = ts.TskitFormat(ts_path) - contig_indices = list(format_obj.iter_contig(1, 3)) - assert contig_indices == [0, 0] - - def test_iter_field(self, simple_ts): - ts_path, _ = simple_ts - format_obj = ts.TskitFormat(ts_path) - positions = list(format_obj.iter_field("position", None, 0, 3)) - assert positions == [10, 20, 30] - positions = list(format_obj.iter_field("position", None, 1, 3)) - assert positions == [20, 30] - with pytest.raises(ValueError, match="Unknown field"): - list(format_obj.iter_field("unknown_field", None, 0, 3)) - - @pytest.mark.parametrize( - ("ind_nodes", "expected_gts"), - [ - # Standard case: diploid samples with sequential node IDs - ( - np.array([[0, 1], [2, 3]]), - [[[1, 1], [0, 0]], [[0, 0], [1, 1]], [[1, 0], [0, 0]]], - ), - # Mixed ploidy: first sample diploid, second haploid - ( - np.array([[0, 1], [2, -1]]), - [[[1, 1], [0, -2]], [[0, 0], [1, -2]], [[1, 0], [0, -2]]], - ), - # Reversed order: nodes are not in sequential order - ( - np.array([[2, 3], [0, 1]]), - [[[0, 0], [1, 1]], [[1, 1], [0, 0]], [[0, 0], [1, 0]]], - ), - # Duplicate nodes: same node used multiple times - ( - np.array([[0, 0], [2, 2]]), - [[[1, 1], [0, 0]], [[0, 0], [1, 1]], [[1, 1], [0, 0]]], - ), - # Non-sample node: using node 4 which is an internal node (MRCA for 0,1) - ( - np.array([[0, 4], [2, 3]]), - [[[1, 1], [0, 0]], [[0, 0], [1, 1]], [[1, 0], [0, 0]]], - ), - # One individual with zero ploidy - ( - np.array([[0, 1], [-1, -1]]), - [[[1, 1], [-2, -2]], [[0, 0], [-2, -2]], [[1, 0], [-2, -2]]], - ), - ], - ) - def test_iter_alleles_and_genotypes(self, simple_ts, ind_nodes, expected_gts): - ts_path, _ = simple_ts - - format_obj = ts.TskitFormat(ts_path, individuals_nodes=ind_nodes) - - shape = (2, 2) # (num_samples, max_ploidy) - results = list(format_obj.iter_alleles_and_genotypes(0, 3, shape, 2)) - - assert len(results) == 3 - - for i, variant_data in enumerate(results): - if i == 0: - assert variant_data.variant_length == 2 - assert np.array_equal(variant_data.alleles, ("A", "TT")) - elif i == 1: - assert variant_data.variant_length == 3 - assert np.array_equal(variant_data.alleles, ("CCC", "G")) - elif i == 2: - assert variant_data.variant_length == 1 - assert np.array_equal(variant_data.alleles, ("G", "A")) - - assert np.array_equal( - variant_data.genotypes, expected_gts[i] - ), f"Mismatch at variant {i}, expected {expected_gts[i]}, " - f"got {variant_data.genotypes}" - assert np.all(variant_data.phased) - - def test_iter_alleles_and_genotypes_errors(self, simple_ts): - """Test error cases for iter_alleles_and_genotypes with invalid inputs.""" - ts_path, _ = simple_ts - - # Test with node ID that doesn't exist in tree sequence (out of range) - invalid_nodes = np.array([[10, 11], [12, 13]], dtype=np.int32) - format_obj = ts.TskitFormat(ts_path, individuals_nodes=invalid_nodes) - shape = (2, 2) - with pytest.raises( - tskit.LibraryError, match="out of bounds" - ): # Node ID 10 doesn't exist - list(format_obj.iter_alleles_and_genotypes(0, 1, shape, 2)) - - # Test with empty ind_nodes array (no samples) - empty_nodes = np.zeros((0, 2), dtype=np.int32) - with pytest.raises( - ValueError, match="individuals_nodes must have at least one sample" - ): - format_obj = ts.TskitFormat(ts_path, individuals_nodes=empty_nodes) - - # Test with all invalid nodes (-1) - all_invalid = np.full((2, 2), -1, dtype=np.int32) - with pytest.raises( - ValueError, match="individuals_nodes must have at least one valid sample" - ): - format_obj = ts.TskitFormat(ts_path, individuals_nodes=all_invalid) - - def test_isolated_as_missing(self, tmp_path): - def insert_branch_sites(ts, m=1): - if m == 0: - return ts - tables = ts.dump_tables() - tables.sites.clear() - tables.mutations.clear() - for tree in ts.trees(): - left, right = tree.interval - delta = (right - left) / (m * len(list(tree.nodes()))) - x = left - for u in tree.nodes(): - if tree.parent(u) != tskit.NULL: - for _ in range(m): - site = tables.sites.add_row(position=x, ancestral_state="0") - tables.mutations.add_row( - site=site, node=u, derived_state="1" - ) - x += delta - return tables.tree_sequence() - - tables = tskit.Tree.generate_balanced(2, span=10).tree_sequence.dump_tables() - # This also tests sample nodes that are not a single block at - # the start of the nodes table. - tables.nodes.add_row(time=0, flags=tskit.NODE_IS_SAMPLE) - tree_sequence = insert_branch_sites(tables.tree_sequence()) - - ts_path = tmp_path / "isolated_sample.trees" - tree_sequence.dump(ts_path) - ind_nodes = np.array([[0], [1], [3]]) - format_obj_default = ts.TskitFormat( - ts_path, individuals_nodes=ind_nodes, isolated_as_missing=False - ) - shape = (3, 1) # (num_samples, max_ploidy) - results_default = list( - format_obj_default.iter_alleles_and_genotypes(0, 1, shape, 2) - ) - - assert len(results_default) == 1 - variant_data_default = results_default[0] - assert np.array_equal(variant_data_default.alleles, ("0", "1")) - - # Sample 2 should have the ancestral state (0) when isolated_as_missing=False - expected_gt_default = np.array([[1], [0], [0]]) - assert np.array_equal(variant_data_default.genotypes, expected_gt_default) - - format_obj_missing = ts.TskitFormat( - ts_path, individuals_nodes=ind_nodes, isolated_as_missing=True - ) - results_missing = list( - format_obj_missing.iter_alleles_and_genotypes(0, 1, shape, 2) - ) - - assert len(results_missing) == 1 - variant_data_missing = results_missing[0] - assert variant_data_missing.variant_length == 1 - assert np.array_equal(variant_data_missing.alleles, ("0", "1")) - - # Individual 2 should have missing values (-1) when isolated_as_missing=True - expected_gt_missing = np.array([[1], [0], [-1]]) - assert np.array_equal(variant_data_missing.genotypes, expected_gt_missing) - - def test_genotype_dtype_selection(self, tmp_path): - tables = tskit.TableCollection(sequence_length=100) - for _ in range(4): - tables.nodes.add_row(flags=tskit.NODE_IS_SAMPLE, time=0) - mrca = tables.nodes.add_row(flags=0, time=1) - for i in range(4): - tables.edges.add_row(left=0, right=100, parent=mrca, child=i) - site_id = tables.sites.add_row(position=10, ancestral_state="A") - tables.mutations.add_row(site=site_id, node=0, derived_state="T") - tables.sort() - tree_sequence = tables.tree_sequence() - ts_path = tmp_path / "small_alleles.trees" - tree_sequence.dump(ts_path) - - ind_nodes = np.array([[0, 1], [2, 3]]) - format_obj = ts.TskitFormat(ts_path, individuals_nodes=ind_nodes) - schema = format_obj.generate_schema() - call_genotype_spec = next(s for s in schema.fields if s.name == "call_genotype") - assert call_genotype_spec.dtype == "i1" - - tables = tskit.TableCollection(sequence_length=100) - for _ in range(4): - tables.nodes.add_row(flags=tskit.NODE_IS_SAMPLE, time=0) - mrca = tables.nodes.add_row(flags=0, time=1) - for i in range(4): - tables.edges.add_row(left=0, right=100, parent=mrca, child=i) - site_id = tables.sites.add_row(position=10, ancestral_state="A") - for i in range(32768): - tables.mutations.add_row(site=site_id, node=0, derived_state=f"ALLELE_{i}") - - tables.sort() - tree_sequence = tables.tree_sequence() - ts_path = tmp_path / "large_alleles.trees" - tree_sequence.dump(ts_path) - - format_obj = ts.TskitFormat(ts_path, individuals_nodes=ind_nodes) - schema = format_obj.generate_schema() - call_genotype_spec = next(s for s in schema.fields if s.name == "call_genotype") - assert call_genotype_spec.dtype == "i4" diff --git a/tests/test_tskit.py b/tests/test_tskit.py new file mode 100644 index 00000000..81e28f62 --- /dev/null +++ b/tests/test_tskit.py @@ -0,0 +1,554 @@ +from unittest import mock + +import msprime +import numpy as np +import numpy.testing as nt +import pytest +import sgkit as sg +import tskit +import xarray.testing as xt +import zarr + +from bio2zarr import tskit as tsk +from bio2zarr import vcf + + +def test_missing_dependency(): + with mock.patch( + "importlib.import_module", + side_effect=ImportError("No module named 'tskit'"), + ): + with pytest.raises(ImportError) as exc_info: + tsk.convert( + "UNUSED_PATH", + "UNUSED_PATH", + ) + assert ( + "This process requires the optional tskit module. Install " + "it with: pip install bio2zarr[tskit]" in str(exc_info.value) + ) + + +def tskit_model_mapping(ind_nodes, ind_names=None): + if ind_names is None: + ind_names = ["tsk{j}" for j in range(len(ind_nodes))] + return tskit.VcfModelMapping(ind_nodes, ind_names) + + +def add_mutations(ts): + # Add some mutation to the tree sequence. This guarantees that + # we have variation at all sites > 0. + tables = ts.dump_tables() + samples = ts.samples() + states = "ACGT" + for j in range(1, int(ts.sequence_length) - 1): + site = tables.sites.add_row(j, ancestral_state=states[j % 4]) + tables.mutations.add_row( + site=site, + derived_state=states[(j + 1) % 4], + node=samples[j % ts.num_samples], + ) + return tables.tree_sequence() + + +def simple_ts(add_individuals=False): + tables = tskit.TableCollection(sequence_length=100) + for _ in range(4): + ind = -1 + if add_individuals: + ind = tables.individuals.add_row() + tables.nodes.add_row(flags=tskit.NODE_IS_SAMPLE, time=0, individual=ind) + tables.nodes.add_row(flags=0, time=1) # MRCA for 0,1 + tables.nodes.add_row(flags=0, time=1) # MRCA for 2,3 + tables.edges.add_row(left=0, right=100, parent=4, child=0) + tables.edges.add_row(left=0, right=100, parent=4, child=1) + tables.edges.add_row(left=0, right=100, parent=5, child=2) + tables.edges.add_row(left=0, right=100, parent=5, child=3) + site_id = tables.sites.add_row(position=10, ancestral_state="A") + tables.mutations.add_row(site=site_id, node=4, derived_state="TTTT") + site_id = tables.sites.add_row(position=20, ancestral_state="CCC") + tables.mutations.add_row(site=site_id, node=5, derived_state="G") + site_id = tables.sites.add_row(position=30, ancestral_state="G") + tables.mutations.add_row(site=site_id, node=0, derived_state="AA") + + tables.sort() + return tables.tree_sequence() + + +def insert_branch_sites(ts, m=1): + if m == 0: + return ts + tables = ts.dump_tables() + tables.sites.clear() + tables.mutations.clear() + for tree in ts.trees(): + left, right = tree.interval + delta = (right - left) / (m * len(list(tree.nodes()))) + x = left + for u in tree.nodes(): + if tree.parent(u) != tskit.NULL: + for _ in range(m): + site = tables.sites.add_row(position=x, ancestral_state="0") + tables.mutations.add_row(site=site, node=u, derived_state="1") + x += delta + return tables.tree_sequence() + + +class TestSimpleTs: + @pytest.fixture() + def conversion(self, tmp_path): + ts = simple_ts() + zarr_path = tmp_path / "test_output.vcz" + tsk.convert(ts, zarr_path) + zroot = zarr.open(zarr_path, mode="r") + return ts, zroot + + def test_position(self, conversion): + ts, zroot = conversion + + pos = zroot["variant_position"][:] + assert pos.shape == (3,) + assert pos.dtype == np.int8 + nt.assert_array_equal(pos, [10, 20, 30]) + + def test_alleles(self, conversion): + ts, zroot = conversion + alleles = zroot["variant_allele"][:] + assert alleles.shape == (3, 2) + assert alleles.dtype == "O" + nt.assert_array_equal(alleles, [["A", "TTTT"], ["CCC", "G"], ["G", "AA"]]) + + def test_variant_length(self, conversion): + ts, zroot = conversion + lengths = zroot["variant_length"][:] + assert lengths.shape == (3,) + assert lengths.dtype == np.int8 + nt.assert_array_equal(lengths, [1, 3, 1]) + + def test_genotypes(self, conversion): + ts, zroot = conversion + genotypes = zroot["call_genotype"][:] + assert genotypes.shape == (3, 4, 1) + assert genotypes.dtype == np.int8 + nt.assert_array_equal( + genotypes, + [[[1], [1], [0], [0]], [[0], [0], [1], [1]], [[1], [0], [0], [0]]], + ) + + def test_phased(self, conversion): + ts, zroot = conversion + phased = zroot["call_genotype_phased"][:] + assert phased.shape == (3, 4) + assert phased.dtype == "bool" + assert np.all(phased) + + def test_contig_id(self, conversion): + ts, zroot = conversion + contigs = zroot["contig_id"][:] + assert contigs.shape == (1,) + assert contigs.dtype == "O" + nt.assert_array_equal(contigs, ["1"]) + + def test_variant_contig(self, conversion): + ts, zroot = conversion + contig = zroot["variant_contig"][:] + assert contig.shape == (3,) + assert contig.dtype == np.int8 + nt.assert_array_equal(contig, [0, 0, 0]) + + def test_sample_id(self, conversion): + ts, zroot = conversion + samples = zroot["sample_id"][:] + assert samples.shape == (4,) + assert samples.dtype == "O" + nt.assert_array_equal(samples, ["tsk_0", "tsk_1", "tsk_2", "tsk_3"]) + + def test_region_index(self, conversion): + ts, zroot = conversion + region_index = zroot["region_index"][:] + assert region_index.shape == (1, 6) + assert region_index.dtype == np.int8 + nt.assert_array_equal(region_index, [[0, 0, 10, 30, 30, 3]]) + + def test_fields(self, conversion): + ts, zroot = conversion + assert set(zroot.array_keys()) == { + "variant_position", + "variant_allele", + "variant_length", + "call_genotype", + "call_genotype_phased", + "call_genotype_mask", + "contig_id", + "variant_contig", + "sample_id", + "region_index", + } + + +class TestTskitFormat: + """Unit tests for TskitFormat without using full conversion.""" + + @pytest.fixture() + def fx_simple_ts(self): + return simple_ts(add_individuals=True) + + @pytest.fixture() + def fx_ts_2_diploids(self): + ts = msprime.sim_ancestry(2, sequence_length=10, random_seed=42) + return add_mutations(ts) + + @pytest.fixture() + def fx_ts_isolated_samples(self): + tables = tskit.Tree.generate_balanced(2, span=10).tree_sequence.dump_tables() + # This also tests sample nodes that are not a single block at + # the start of the nodes table. + tables.nodes.add_row(time=0, flags=tskit.NODE_IS_SAMPLE) + return insert_branch_sites(tables.tree_sequence()) + + def test_path_or_ts_input(self, tmp_path, fx_simple_ts): + f1 = tsk.TskitFormat(fx_simple_ts) + ts_path = tmp_path / "trees.ts" + fx_simple_ts.dump(ts_path) + f2 = tsk.TskitFormat(ts_path) + f1.ts.tables.assert_equals(f2.ts.tables) + + def test_small_position_dtype(self): + tables = tskit.TableCollection(sequence_length=100) + tables.nodes.add_row(flags=tskit.NODE_IS_SAMPLE, time=0) + tables.sites.add_row(position=10, ancestral_state="A") + tables.sites.add_row(position=20, ancestral_state="C") + ts = tables.tree_sequence() + format_obj_small = tsk.TskitFormat(ts) + schema_small = format_obj_small.generate_schema() + + position_field = next( + f for f in schema_small.fields if f.name == "variant_position" + ) + assert position_field.dtype == "i1" + + def test_large_position_dtype(self): + tables = tskit.TableCollection(sequence_length=3_000_000_000) + tables.nodes.add_row(flags=tskit.NODE_IS_SAMPLE, time=0) + tables.sites.add_row(position=10, ancestral_state="A") + tables.sites.add_row(position=np.iinfo(np.int32).max + 1, ancestral_state="C") + ts = tables.tree_sequence() + + format_obj_large = tsk.TskitFormat(ts) + schema_large = format_obj_large.generate_schema() + + position_field = next( + f for f in schema_large.fields if f.name == "variant_position" + ) + assert position_field.dtype == "i8" + + def test_initialization_defaults(self, fx_simple_ts): + format_obj = tsk.TskitFormat(fx_simple_ts) + assert format_obj.path is None + assert format_obj.ts.num_sites == fx_simple_ts.num_sites + assert format_obj.contig_id == "1" + assert not format_obj.isolated_as_missing + + def test_initialization_params(self, fx_simple_ts): + format_obj = tsk.TskitFormat( + fx_simple_ts, + contig_id="chr1", + isolated_as_missing=True, + ) + assert format_obj.contig_id == "chr1" + assert format_obj.isolated_as_missing + + def test_basic_properties(self, fx_ts_2_diploids): + format_obj = tsk.TskitFormat(fx_ts_2_diploids) + + assert format_obj.num_records == format_obj.ts.num_sites + assert format_obj.num_samples == 2 # Two individuals + assert len(format_obj.samples) == 2 + assert format_obj.samples[0].id == "tsk_0" + assert format_obj.samples[1].id == "tsk_1" + + assert format_obj.root_attrs == {} + + contigs = format_obj.contigs + assert len(contigs) == 1 + assert contigs[0].id == "1" + + def test_custom_sample_ids(self, fx_ts_2_diploids): + custom_ids = ["sW", "sX"] + model_mapping = fx_ts_2_diploids.map_to_vcf_model(individual_names=custom_ids) + format_obj = tsk.TskitFormat(fx_ts_2_diploids, model_mapping=model_mapping) + + assert format_obj.num_samples == 2 + assert len(format_obj.samples) == 2 + assert format_obj.samples[0].id == "sW" + assert format_obj.samples[1].id == "sX" + + def test_schema_generation(self, fx_simple_ts): + format_obj = tsk.TskitFormat(fx_simple_ts) + + schema = format_obj.generate_schema() + assert schema.dimensions["variants"].size == 3 + assert schema.dimensions["samples"].size == 4 + assert schema.dimensions["ploidy"].size == 1 + assert schema.dimensions["alleles"].size == 2 # A/T, C/G, G/A -> max is 2 + field_names = [field.name for field in schema.fields] + assert "variant_position" in field_names + assert "variant_allele" in field_names + assert "variant_length" in field_names + assert "variant_contig" in field_names + assert "call_genotype" in field_names + assert "call_genotype_phased" in field_names + assert "call_genotype_mask" in field_names + schema = format_obj.generate_schema( + variants_chunk_size=10, samples_chunk_size=5 + ) + assert schema.dimensions["variants"].chunk_size == 10 + assert schema.dimensions["samples"].chunk_size == 5 + + def test_iter_contig(self, fx_simple_ts): + format_obj = tsk.TskitFormat(fx_simple_ts) + contig_indices = list(format_obj.iter_contig(1, 3)) + assert contig_indices == [0, 0] + + def test_iter_field(self, fx_simple_ts): + format_obj = tsk.TskitFormat(fx_simple_ts) + positions = list(format_obj.iter_field("position", None, 0, 3)) + assert positions == [10, 20, 30] + positions = list(format_obj.iter_field("position", None, 1, 3)) + assert positions == [20, 30] + with pytest.raises(ValueError, match="Unknown field"): + list(format_obj.iter_field("unknown_field", None, 0, 3)) + + def test_zero_samples(self, fx_simple_ts): + model_mapping = tskit_model_mapping(np.array([])) + with pytest.raises(ValueError, match="at least one sample"): + tsk.TskitFormat(fx_simple_ts, model_mapping=model_mapping) + + def test_no_valid_samples(self, fx_simple_ts): + model_mapping = fx_simple_ts.map_to_vcf_model() + model_mapping.individuals_nodes[:] = -1 + with pytest.raises(ValueError, match="at least one valid sample"): + tsk.TskitFormat(fx_simple_ts, model_mapping=model_mapping) + + def test_model_size_mismatch(self, fx_simple_ts): + model_mapping = fx_simple_ts.map_to_vcf_model() + model_mapping.individuals_name = ["x"] + with pytest.raises(ValueError, match="match number of samples"): + tsk.TskitFormat(fx_simple_ts, model_mapping=model_mapping) + + @pytest.mark.parametrize( + ("ind_nodes", "expected_gts"), + [ + # Standard case: diploid samples with sequential node IDs + ( + np.array([[0, 1], [2, 3]]), + [[[1, 1], [0, 0]], [[0, 0], [1, 1]], [[1, 0], [0, 0]]], + ), + # Mixed ploidy: first sample diploid, second haploid + ( + np.array([[0, 1], [2, -1]]), + [[[1, 1], [0, -2]], [[0, 0], [1, -2]], [[1, 0], [0, -2]]], + ), + # Reversed order: nodes are not in sequential order + ( + np.array([[2, 3], [0, 1]]), + [[[0, 0], [1, 1]], [[1, 1], [0, 0]], [[0, 0], [1, 0]]], + ), + # Duplicate nodes: same node used multiple times + ( + np.array([[0, 0], [2, 2]]), + [[[1, 1], [0, 0]], [[0, 0], [1, 1]], [[1, 1], [0, 0]]], + ), + # Non-sample node: using node 4 which is an internal node (MRCA for 0,1) + ( + np.array([[0, 4], [2, 3]]), + [[[1, 1], [0, 0]], [[0, 0], [1, 1]], [[1, 0], [0, 0]]], + ), + # One individual with zero ploidy + ( + np.array([[0, 1], [-1, -1]]), + [[[1, 1], [-2, -2]], [[0, 0], [-2, -2]], [[1, 0], [-2, -2]]], + ), + ], + ) + def test_iter_alleles_and_genotypes(self, fx_simple_ts, ind_nodes, expected_gts): + model_mapping = tskit_model_mapping(ind_nodes) + format_obj = tsk.TskitFormat(fx_simple_ts, model_mapping=model_mapping) + + shape = (2, 2) # (num_samples, max_ploidy) + results = list(format_obj.iter_alleles_and_genotypes(0, 3, shape, 2)) + + assert len(results) == 3 + + for i, variant_data in enumerate(results): + if i == 0: + assert variant_data.variant_length == 1 + nt.assert_array_equal(variant_data.alleles, ("A", "TTTT")) + elif i == 1: + assert variant_data.variant_length == 3 + nt.assert_array_equal(variant_data.alleles, ("CCC", "G")) + elif i == 2: + assert variant_data.variant_length == 1 + nt.assert_array_equal(variant_data.alleles, ("G", "AA")) + + nt.assert_array_equal(variant_data.genotypes, expected_gts[i]) + assert np.all(variant_data.phased) + + def test_iter_alleles_and_genotypes_missing_node(self, fx_ts_2_diploids): + # Test with node ID that doesn't exist in tree sequence (out of range) + ind_nodes = np.array([[10, 11], [12, 13]], dtype=np.int32) + model_mapping = tskit_model_mapping(ind_nodes) + format_obj = tsk.TskitFormat(fx_ts_2_diploids, model_mapping=model_mapping) + shape = (2, 2) + with pytest.raises( + tskit.LibraryError, match="out of bounds" + ): # Node ID 10 doesn't exist + list(format_obj.iter_alleles_and_genotypes(0, 1, shape, 2)) + + def test_isolated_as_missing(self, fx_ts_isolated_samples): + ind_nodes = np.array([[0], [1], [3]]) + model_mapping = tskit_model_mapping(ind_nodes) + + format_obj_default = tsk.TskitFormat( + fx_ts_isolated_samples, + model_mapping=model_mapping, + isolated_as_missing=False, + ) + shape = (3, 1) # (num_samples, max_ploidy) + results_default = list( + format_obj_default.iter_alleles_and_genotypes(0, 1, shape, 2) + ) + + assert len(results_default) == 1 + variant_data_default = results_default[0] + nt.assert_array_equal(variant_data_default.alleles, ("0", "1")) + + # Sample 2 should have the ancestral state (0) when isolated_as_missing=False + expected_gt_default = np.array([[1], [0], [0]]) + nt.assert_array_equal(variant_data_default.genotypes, expected_gt_default) + + format_obj_missing = tsk.TskitFormat( + fx_ts_isolated_samples, + model_mapping=model_mapping, + isolated_as_missing=True, + ) + results_missing = list( + format_obj_missing.iter_alleles_and_genotypes(0, 1, shape, 2) + ) + + assert len(results_missing) == 1 + variant_data_missing = results_missing[0] + assert variant_data_missing.variant_length == 1 + nt.assert_array_equal(variant_data_missing.alleles, ("0", "1")) + + # Individual 2 should have missing values (-1) when isolated_as_missing=True + expected_gt_missing = np.array([[1], [0], [-1]]) + nt.assert_array_equal(variant_data_missing.genotypes, expected_gt_missing) + + def test_genotype_dtype_i1(self): + tables = tskit.TableCollection(sequence_length=100) + for _ in range(4): + tables.nodes.add_row(flags=tskit.NODE_IS_SAMPLE, time=0) + mrca = tables.nodes.add_row(flags=0, time=1) + for i in range(4): + tables.edges.add_row(left=0, right=100, parent=mrca, child=i) + site_id = tables.sites.add_row(position=10, ancestral_state="A") + tables.mutations.add_row(site=site_id, node=0, derived_state="T") + tables.sort() + tree_sequence = tables.tree_sequence() + + format_obj = tsk.TskitFormat(tree_sequence) + schema = format_obj.generate_schema() + call_genotype_spec = next(s for s in schema.fields if s.name == "call_genotype") + assert call_genotype_spec.dtype == "i1" + + def test_genotype_dtype_i4(self): + tables = tskit.TableCollection(sequence_length=100) + for _ in range(4): + tables.nodes.add_row(flags=tskit.NODE_IS_SAMPLE, time=0) + mrca = tables.nodes.add_row(flags=0, time=1) + for i in range(4): + tables.edges.add_row(left=0, right=100, parent=mrca, child=i) + site_id = tables.sites.add_row(position=10, ancestral_state="A") + for i in range(32768): + tables.mutations.add_row(site=site_id, node=0, derived_state=f"ALLELE_{i}") + + tables.sort() + tree_sequence = tables.tree_sequence() + + format_obj = tsk.TskitFormat(tree_sequence) + schema = format_obj.generate_schema() + call_genotype_spec = next(s for s in schema.fields if s.name == "call_genotype") + assert call_genotype_spec.dtype == "i4" + + +@pytest.mark.parametrize( + "ts", + [ + # Standard individuals-with-a-given-ploidy situation + add_mutations( + msprime.sim_ancestry(4, ploidy=1, sequence_length=10, random_seed=42) + ), + add_mutations( + msprime.sim_ancestry(2, ploidy=2, sequence_length=10, random_seed=42) + ), + add_mutations( + msprime.sim_ancestry(3, ploidy=12, sequence_length=10, random_seed=142) + ), + # No individuals, ploidy1 + add_mutations(msprime.simulate(4, length=10, random_seed=412)), + ], +) +def test_against_tskit_vcf_output(ts, tmp_path): + vcf_path = tmp_path / "ts.vcf" + with open(vcf_path, "w") as f: + ts.write_vcf(f) + + tskit_zarr = tmp_path / "tskit.zarr" + vcf_zarr = tmp_path / "vcf.zarr" + tsk.convert(ts, tskit_zarr, worker_processes=0) + + vcf.convert([vcf_path], vcf_zarr, worker_processes=0) + ds1 = sg.load_dataset(tskit_zarr) + ds2 = ( + sg.load_dataset(vcf_zarr) + .drop_dims("filters") + .drop_vars( + ["variant_id", "variant_id_mask", "variant_quality", "contig_length"] + ) + ) + xt.assert_equal(ds1, ds2) + + +def assert_ts_ds_equal(ts, ds, ploidy=2): + assert ds.sizes["ploidy"] == ploidy + assert ds.sizes["variants"] == ts.num_sites + assert ds.sizes["samples"] == ts.num_individuals + # Msprime guarantees that this will be true. + nt.assert_array_equal( + ts.genotype_matrix().reshape((ts.num_sites, ts.num_individuals, ploidy)), + ds.call_genotype.values, + ) + nt.assert_array_equal( + ds.call_genotype_phased.values, + np.ones((ts.num_sites, ts.num_individuals), dtype=bool), + ) + # Specialised for the limited form of mutations used here + nt.assert_equal( + ds.variant_allele[:, 0].values, [site.ancestral_state for site in ts.sites()] + ) + nt.assert_equal( + ds.variant_allele[:, 1].values, + [mutation.derived_state for mutation in ts.mutations()], + ) + nt.assert_equal(ds.variant_position, ts.sites_position) + + +@pytest.mark.parametrize("worker_processes", [0, 1, 2, 15]) +def test_workers(tmp_path, worker_processes): + ts = msprime.sim_ancestry(10, sequence_length=1000, random_seed=42) + ts = add_mutations(ts) + out = tmp_path / "tskit.zarr" + tsk.convert(ts, out, worker_processes=worker_processes) + ds = sg.load_dataset(out) + assert_ts_ds_equal(ts, ds)