Skip to content

Commit 99130dd

Browse files
committed
Move individual nodes to the caller, use numpy assign
1 parent 7302968 commit 99130dd

File tree

2 files changed

+157
-308
lines changed

2 files changed

+157
-308
lines changed

bio2zarr/tskit.py

Lines changed: 40 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -9,79 +9,46 @@
99
logger = logging.getLogger(__name__)
1010

1111

12-
import numpy as np
13-
import tskit
14-
15-
def individual_nodes(ts):
16-
"""
17-
Convert a tree sequence with individuals to a 2D array of node IDs.
18-
19-
Parameters
20-
----------
21-
ts : tskit.TreeSequence
22-
The tree sequence to convert
23-
24-
Returns
25-
-------
26-
numpy.ndarray
27-
Array of shape (num_individuals, max_ploidy) containing node IDs.
28-
Values of -1 indicate unused slots for individuals with ploidy
29-
less than the maximum.
30-
31-
Raises
32-
------
33-
ValueError
34-
If the tree sequence has no individuals, if any sample doesn't have an individual,
35-
if individuals have nodes that are both samples and non-samples, or if an
36-
individuals has no samples.
37-
"""
38-
if ts.num_individuals == 0:
39-
raise ValueError("Tree sequence has no individuals")
40-
41-
individuals = np.unique(ts.nodes_individual[ts.samples()])
42-
if len(individuals) == 1 and individuals[0] == tskit.NULL:
43-
raise ValueError("No samples refer to individuals")
44-
45-
# np.unique sorts the argument, so if NULL (-1) is present it will be first
46-
if individuals[0] == tskit.NULL:
47-
raise ValueError(
48-
"Sample nodes must all be associated with individuals"
49-
)
50-
51-
max_ploidy = 0
52-
for i in range(ts.num_individuals):
53-
ind = ts.individual(i)
54-
max_ploidy = max(max_ploidy, len(ind.nodes))
55-
56-
# Initialize output array with -1 (indicating no node)
57-
result = np.full((ts.num_individuals, max_ploidy), -1, dtype=np.int32)
58-
59-
for i in range(ts.num_individuals):
60-
ind = ts.individual(i)
61-
if len(ind.nodes) == 0:
62-
raise ValueError(f"Individual {i} not associated with any nodes")
63-
64-
is_sample = {ts.node(u).is_sample() for u in ind.nodes}
65-
if len(is_sample) != 1:
66-
raise ValueError(
67-
f"Individual {ind.id} has nodes that are sample and non-samples"
68-
)
69-
70-
for j, node_id in enumerate(ind.nodes):
71-
result[i, j] = node_id
72-
73-
return result
74-
7512
class TskitFormat(vcz.Source):
76-
def __init__(self, ts_path, contig_id=None, isolated_as_missing=False):
13+
def __init__(
14+
self,
15+
ts_path,
16+
individual_nodes,
17+
sample_ids=None,
18+
contig_id=None,
19+
isolated_as_missing=False,
20+
):
7721
self._path = ts_path
7822
self.ts = tskit.load(ts_path)
7923
self.contig_id = contig_id if contig_id is not None else "1"
8024
self.isolated_as_missing = isolated_as_missing
8125

82-
self._make_sample_mapping()
8326
self.positions = self.ts.sites_position
8427

28+
self._num_samples = individual_nodes.shape[0]
29+
if self._num_samples < 1:
30+
raise ValueError("individual_nodes must have at least one sample")
31+
self.max_ploidy = individual_nodes.shape[1]
32+
if sample_ids is None:
33+
sample_ids = [f"tsk_{j}" for j in range(self._num_samples)]
34+
elif len(sample_ids) != self._num_samples:
35+
raise ValueError(
36+
f"Length of sample_ids ({len(sample_ids)}) does not match "
37+
f"number of samples ({self._num_samples})"
38+
)
39+
40+
self._samples = [vcz.Sample(id=sample_id) for sample_id in sample_ids]
41+
42+
self.tskit_samples = np.unique(individual_nodes[individual_nodes >= 0])
43+
if len(self.tskit_samples) < 1:
44+
raise ValueError("individual_nodes must have at least one valid sample")
45+
node_id_to_index = {node_id: i for i, node_id in enumerate(self.tskit_samples)}
46+
valid_mask = individual_nodes >= 0
47+
self.sample_indices, self.ploidy_indices = np.where(valid_mask)
48+
self.genotype_indices = np.array(
49+
[node_id_to_index[node_id] for node_id in individual_nodes[valid_mask]]
50+
)
51+
8552
@property
8653
def path(self):
8754
return self._path
@@ -106,21 +73,6 @@ def root_attrs(self):
10673
def contigs(self):
10774
return [vcz.Contig(id=self.contig_id)]
10875

109-
def _make_sample_mapping(self):
110-
ts = self.ts
111-
112-
# Use individual_nodes to get the mapping between individuals and nodes
113-
try:
114-
# Get a 2D array of node IDs for each individual
115-
self.node_ids_array = individual_nodes(ts)
116-
self._num_samples = ts.num_individuals
117-
self.max_ploidy = self.node_ids_array.shape[1]
118-
119-
except ValueError as e:
120-
raise ValueError(f"Error mapping individuals to nodes: {e}") from e
121-
122-
self._samples = [vcz.Sample(id=f"tsk_{j}") for j in range(self.num_samples)]
123-
12476
def iter_contig(self, start, stop):
12577
yield from (0 for _ in range(start, stop))
12678

@@ -139,6 +91,7 @@ def iter_alleles_and_genotypes(self, start, stop, shape, num_alleles):
13991
isolated_as_missing=self.isolated_as_missing,
14092
left=self.positions[start],
14193
right=self.positions[stop] if stop < self.num_records else None,
94+
samples=self.tskit_samples,
14295
):
14396
gt = np.full(shape, constants.INT_FILL, dtype=np.int8)
14497
alleles = np.full(num_alleles, constants.STR_FILL, dtype="O")
@@ -149,12 +102,9 @@ def iter_alleles_and_genotypes(self, start, stop, shape, num_alleles):
149102
assert i < num_alleles
150103
alleles[i] = allele
151104

152-
# For each individual, get genotypes for their nodes
153-
for i in range(self.num_samples):
154-
for j in range(self.max_ploidy):
155-
node_id = self.node_ids_array[i, j]
156-
if node_id >= 0: # Skip -1 entries (unused slots)
157-
gt[i, j] = variant.genotypes[node_id]
105+
gt[self.sample_indices, self.ploidy_indices] = variant.genotypes[
106+
self.genotype_indices
107+
]
158108

159109
yield alleles, (gt, phased)
160110

@@ -253,7 +203,9 @@ def generate_schema(
253203
def convert(
254204
ts_path,
255205
zarr_path,
206+
individual_nodes,
256207
*,
208+
sample_ids=None,
257209
contig_id=None,
258210
isolated_as_missing=False,
259211
variants_chunk_size=None,
@@ -263,6 +215,8 @@ def convert(
263215
):
264216
tskit_format = TskitFormat(
265217
ts_path,
218+
individual_nodes,
219+
sample_ids=sample_ids,
266220
contig_id=contig_id,
267221
isolated_as_missing=isolated_as_missing,
268222
)

0 commit comments

Comments
 (0)