Skip to content

Commit 7302968

Browse files
committed
Simplify and use individuals_node method
1 parent e0dc5f6 commit 7302968

File tree

2 files changed

+191
-109
lines changed

2 files changed

+191
-109
lines changed

bio2zarr/tskit.py

Lines changed: 84 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,77 @@
99
logger = logging.getLogger(__name__)
1010

1111

12+
import numpy as np
13+
import tskit
14+
15+
def individual_nodes(ts):
16+
"""
17+
Convert a tree sequence with individuals to a 2D array of node IDs.
18+
19+
Parameters
20+
----------
21+
ts : tskit.TreeSequence
22+
The tree sequence to convert
23+
24+
Returns
25+
-------
26+
numpy.ndarray
27+
Array of shape (num_individuals, max_ploidy) containing node IDs.
28+
Values of -1 indicate unused slots for individuals with ploidy
29+
less than the maximum.
30+
31+
Raises
32+
------
33+
ValueError
34+
If the tree sequence has no individuals, if any sample doesn't have an individual,
35+
if individuals have nodes that are both samples and non-samples, or if an
36+
individuals has no samples.
37+
"""
38+
if ts.num_individuals == 0:
39+
raise ValueError("Tree sequence has no individuals")
40+
41+
individuals = np.unique(ts.nodes_individual[ts.samples()])
42+
if len(individuals) == 1 and individuals[0] == tskit.NULL:
43+
raise ValueError("No samples refer to individuals")
44+
45+
# np.unique sorts the argument, so if NULL (-1) is present it will be first
46+
if individuals[0] == tskit.NULL:
47+
raise ValueError(
48+
"Sample nodes must all be associated with individuals"
49+
)
50+
51+
max_ploidy = 0
52+
for i in range(ts.num_individuals):
53+
ind = ts.individual(i)
54+
max_ploidy = max(max_ploidy, len(ind.nodes))
55+
56+
# Initialize output array with -1 (indicating no node)
57+
result = np.full((ts.num_individuals, max_ploidy), -1, dtype=np.int32)
58+
59+
for i in range(ts.num_individuals):
60+
ind = ts.individual(i)
61+
if len(ind.nodes) == 0:
62+
raise ValueError(f"Individual {i} not associated with any nodes")
63+
64+
is_sample = {ts.node(u).is_sample() for u in ind.nodes}
65+
if len(is_sample) != 1:
66+
raise ValueError(
67+
f"Individual {ind.id} has nodes that are sample and non-samples"
68+
)
69+
70+
for j, node_id in enumerate(ind.nodes):
71+
result[i, j] = node_id
72+
73+
return result
74+
1275
class TskitFormat(vcz.Source):
13-
def __init__(self, ts_path, contig_id=None, ploidy=None, isolated_as_missing=False):
76+
def __init__(self, ts_path, contig_id=None, isolated_as_missing=False):
1477
self._path = ts_path
1578
self.ts = tskit.load(ts_path)
1679
self.contig_id = contig_id if contig_id is not None else "1"
1780
self.isolated_as_missing = isolated_as_missing
1881

19-
self._make_sample_mapping(ploidy)
82+
self._make_sample_mapping()
2083
self.positions = self.ts.sites_position
2184

2285
@property
@@ -43,62 +106,19 @@ def root_attrs(self):
43106
def contigs(self):
44107
return [vcz.Contig(id=self.contig_id)]
45108

