Skip to content

Commit 9f5ae9a

Browse files
committed
Inital ts convert
1 parent 0cbf5cb commit 9f5ae9a

File tree

2 files changed

+331
-0
lines changed

2 files changed

+331
-0
lines changed

bio2zarr/ts.py

Lines changed: 271 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,271 @@
1+
import logging
2+
import pathlib
3+
4+
import numpy as np
5+
import tskit
6+
7+
from bio2zarr import constants, core, vcz
8+
9+
logger = logging.getLogger(__name__)
10+
11+
12+
class TskitFormat:
13+
def __init__(self, ts_path, contig_id=None, ploidy=None, isolated_as_missing=False):
14+
self.path = ts_path
15+
self.ts = tskit.load(ts_path)
16+
self.contig_id = contig_id if contig_id is not None else "1"
17+
self.isolated_as_missing = isolated_as_missing
18+
self.root_attrs = {}
19+
20+
self._make_sample_mapping(ploidy)
21+
self.contigs = [vcz.Contig(id=self.contig_id)]
22+
self.num_records = self.ts.num_sites
23+
self.positions = self.ts.sites_position
24+
25+
def _make_sample_mapping(self, ploidy):
26+
ts = self.ts
27+
self.individual_ploidies = []
28+
self.max_ploidy = 0
29+
30+
if ts.num_individuals > 0 and ploidy is not None:
31+
raise ValueError(
32+
"Cannot specify ploidy when individuals are present in tables"
33+
)
34+
35+
# Find all sample nodes that reference individuals
36+
individuals = np.unique(ts.tables.nodes.individual[ts.samples()])
37+
if len(individuals) == 1 and individuals[0] == tskit.NULL:
38+
# No samples refer to individuals
39+
individuals = None
40+
else:
41+
# np.unique sorts the argument, so if NULL (-1) is present it
42+
# will be the first value.
43+
if individuals[0] == tskit.NULL:
44+
raise ValueError(
45+
"Sample nodes must either all be associated with individuals "
46+
"or not associated with any individuals"
47+
)
48+
49+
if individuals is not None:
50+
self.sample_ids = []
51+
for i in individuals:
52+
if i < 0 or i >= self.ts.num_individuals:
53+
raise ValueError("Invalid individual IDs provided.")
54+
ind = self.ts.individual(i)
55+
if len(ind.nodes) == 0:
56+
raise ValueError(f"Individual {i} not associated with a node")
57+
is_sample = {ts.node(u).is_sample() for u in ind.nodes}
58+
if len(is_sample) != 1:
59+
raise ValueError(
60+
f"Individual {ind.id} has nodes that are sample and "
61+
"non-samples"
62+
)
63+
self.sample_ids.extend(ind.nodes)
64+
self.individual_ploidies.append(len(ind.nodes))
65+
self.max_ploidy = max(self.max_ploidy, len(ind.nodes))
66+
else:
67+
if ploidy is None:
68+
ploidy = 1
69+
if ploidy < 1:
70+
raise ValueError("Ploidy must be >= 1")
71+
if ts.num_samples % ploidy != 0:
72+
raise ValueError("Sample size must be divisible by ploidy")
73+
self.individual_ploidies = np.full(
74+
ts.num_samples // ploidy, ploidy, dtype=np.int32
75+
)
76+
self.max_ploidy = ploidy
77+
self.sample_ids = np.arange(ts.num_samples, dtype=np.int32)
78+
79+
self.num_samples = len(self.individual_ploidies)
80+
81+
self.samples = [vcz.Sample(id=f"tsk_{j}") for j in range(self.num_samples)]
82+
83+
def iter_alleles(self, start, stop, num_alleles):
84+
for variant in self.ts.variants(
85+
samples=self.sample_ids,
86+
isolated_as_missing=self.isolated_as_missing,
87+
left=self.positions[start],
88+
right=self.positions[stop] if stop < self.num_records else None,
89+
):
90+
alleles = np.full(num_alleles, constants.STR_FILL, dtype="O")
91+
for i, allele in enumerate(variant.alleles):
92+
assert i < num_alleles
93+
alleles[i] = allele
94+
yield alleles
95+
96+
def iter_contig(self, start, stop):
97+
yield from (0 for _ in range(start, stop))
98+
99+
def iter_field(self, field_name, shape, start, stop):
100+
if field_name == "position":
101+
for pos in self.ts.tables.sites.position[start:stop]:
102+
yield int(pos)
103+
else:
104+
raise ValueError(f"Unknown field {field_name}")
105+
106+
def iter_genotypes(self, shape, start, stop):
107+
gt = np.zeros(shape, dtype=np.int8)
108+
phased = np.zeros(shape[:-1], dtype=bool)
109+
110+
for variant in self.ts.variants(
111+
samples=self.sample_ids,
112+
isolated_as_missing=self.isolated_as_missing,
113+
left=self.positions[start],
114+
right=self.positions[stop] if stop < self.num_records else None,
115+
):
116+
genotypes = variant.genotypes
117+
118+
sample_index = 0
119+
for i, ploidy in enumerate(self.individual_ploidies):
120+
for j in range(ploidy):
121+
if j < self.max_ploidy: # Only fill up to max_ploidy
122+
try:
123+
gt[i, j] = genotypes[sample_index + j]
124+
except IndexError:
125+
# This can happen if the ploidy varies between individuals
126+
gt[i, j] = -2 # Fill value
127+
128+
# In tskit, all genotypes are considered phased
129+
phased[i] = True
130+
sample_index += ploidy
131+
132+
yield gt, phased
133+
134+
def generate_schema(
135+
self,
136+
variants_chunk_size=None,
137+
samples_chunk_size=None,
138+
):
139+
n = self.num_samples
140+
m = self.ts.num_sites
141+
142+
# Determine max number of alleles
143+
max_alleles = 0
144+
for variant in self.ts.variants():
145+
max_alleles = max(max_alleles, len(variant.alleles))
146+
147+
logging.info(f"Scanned tskit with {n} samples and {m} variants")
148+
logging.info(
149+
f"Maximum ploidy: {self.max_ploidy}, maximum alleles: {max_alleles}"
150+
)
151+
152+
schema_instance = vcz.VcfZarrSchema(
153+
format_version=vcz.ZARR_SCHEMA_FORMAT_VERSION,
154+
samples_chunk_size=samples_chunk_size,
155+
variants_chunk_size=variants_chunk_size,
156+
fields=[],
157+
)
158+
159+
logger.info(
160+
"Generating schema with chunks="
161+
f"{schema_instance.variants_chunk_size, schema_instance.samples_chunk_size}"
162+
)
163+
164+
array_specs = [
165+
vcz.ZarrArraySpec.new(
166+
vcf_field="position",
167+
name="variant_position",
168+
dtype="i4",
169+
shape=[m],
170+
dimensions=["variants"],
171+
chunks=[schema_instance.variants_chunk_size],
172+
description="Position of each variant",
173+
),
174+
vcz.ZarrArraySpec.new(
175+
vcf_field=None,
176+
name="variant_allele",
177+
dtype="O",
178+
shape=[m, max_alleles],
179+
dimensions=["variants", "alleles"],
180+
chunks=[schema_instance.variants_chunk_size, max_alleles],
181+
description="Alleles for each variant",
182+
),
183+
vcz.ZarrArraySpec.new(
184+
vcf_field=None,
185+
name="variant_contig",
186+
dtype=core.min_int_dtype(0, len(self.contigs)),
187+
shape=[m],
188+
dimensions=["variants"],
189+
chunks=[schema_instance.variants_chunk_size],
190+
description="Contig/chromosome index for each variant",
191+
),
192+
vcz.ZarrArraySpec.new(
193+
vcf_field=None,
194+
name="call_genotype_phased",
195+
dtype="bool",
196+
shape=[m, n],
197+
dimensions=["variants", "samples"],
198+
chunks=[
199+
schema_instance.variants_chunk_size,
200+
schema_instance.samples_chunk_size,
201+
],
202+
description="Whether the genotype is phased",
203+
),
204+
vcz.ZarrArraySpec.new(
205+
vcf_field=None,
206+
name="call_genotype",
207+
dtype="i1",
208+
shape=[m, n, self.max_ploidy],
209+
dimensions=["variants", "samples", "ploidy"],
210+
chunks=[
211+
schema_instance.variants_chunk_size,
212+
schema_instance.samples_chunk_size,
213+
self.max_ploidy,
214+
],
215+
description="Genotype for each variant and sample",
216+
),
217+
vcz.ZarrArraySpec.new(
218+
vcf_field=None,
219+
name="call_genotype_mask",
220+
dtype="bool",
221+
shape=[m, n, self.max_ploidy],
222+
dimensions=["variants", "samples", "ploidy"],
223+
chunks=[
224+
schema_instance.variants_chunk_size,
225+
schema_instance.samples_chunk_size,
226+
self.max_ploidy,
227+
],
228+
description="Mask for each genotype call",
229+
),
230+
]
231+
schema_instance.fields = array_specs
232+
return schema_instance
233+
234+
235+
def convert(
236+
ts_path,
237+
zarr_path,
238+
*,
239+
contig_id=None,
240+
ploidy=None,
241+
isolated_as_missing=False,
242+
variants_chunk_size=None,
243+
samples_chunk_size=None,
244+
worker_processes=1,
245+
show_progress=False,
246+
):
247+
tskit_format = TskitFormat(
248+
ts_path,
249+
contig_id=contig_id,
250+
ploidy=ploidy,
251+
isolated_as_missing=isolated_as_missing,
252+
)
253+
schema_instance = tskit_format.generate_schema(
254+
variants_chunk_size=variants_chunk_size,
255+
samples_chunk_size=samples_chunk_size,
256+
)
257+
zarr_path = pathlib.Path(zarr_path)
258+
vzw = vcz.VcfZarrWriter(TskitFormat, zarr_path)
259+
# Rough heuristic to split work up enough to keep utilisation high
260+
target_num_partitions = max(1, worker_processes * 4)
261+
vzw.init(
262+
tskit_format,
263+
target_num_partitions=target_num_partitions,
264+
schema=schema_instance,
265+
)
266+
vzw.encode_all_partitions(
267+
worker_processes=worker_processes,
268+
show_progress=show_progress,
269+
)
270+
vzw.finalise(show_progress)
271+
vzw.create_index()

