Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ dependencies = [
"humanize",
"daiquiri",
"tskit>=1.0.0",
"zarr<3",
"zarr>=3,<4",
"lmdb",
"sortedcontainers",
"numba",
Expand Down
90 changes: 43 additions & 47 deletions tests/test_formats.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@

import lmdb
import msprime
import numcodecs
import numcodecs.blosc as blosc
import numpy as np
import pytest
Expand Down Expand Up @@ -2245,8 +2244,10 @@ def test_zero_sequence_length(self):
with tsinfer.SampleData(path=filename) as sample_data:
for var in ts.variants():
sample_data.add_site(var.site.position, var.genotypes)
store = zarr.LMDBStore(filename, subdir=False)
data = zarr.open(store=store, mode="w+")
from tsinfer._lmdb_store import LMDBStore

store = LMDBStore(filename, subdir=False)
data = zarr.open(store=store, mode="r+")
data.attrs["sequence_length"] = 0
store.close()
sample_data = tsinfer.load(filename)
Expand Down Expand Up @@ -2572,7 +2573,7 @@ def verify_round_trip(self, source):
dest = {}
num_rows = -1
for key, array in source.items():
dest[key] = zarr.empty_like(array)
dest[key] = zarr.empty_like(array, zarr_format=2)
if num_rows == -1:
num_rows = array.shape[0]
assert num_rows == array.shape[0]
Expand All @@ -2585,32 +2586,27 @@ def verify_round_trip(self, source):

for key, source_array in source.items():
dest_array = dest[key]
if source_array.dtype.str == "|O":
# Object arrays have to be treated differently.
assert source_array.shape == dest_array.shape
for a, b in zip(source_array, dest_array):
if isinstance(a, np.ndarray):
assert np.array_equal(a, b)
else:
assert a == b
else:
assert np.array_equal(source_array[:], dest_array[:])
assert np.array_equal(source_array[:], dest_array[:])
assert source_array.chunks == dest_array.chunks
return dest

def test_one_array(self):
self.verify_round_trip({"a": zarr.ones(10)})
self.verify_round_trip({"a": zarr.ones(10, zarr_format=2)})

def test_two_arrays(self):
self.verify_round_trip({"a": zarr.ones(10), "b": zarr.zeros(10)})
self.verify_round_trip(
{"a": zarr.ones(10, zarr_format=2), "b": zarr.zeros(10, zarr_format=2)}
)

def verify_dtypes(self, chunk_size=None):
n = 100
if chunk_size is None:
chunk_size = 100
dtypes = [np.int8, np.uint8, np.int32, np.uint32, np.float64, np.float32]
source = {
str(dtype): zarr.array(np.arange(n, dtype=dtype), chunks=(chunk_size,))
str(dtype): zarr.array(
np.arange(n, dtype=dtype), chunks=(chunk_size,), zarr_format=2
)
for dtype in dtypes
}
dest = self.verify_round_trip(source)
Expand All @@ -2634,72 +2630,72 @@ def test_mixed_dtypes_chunk_size_10000(self):
self.verify_dtypes(10000)

def test_2d_array(self):
a = zarr.array(np.arange(100).reshape((10, 10)))
a = zarr.array(np.arange(100).reshape((10, 10)), zarr_format=2)
self.verify_round_trip({"a": a})

def test_2d_array_chunk_size_1_1(self):
a = zarr.array(np.arange(100).reshape((10, 10)), chunks=(1, 1))
a = zarr.array(np.arange(100).reshape((10, 10)), chunks=(1, 1), zarr_format=2)
self.verify_round_trip({"a": a})

def test_2d_array_chunk_size_1_2(self):
a = zarr.array(np.arange(100).reshape((10, 10)), chunks=(1, 2))
a = zarr.array(np.arange(100).reshape((10, 10)), chunks=(1, 2), zarr_format=2)
self.verify_round_trip({"a": a})

def test_2d_array_chunk_size_2_1(self):
a = zarr.array(np.arange(100).reshape((10, 10)), chunks=(1, 2))
a = zarr.array(np.arange(100).reshape((10, 10)), chunks=(1, 2), zarr_format=2)
self.verify_round_trip({"a": a})

def test_2d_array_chunk_size_1_100(self):
a = zarr.array(np.arange(100).reshape((10, 10)), chunks=(1, 100))
a = zarr.array(np.arange(100).reshape((10, 10)), chunks=(1, 100), zarr_format=2)
self.verify_round_trip({"a": a})

def test_2d_array_chunk_size_100_1(self):
a = zarr.array(np.arange(100).reshape((10, 10)), chunks=(100, 1))
a = zarr.array(np.arange(100).reshape((10, 10)), chunks=(100, 1), zarr_format=2)
self.verify_round_trip({"a": a})

def test_2d_array_chunk_size_10_10(self):
a = zarr.array(np.arange(100).reshape((10, 10)), chunks=(5, 10))
a = zarr.array(np.arange(100).reshape((10, 10)), chunks=(5, 10), zarr_format=2)
self.verify_round_trip({"a": a})

def test_3d_array(self):
a = zarr.array(np.arange(27).reshape((3, 3, 3)))
a = zarr.array(np.arange(27).reshape((3, 3, 3)), zarr_format=2)
self.verify_round_trip({"a": a})

def test_3d_array_chunks_size_1_1_1(self):
a = zarr.array(np.arange(27).reshape((3, 3, 3)), chunks=(1, 1, 1))
a = zarr.array(
np.arange(27).reshape((3, 3, 3)), chunks=(1, 1, 1), zarr_format=2
)
self.verify_round_trip({"a": a})

def test_ragged_array_int32(self):
def test_json_encoded_string_array(self):
# Replaces old dtype="array:i4" test: tsinfer now JSON-encodes ragged
# int arrays as variable-length strings.
n = 10
z = zarr.empty(n, dtype="array:i4")
z = zarr.empty(n, dtype=str, zarr_format=2)
for j in range(n):
z[j] = np.arange(j)
self.filter_warnings_verify_round_trip({"z": z})
z[j] = json.dumps(list(range(j)))
self.verify_round_trip({"z": z})

def test_square_object_array_int32(self):
n = 10
z = zarr.empty(n, dtype="array:i4")
for j in range(n):
z[j] = np.arange(n)
self.filter_warnings_verify_round_trip({"z": z})

def test_json_object_array(self):
def test_json_encoded_dict_array(self):
# Replaces old dtype=object/object_codec=JSON() test: tsinfer now
# JSON-encodes metadata dicts as variable-length strings.
for chunks in [2, 5, 10, 100]:
n = 10
z = zarr.empty(
n, dtype=object, object_codec=numcodecs.JSON(), chunks=(chunks,)
)
z = zarr.empty(n, dtype=str, chunks=(chunks,), zarr_format=2)
for j in range(n):
z[j] = {str(k): k for k in range(j)}
self.filter_warnings_verify_round_trip({"z": z})
z[j] = json.dumps({str(k): k for k in range(j)})
self.verify_round_trip({"z": z})

def test_empty_string_list(self):
z = zarr.empty(1, dtype=object, object_codec=numcodecs.JSON(), chunks=(2,))
z[0] = ["", ""]
self.filter_warnings_verify_round_trip({"z": z})
z = zarr.empty(1, dtype=str, chunks=(2,), zarr_format=2)
z[0] = json.dumps(["", ""])
self.verify_round_trip({"z": z})

def test_mixed_chunk_sizes(self):
source = {"a": zarr.zeros(10, chunks=(1,)), "b": zarr.zeros(10, chunks=(2,))}
source = {
"a": zarr.zeros(10, chunks=(1,), zarr_format=2),
"b": zarr.zeros(10, chunks=(2,), zarr_format=2),
}
with pytest.raises(ValueError):
formats.BufferedItemWriter(source)

Expand Down
42 changes: 26 additions & 16 deletions tests/test_variantdata.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@

import bio2zarr.tskit as ts2z
import msprime
import numcodecs
import numpy as np
import pytest
import tskit
Expand Down Expand Up @@ -83,13 +82,16 @@ def test_sgkit_dataset_roundtrip(tmp_path):
def test_sgkit_individual_metadata_not_clobbered(tmp_path):
ts, zarr_path = tsutil.make_ts_and_zarr(tmp_path)
# Load the zarr to add metadata for testing
zarr_root = zarr.open(zarr_path)
zarr_root = zarr.open(zarr_path, mode="r+")
empty_obj = json.dumps({}).encode()
indiv_metadata = np.array([empty_obj] * ts.num_individuals, dtype=object)
indiv_metadata = [empty_obj] * ts.num_individuals
indiv_metadata[42] = json.dumps({"variant_data_sample_id": "foobar"}).encode()
zarr_root.create_dataset(
"individuals_metadata", data=indiv_metadata, object_codec=numcodecs.VLenBytes()
arr = zarr_root.create_array(
"individuals_metadata",
shape=(len(indiv_metadata),),
dtype=bytes,
)
arr[:] = indiv_metadata
zarr_root.attrs["individuals_metadata_schema"] = repr(
tskit.MetadataSchema.permissive_json()
)
Expand Down Expand Up @@ -856,8 +858,10 @@ def simulate_genotype_call_dataset(*args, **kwargs):
return ds

def test_bad_zarr_spec(self):
ds = zarr.group()
ds["call_genotype"] = zarr.array(np.zeros(10, dtype=np.int8))
ds = zarr.group(zarr_format=2)
ds["call_genotype"] = zarr.array(
np.zeros(10, dtype=np.int8), zarr_format=2
)
with pytest.raises(
ValueError, match="Expecting a VCF Zarr object with 3D call_genotype array"
):
Expand Down Expand Up @@ -1041,9 +1045,18 @@ def test_wrong_individuals_array_length(self, tmp_path):

@pytest.mark.skipif(sys.platform == "win32", reason="File permission errors on Windows")
class TestAddAncestralStateArray:
@staticmethod
def _make_store(tmp_path, positions):
store = zarr.group(store=str(tmp_path / "test.zarr"), zarr_format=2)
pos = np.asarray(positions, dtype=np.float64)
arr = store.create_array(
"variant_position", shape=pos.shape, dtype=np.float64, zarr_format=2
)
arr[:] = pos
return store

def test_add_ancestral_state_array(self, tmp_path):
store = zarr.group(store=str(tmp_path / "test.zarr"))
store.create_dataset("variant_position", data=[10, 20, 30, 40, 50])
store = self._make_store(tmp_path, [10, 20, 30, 40, 50])
array = formats.add_ancestral_state_array(store, "A" * 60)

assert "ancestral_state" in store
Expand All @@ -1055,30 +1068,27 @@ def test_add_ancestral_state_array(self, tmp_path):
assert "custom_ancestral" in store

def test_mixed_case_and_different_nucleotides(self, tmp_path):
store = zarr.group(store=str(tmp_path / "test.zarr"))
store.create_dataset("variant_position", data=[10, 20, 30, 40, 50])
store = self._make_store(tmp_path, [10, 20, 30, 40, 50])
array = formats.add_ancestral_state_array(
store, "A" * 10 + "c" + "G" * 9 + "t" + "C" * 9 + "a" + "T" * 19 + "g"
)
np.testing.assert_array_equal(array[:], np.array(["C", "T", "A", "T", "G"]))

def test_error_no_variant_position(self, tmp_path):
store = zarr.group(store=str(tmp_path / "test.zarr"))
store = zarr.group(store=str(tmp_path / "test.zarr"), zarr_format=2)
with pytest.raises(ValueError, match="must contain a 'variant_position' array"):
formats.add_ancestral_state_array(store, "A")

def test_error_fasta_too_short(self, tmp_path):
store = zarr.group(store=str(tmp_path / "test.zarr"))
store.create_dataset("variant_position", data=[10, 20, 100])
store = self._make_store(tmp_path, [10, 20, 100])
fasta_string = "A" * 50 # Only 50 bases, not enough for position 100
with pytest.raises(
ValueError, match="length of the fasta string must be at least"
):
formats.add_ancestral_state_array(store, fasta_string)

def test_empty_positions_array(self, tmp_path):
store = zarr.group(store=str(tmp_path / "test.zarr"))
store.create_dataset("variant_position", data=[])
store = self._make_store(tmp_path, [])
with pytest.raises(
ValueError,
match="variant_position array must contain at least one position",
Expand Down
7 changes: 5 additions & 2 deletions tests/tsutil.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@ def add_attribute_to_dataset(name, contents, zarr_path):

def make_ts_and_zarr(path=None, prefix="data", add_optional=False, shuffle_alleles=True):
if path is None:
in_mem_copy = zarr.group()
in_mem_copy = zarr.group(zarr_format=2)
with tempfile.TemporaryDirectory() as path:
ts, zarr_path = _make_ts_and_zarr(
Path(path),
Expand All @@ -276,7 +276,10 @@ def make_ts_and_zarr(path=None, prefix="data", add_optional=False, shuffle_allel
shuffle_alleles=shuffle_alleles,
)
# For testing only, return an in-memory copy of the dataset we just made
zarr.convenience.copy_all(zarr.open(zarr_path), in_mem_copy)
# zarr.convenience.copy_all removed in zarr v3; use our helper instead.
from tsinfer.formats import _copy_zarr_group

_copy_zarr_group(zarr.open(zarr_path, mode="r"), in_mem_copy)
return ts, in_mem_copy
else:
return _make_ts_and_zarr(
Expand Down
4 changes: 2 additions & 2 deletions tsinfer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,10 @@
try:
import zarr

if zarr.__version__ >= "3":
if zarr.__version__ >= "4":
raise RuntimeError(
f"zarr version {zarr.__version__} is not supported. "
"tsinfer requires zarr < 3.0. Please install zarr < 3.0."
"tsinfer requires zarr >= 3, < 4. Please install zarr >= 3, < 4."
)
except ImportError:
pass
Expand Down
Loading
Loading