46-
def _make_sample_mapping(self, ploidy):
109+
def _make_sample_mapping(self):
47110
ts = self.ts
48-
self.individual_ploidies = []
49-
self.max_ploidy = 0
50-
51-
if ts.num_individuals > 0 and ploidy is not None:
52-
raise ValueError(
53-
"Cannot specify ploidy when individuals are present in tables"
54-
)
55-
56-
# Find all sample nodes that reference individuals
57-
individuals = np.unique(ts.nodes_individual[ts.samples()])
58-
if len(individuals) == 1 and individuals[0] == tskit.NULL:
59-
# No samples refer to individuals
60-
individuals = None
61-
else:
62-
# np.unique sorts the argument, so if NULL (-1) is present it
63-
# will be the first value.
64-
if individuals[0] == tskit.NULL:
65-
raise ValueError(
66-
"Sample nodes must either all be associated with individuals "
67-
"or not associated with any individuals"
68-
)
69-
70-
if individuals is not None:
71-
self.sample_ids = []
72-
for i in individuals:
73-
if i < 0 or i >= self.ts.num_individuals:
74-
raise ValueError("Invalid individual IDs provided.")
75-
ind = self.ts.individual(i)
76-
if len(ind.nodes) == 0:
77-
raise ValueError(f"Individual {i} not associated with a node")
78-
is_sample = {ts.node(u).is_sample() for u in ind.nodes}
79-
if len(is_sample) != 1:
80-
raise ValueError(
81-
f"Individual {ind.id} has nodes that are sample and "
82-
"non-samples"
83-
)
84-
self.sample_ids.extend(ind.nodes)
85-
self.individual_ploidies.append(len(ind.nodes))
86-
self.max_ploidy = max(self.max_ploidy, len(ind.nodes))
87-
else:
88-
if ploidy is None:
89-
ploidy = 1
90-
if ploidy < 1:
91-
raise ValueError("Ploidy must be >= 1")
92-
if ts.num_samples % ploidy != 0:
93-
raise ValueError("Sample size must be divisible by ploidy")
94-
self.individual_ploidies = np.full(
95-
ts.num_samples // ploidy, ploidy, dtype=np.int32
96-
)
97-
self.max_ploidy = ploidy
98-
self.sample_ids = (ts.nodes_flags & tskit.NODE_IS_SAMPLE).nonzero()[0]
99-
100-
self._num_samples = len(self.individual_ploidies)
101-
111+
112+
# Use individual_nodes to get the mapping between individuals and nodes
113+
try:
114+
# Get a 2D array of node IDs for each individual
115+
self.node_ids_array = individual_nodes(ts)
116+
self._num_samples = ts.num_individuals
117+
self.max_ploidy = self.node_ids_array.shape[1]
118+
119+
except ValueError as e:
120+
raise ValueError(f"Error mapping individuals to nodes: {e}") from e
121+
102122
self._samples = [vcz.Sample(id=f"tsk_{j}") for j in range(self.num_samples)]
103123

104124
def iter_contig(self, start, stop):
@@ -111,25 +131,11 @@ def iter_field(self, field_name, shape, start, stop):
111131
else:
112132
raise ValueError(f"Unknown field {field_name}")
113133

114-
def iter_alleles(self, start, stop, num_alleles):
115-
for variant in self.ts.variants(
116-
samples=self.sample_ids,
117-
isolated_as_missing=self.isolated_as_missing,
118-
left=self.positions[start],
119-
right=self.positions[stop] if stop < self.num_records else None,
120-
):
121-
alleles = np.full(num_alleles, constants.STR_FILL, dtype="O")
122-
for i, allele in enumerate(variant.alleles):
123-
assert i < num_alleles
124-
alleles[i] = allele
125-
yield alleles
126-
127134
def iter_alleles_and_genotypes(self, start, stop, shape, num_alleles):
128-
# In tskit, all genotypes are considered phased
135+
# All genotypes in tskit are considered phased
129136
phased = np.ones(shape[:-1], dtype=bool)
130137