tests/test_ts.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
import os
2+
import tempfile
3+
4+
import numpy as np
5+
import tskit
6+
import zarr
7+
8+
from bio2zarr import ts
9+
10+
11+
class TestTskit:
12+
def test_simple_tree_sequence(self, tmp_path):
13+
tables = tskit.TableCollection(sequence_length=100)
14+
tables.individuals.add_row(flags=0, location=(0, 0), metadata=b"")
15+
tables.individuals.add_row(flags=0, location=(0, 0), metadata=b"")
16+
tables.nodes.add_row(flags=tskit.NODE_IS_SAMPLE, time=0, individual=0)
17+
tables.nodes.add_row(flags=tskit.NODE_IS_SAMPLE, time=0, individual=0)
18+
tables.nodes.add_row(flags=tskit.NODE_IS_SAMPLE, time=0, individual=1)
19+
tables.nodes.add_row(flags=tskit.NODE_IS_SAMPLE, time=0, individual=1)
20+
tables.nodes.add_row(flags=0, time=1) # MRCA for 0,1
21+
tables.nodes.add_row(flags=0, time=1) # MRCA for 2,3
22+
tables.edges.add_row(left=0, right=100, parent=4, child=0)
23+
tables.edges.add_row(left=0, right=100, parent=4, child=1)
24+
tables.edges.add_row(left=0, right=100, parent=5, child=2)
25+
tables.edges.add_row(left=0, right=100, parent=5, child=3)
26+
site_id = tables.sites.add_row(position=10, ancestral_state="A")
27+
tables.mutations.add_row(site=site_id, node=4, derived_state="T")
28+
site_id = tables.sites.add_row(position=20, ancestral_state="C")
29+
tables.mutations.add_row(site=site_id, node=5, derived_state="G")
30+
site_id = tables.sites.add_row(position=30, ancestral_state="G")
31+
tables.mutations.add_row(site=site_id, node=0, derived_state="A")
32+
tables.sort()
33+
tree_sequence = tables.tree_sequence()
34+
tree_sequence.dump(tmp_path / "test.trees")
35+
with tempfile.TemporaryDirectory() as tempdir:
36+
zarr_path = os.path.join(tempdir, "test_output.zarr")
37+
ts.convert(tmp_path / "test.trees", zarr_path, show_progress=False)
38+
zroot = zarr.open(zarr_path, mode="r")
39+
assert zroot["variant_position"].shape == (3,)
40+
assert list(zroot["variant_position"][:]) == [10, 20, 30]
41+
42+
alleles = zroot["variant_allele"][:]
43+
assert np.array_equal(alleles, [["A", "T"], ["C", "G"], ["G", "A"]])
44+
45+
genotypes = zroot["call_genotype"][:]
46+
assert np.array_equal(
47+
genotypes, [[[1, 1], [0, 0]], [[0, 0], [1, 1]], [[1, 0], [0, 0]]]
48+
)
49+
50+
phased = zroot["call_genotype_phased"][:]
51+
assert np.all(phased)
52+
53+
contigs = zroot["contig_id"][:]
54+
assert np.array_equal(contigs, ["1"])
55+
56+
contig = zroot["variant_contig"][:]
57+
assert np.array_equal(contig, [0, 0, 0])
58+
59+
samples = zroot["sample_id"][:]
60+
assert np.array_equal(samples, ["tsk_0", "tsk_1"])

0 commit comments

Comments
 (0)