Skip to content

Commit 6eb4caf

Browse files
committed
Dynamically pick call_genotype dtype
1 parent 304a85d commit 6eb4caf

File tree

2 files changed

+41
-1
lines changed

2 files changed

+41
-1
lines changed

bio2zarr/tskit.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,7 @@ def generate_schema(
191191
vcz.ZarrArraySpec(
192192
source=None,
193193
name="call_genotype",
194-
dtype="i1",
194+
dtype=core.min_int_dtype(constants.INT_FILL, max_alleles - 1),
195195
dimensions=["variants", "samples", "ploidy"],
196196
description="Genotype for each variant and sample",
197197
compressor=vcz.DEFAULT_ZARR_COMPRESSOR_GENOTYPES.get_config(),

tests/test_ts.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -395,3 +395,43 @@ def insert_branch_sites(ts, m=1):
395395
# Individual 2 should have missing values (-1) when isolated_as_missing=True
396396
expected_gt_missing = np.array([[1], [0], [-1]])
397397
assert np.array_equal(gt_missing, expected_gt_missing)
398+
399+
def test_genotype_dtype_selection(self, tmp_path):
400+
tables = tskit.TableCollection(sequence_length=100)
401+
for _ in range(4):
402+
tables.nodes.add_row(flags=tskit.NODE_IS_SAMPLE, time=0)
403+
mrca = tables.nodes.add_row(flags=0, time=1)
404+
for i in range(4):
405+
tables.edges.add_row(left=0, right=100, parent=mrca, child=i)
406+
site_id = tables.sites.add_row(position=10, ancestral_state="A")
407+
tables.mutations.add_row(site=site_id, node=0, derived_state="T")
408+
tables.sort()
409+
tree_sequence = tables.tree_sequence()
410+
ts_path = tmp_path / "small_alleles.trees"
411+
tree_sequence.dump(ts_path)
412+
413+
ind_nodes = np.array([[0, 1], [2, 3]])
414+
format_obj = ts.TskitFormat(ts_path, ind_nodes)
415+
schema = format_obj.generate_schema()
416+
call_genotype_spec = next(s for s in schema.fields if s.name == "call_genotype")
417+
assert call_genotype_spec.dtype == "i1"
418+
419+
tables = tskit.TableCollection(sequence_length=100)
420+
for _ in range(4):
421+
tables.nodes.add_row(flags=tskit.NODE_IS_SAMPLE, time=0)
422+
mrca = tables.nodes.add_row(flags=0, time=1)
423+
for i in range(4):
424+
tables.edges.add_row(left=0, right=100, parent=mrca, child=i)
425+
site_id = tables.sites.add_row(position=10, ancestral_state="A")
426+
for i in range(32768):
427+
tables.mutations.add_row(site=site_id, node=0, derived_state=f"ALLELE_{i}")
428+
429+
tables.sort()
430+
tree_sequence = tables.tree_sequence()
431+
ts_path = tmp_path / "large_alleles.trees"
432+
tree_sequence.dump(ts_path)
433+
434+
format_obj = ts.TskitFormat(ts_path, ind_nodes)
435+
schema = format_obj.generate_schema()
436+
call_genotype_spec = next(s for s in schema.fields if s.name == "call_genotype")
437+
assert call_genotype_spec.dtype == "i4"

0 commit comments

Comments
 (0)