131138
for variant in self.ts.variants(
132-
samples=self.sample_ids,
133139
isolated_as_missing=self.isolated_as_missing,
134140
left=self.positions[start],
135141
right=self.positions[stop] if stop < self.num_records else None,
@@ -143,13 +149,12 @@ def iter_alleles_and_genotypes(self, start, stop, shape, num_alleles):
143149
assert i < num_alleles
144150
alleles[i] = allele
145151

146-
genotypes = variant.genotypes
147-
sample_index = 0
148-
for i, ploidy in enumerate(self.individual_ploidies):
149-
for j in range(ploidy):
150-
if j < self.max_ploidy:
151-
gt[i, j] = genotypes[sample_index + j]
152-
sample_index += ploidy
152+
# For each individual, get genotypes for their nodes
153+
for i in range(self.num_samples):
154+
for j in range(self.max_ploidy):
155+
node_id = self.node_ids_array[i, j]
156+
if node_id >= 0: # Skip -1 entries (unused slots)
157+
gt[i, j] = variant.genotypes[node_id]
153158

154159
yield alleles, (gt, phased)
155160

@@ -250,7 +255,6 @@ def convert(
250255
zarr_path,
251256
*,
252257
contig_id=None,
253-
ploidy=None,
254258
isolated_as_missing=False,
255259
variants_chunk_size=None,
256260
samples_chunk_size=None,
@@ -260,7 +264,6 @@ def convert(
260264
tskit_format = TskitFormat(
261265
ts_path,
262266
contig_id=contig_id,
263-
ploidy=ploidy,
264267
isolated_as_missing=isolated_as_missing,
265268
)
266269
schema_instance = tskit_format.generate_schema(

tests/test_ts.py

Lines changed: 107 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -154,32 +154,12 @@ def test_sample_mapping_with_individuals(self, simple_ts):
154154
assert format_obj.max_ploidy == 2
155155
assert format_obj.individual_ploidies == [2, 2]
156156

157-
# Should raise error if ploidy specified with individuals
158-
with pytest.raises(
159-
ValueError, match="Cannot specify ploidy when individuals are present"
160-
):
161-
ts.TskitFormat(ts_path, ploidy=2)
162-
163-
def test_sample_mapping_without_individuals(self, no_individuals_ts):
164-
ts_path, tree_sequence = no_individuals_ts
165-
166-
# Default ploidy should be 1
167-
format_obj = ts.TskitFormat(ts_path)
168-
assert format_obj.num_samples == 4
169-
assert format_obj.max_ploidy == 1
170-
assert list(format_obj.individual_ploidies) == [1, 1, 1, 1]
171-
172-
# Explicitly set ploidy to 2
173-
format_obj = ts.TskitFormat(ts_path, ploidy=2)
174-
assert format_obj.num_samples == 2
175-
assert format_obj.max_ploidy == 2
176-
assert list(format_obj.individual_ploidies) == [2, 2]
177-
178-
with pytest.raises(ValueError, match="Ploidy must be >= 1"):
179-
ts.TskitFormat(ts_path, ploidy=0)
180-
181-
with pytest.raises(ValueError, match="Sample size must be divisible by ploidy"):
182-
ts.TskitFormat(ts_path, ploidy=3)
157+
def test_no_individuals(self, no_individuals_ts):
158+
"""Test that tree sequences without individuals raise an error."""
159+
ts_path, _ = no_individuals_ts
160+
161+
with pytest.raises(ValueError, match="Tree sequence has no individuals"):
162+
ts.TskitFormat(ts_path)
183163

184164
def test_schema_generation(self, simple_ts):
185165
ts_path, _ = simple_ts
@@ -307,8 +287,7 @@ def test_variable_ploidy(self, tmp_path):
307287
format_obj = ts.TskitFormat(ts_path)
308288

309289
assert format_obj.max_ploidy == 3
310-
assert format_obj.individual_ploidies == [2, 3]
311-
290+
312291
shape = (2, 3) # (num_samples, max_ploidy)
313292
results = list(format_obj.iter_alleles_and_genotypes(0, 2, shape, 2))
314293

@@ -387,3 +366,103 @@ def insert_branch_sites(ts, m=1):
387366
# Individual 2 should have missing values (-1) when isolated_as_missing=True
388367
expected_gt_missing = np.array([[1], [0], [-1]])
389368
assert np.array_equal(gt_missing, expected_gt_missing)
369+
370+
371+
class TestIndividualNodes:
372+
373+
def test_basic_individual_nodes(self, tmp_path):
374+
# Create a basic tree sequence with two individuals
375+
tables = tskit.TableCollection(sequence_length=100)
376+
tables.individuals.add_row(flags=0, location=(0, 0), metadata=b"")
377+
tables.individuals.add_row(flags=0, location=(0, 0), metadata=b"")
378+
tables.nodes.add_row(flags=tskit.NODE_IS_SAMPLE, time=0, individual=0)
379+
tables.nodes.add_row(flags=tskit.NODE_IS_SAMPLE, time=0, individual=0)
380+
tables.nodes.add_row(flags=tskit.NODE_IS_SAMPLE, time=0, individual=1)
381+
tables.nodes.add_row(flags=tskit.NODE_IS_SAMPLE, time=0, individual=1)
382+
tree_sequence = tables.tree_sequence()
383+
384+
result = ts.individual_nodes(tree_sequence)
385+
assert result.shape == (2, 2)
386+
assert np.array_equal(result, [[0, 1], [2, 3]])
387+
388+
def test_variable_ploidy(self, tmp_path):
389+
tables = tskit.TableCollection(sequence_length=100)
390+
tables.individuals.add_row(flags=0, location=(0, 0), metadata=b"") # Diploid
391+
tables.individuals.add_row(flags=0, location=(0, 0), metadata=b"") # Haploid
392+
tables.individuals.add_row(flags=0, location=(0, 0), metadata=b"") # Triploid
393+
394+
# Diploid individual
395+
tables.nodes.add_row(flags=tskit.NODE_IS_SAMPLE, time=0, individual=0)
396+
tables.nodes.add_row(flags=tskit.NODE_IS_SAMPLE, time=0, individual=0)
397+
398+
# Haploid individual
399+
tables.nodes.add_row(flags=tskit.NODE_IS_SAMPLE, time=0, individual=1)
400+
401+
# Triploid individual
402+
tables.nodes.add_row(flags=tskit.NODE_IS_SAMPLE, time=0, individual=2)
403+
tables.nodes.add_row(flags=tskit.NODE_IS_SAMPLE, time=0, individual=2)
404+
tables.nodes.add_row(flags=tskit.NODE_IS_SAMPLE, time=0, individual=2)
405+
406+
tree_sequence = tables.tree_sequence()
407+
408+
result = ts.individual_nodes(tree_sequence)
409+
410+
assert result.shape == (3, 3)
411+
412+
expected = np.array([
413+
[0, 1, -1], # Diploid
414+
[2, -1, -1], # Haploid
415+
[3, 4, 5] # Triploid
416+
])
417+
assert np.array_equal(result, expected)
418+
419+
def test_no_individuals(self):
420+
tables = tskit.TableCollection(sequence_length=100)
421+
tables.nodes.add_row(flags=tskit.NODE_IS_SAMPLE, time=0)
422+
tree_sequence = tables.tree_sequence()
423+
424+
with pytest.raises(ValueError, match="Tree sequence has no individuals"):
425+
ts.individual_nodes(tree_sequence)
426+
427+
def test_no_samples_with_individuals(self):
428+
tables = tskit.TableCollection(sequence_length=100)
429+
tables.individuals.add_row(flags=0, location=(0, 0), metadata=b"")
430+
# Node without individual reference
431+
tables.nodes.add_row(flags=tskit.NODE_IS_SAMPLE, time=0)
432+
tree_sequence = tables.tree_sequence()
433+
434+
with pytest.raises(ValueError, match="No samples refer to individuals"):
435+
ts.individual_nodes(tree_sequence)
436+
437+
def test_mixed_individual_references(self):
438+
tables = tskit.TableCollection(sequence_length=100)
439+
tables.individuals.add_row(flags=0, location=(0, 0), metadata=b"")
440+
# One node with individual reference
441+
tables.nodes.add_row(flags=tskit.NODE_IS_SAMPLE, time=0, individual=0)
442+
# One node without individual reference
443+
tables.nodes.add_row(flags=tskit.NODE_IS_SAMPLE, time=0)
444+
tree_sequence = tables.tree_sequence()
445+
446+
with pytest.raises(ValueError, match="Sample nodes must all be associated with individuals"):
447+
ts.individual_nodes(tree_sequence)
448+
449+
def test_individual_with_no_nodes(self):
450+
tables = tskit.TableCollection(sequence_length=100)
451+
tables.individuals.add_row(flags=0, location=(0, 0), metadata=b"")
452+
tables.individuals.add_row(flags=0, location=(0, 0), metadata=b"")
453+
# Only add nodes for first individual
454+
tables.nodes.add_row(flags=tskit.NODE_IS_SAMPLE, time=0, individual=0)
455+
tree_sequence = tables.tree_sequence()
456+
457+
with pytest.raises(ValueError, match="Individual 1 not associated with any nodes"):
458+
ts.individual_nodes(tree_sequence)
459+
460+
def test_mixed_sample_status(self):
461+
tables = tskit.TableCollection(sequence_length=100)
462+
tables.individuals.add_row(flags=0, location=(0, 0), metadata=b"")
463+
tables.nodes.add_row(flags=tskit.NODE_IS_SAMPLE, time=0, individual=0)
464+
tables.nodes.add_row(flags=0, time=0, individual=0)
465+
tree_sequence = tables.tree_sequence()
466+
467+
with pytest.raises(ValueError, match="has nodes that are sample and non-samples"):
468+
ts.individual_nodes(tree_sequence)

0 commit comments

Comments
 (0)