Skip to content

Commit 304a85d

Browse files
committed
Use i8 for positions if needed
1 parent 4d5a9b4 commit 304a85d

File tree

2 files changed

+41
-1
lines changed

2 files changed

+41
-1
lines changed

bio2zarr/tskit.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -152,11 +152,17 @@ def generate_schema(
152152
f"{schema_instance.dimensions['samples'].chunk_size}"
153153
)
154154

155+
# Check if positions will fit in i4 (max ~2.1 billion)
156+
max_position = 0
157+
if self.ts.num_sites > 0:
158+
max_position = np.max(self.ts.sites_position)
159+
position_dtype = "i4" if max_position <= np.iinfo(np.int32).max else "i8"
160+
155161
array_specs = [
156162
vcz.ZarrArraySpec(
157163
source="position",
158164
name="variant_position",
159-
dtype="i4",
165+
dtype=position_dtype,
160166
dimensions=["variants"],
161167
description="Position of each variant",
162168
),

tests/test_ts.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,40 @@ def no_individuals_ts(self, tmp_path):
115115
tree_sequence.dump(ts_path)
116116
return ts_path, tree_sequence
117117

118+
def test_position_dtype_selection(self, tmp_path):
119+
tables = tskit.TableCollection(sequence_length=100)
120+
tables.nodes.add_row(flags=tskit.NODE_IS_SAMPLE, time=0)
121+
tables.sites.add_row(position=10, ancestral_state="A")
122+
tables.sites.add_row(position=20, ancestral_state="C")
123+
ts_small = tables.tree_sequence()
124+
ts_path_small = tmp_path / "small_positions.trees"
125+
ts_small.dump(ts_path_small)
126+
127+
tables = tskit.TableCollection(sequence_length=3_000_000_000)
128+
tables.nodes.add_row(flags=tskit.NODE_IS_SAMPLE, time=0)
129+
tables.sites.add_row(position=10, ancestral_state="A")
130+
tables.sites.add_row(position=np.iinfo(np.int32).max + 1, ancestral_state="C")
131+
ts_large = tables.tree_sequence()
132+
ts_path_large = tmp_path / "large_positions.trees"
133+
ts_large.dump(ts_path_large)
134+
135+
ind_nodes = np.array([[0], [1]])
136+
format_obj_small = ts.TskitFormat(ts_path_small, ind_nodes)
137+
schema_small = format_obj_small.generate_schema()
138+
139+
position_field = next(
140+
f for f in schema_small.fields if f.name == "variant_position"
141+
)
142+
assert position_field.dtype == "i4"
143+
144+
format_obj_large = ts.TskitFormat(ts_path_large, ind_nodes)
145+
schema_large = format_obj_large.generate_schema()
146+
147+
position_field = next(
148+
f for f in schema_large.fields if f.name == "variant_position"
149+
)
150+
assert position_field.dtype == "i8"
151+
118152
def test_initialization(self, simple_ts):
119153
ts_path, tree_sequence = simple_ts
120154
ind_nodes = np.array([[0, 1], [2, 3]])

0 commit comments

Comments
 (0)