Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
152 changes: 65 additions & 87 deletions tests/test_variantdata.py
Original file line number Diff line number Diff line change
Expand Up @@ -834,135 +834,118 @@ def test_ancestral_missing_info(tmp_path, caplog):
assert inf_var.site.ancestral_state == var.site.ancestral_state


@pytest.mark.skip("Depends on sgkit methods")
class TestVariantDataErrors:
@staticmethod
def simulate_genotype_call_dataset(*args, **kwargs):
# roll our own simulate_genotype_call_dataset to hack around bug in sgkit where
# duplicate alleles are created. Doesn't need to be efficient: just for testing
if "seed" not in kwargs:
kwargs["seed"] = 123
ds = sgkit.simulate_genotype_call_dataset(*args, **kwargs) ## NOQA
variant_alleles = ds["variant_allele"].values
allowed_alleles = np.array(
["A", "T", "C", "G", "N"], dtype=variant_alleles.dtype
def simulate_vcz(path, n_sample, n_variant, n_ploidy=2):
# use b2z to create a zarr dataset we can modify
ts = msprime.sim_ancestry(
n_sample, sequence_length=1000, ploidy=n_ploidy, random_seed=123
)
for row in range(len(variant_alleles)):
alleles = variant_alleles[row]
if len(set(alleles)) != len(alleles):
# Just use a set that we know is unique
variant_alleles[row] = allowed_alleles[0 : len(alleles)]
ds["variant_allele"] = ds["variant_allele"].dims, variant_alleles
return ds
# Add enough mutations that we get at least n_variant sites
ts = msprime.sim_mutations(ts, rate=5e-3, random_seed=123)
assert ts.num_sites >= n_variant
# Filter down to the first n_variant sites
ts = ts.keep_intervals([(0, ts.site(n_variant - 1).position + 1)]).trim()
assert ts.num_sites == n_variant
ts2z.convert(ts, path)
z = zarr.open(path, mode="a")
z.create_dataset(
name="variant_ancestral_allele",
data=ts.sites_ancestral_state,
dtype="S1",
chunks=True,
)
return z

def test_phased(self, tmp_path):
# Check the default runs without error
path = tmp_path / "data.vcz"
self.simulate_vcz(path, n_variant=3, n_sample=3)
vdata = tsinfer.VariantData(path, "variant_ancestral_allele")
assert vdata.ploidy == 2

def test_bad_zarr_spec(self):
ds = zarr.group()
ds["call_genotype"] = zarr.array(np.zeros(10, dtype=np.int8))
z = zarr.group()
z["call_genotype"] = zarr.array(np.zeros(10, dtype=np.int8))
with pytest.raises(
ValueError, match="Expecting a VCF Zarr object with 3D call_genotype array"
):
tsinfer.VariantData(ds, np.zeros(10, dtype="<U1"))
tsinfer.VariantData(z, np.zeros(10, dtype="<U1"))

def test_missing_phase(self, tmp_path):
path = tmp_path / "data.zarr"
ds = self.simulate_genotype_call_dataset(n_variant=3, n_sample=3)
tsutil.save_dataset(ds, path)
path = tmp_path / "data.vcz"
z = self.simulate_vcz(path, n_variant=3, n_sample=3)
del z["call_genotype_phased"]
with pytest.raises(
ValueError, match="The call_genotype_phased array is missing"
):
tsinfer.VariantData(path, "variant_ancestral_allele")

def test_phased(self, tmp_path):
path = tmp_path / "data.zarr"
ds = self.simulate_genotype_call_dataset(n_variant=3, n_sample=3)
ds["call_genotype_phased"] = (
ds["call_genotype"].dims,
np.ones(ds["call_genotype"].shape, dtype=bool),
)
tsutil.save_dataset(ds, path)
tsinfer.VariantData(path, ds["variant_allele"][:, 0].values.astype(str))

def test_ploidy1_missing_phase(self, tmp_path):
path = tmp_path / "data.zarr"
# Ploidy==1 is always ok
ds = self.simulate_genotype_call_dataset(n_variant=3, n_sample=3, n_ploidy=1)
tsutil.save_dataset(ds, path)
tsinfer.VariantData(path, ds["variant_allele"][:, 0].values.astype(str))
z = self.simulate_vcz(path, n_variant=3, n_sample=3, n_ploidy=1)
del z["call_genotype_phased"]
tsinfer.VariantData(path, "variant_ancestral_allele")

def test_ploidy1_unphased(self, tmp_path):
path = tmp_path / "data.zarr"
ds = self.simulate_genotype_call_dataset(n_variant=3, n_sample=3, n_ploidy=1)
ds["call_genotype_phased"] = (
ds["call_genotype"].dims,
np.zeros(ds["call_genotype"].shape, dtype=bool),
)
tsutil.save_dataset(ds, path)
tsinfer.VariantData(path, ds["variant_allele"][:, 0].values.astype(str))
z = self.simulate_vcz(path, n_variant=3, n_sample=3, n_ploidy=1)
z["call_genotype_phased"][:] = False
tsinfer.VariantData(path, "variant_ancestral_allele")

def test_duplicate_positions(self, tmp_path):
path = tmp_path / "data.zarr"
ds = self.simulate_genotype_call_dataset(n_variant=3, n_sample=3, phased=True)
ds["variant_position"][2] = ds["variant_position"][1]
tsutil.save_dataset(ds, path)
z = self.simulate_vcz(path, n_variant=3, n_sample=3)
z["variant_position"][2] = z["variant_position"][1]
with pytest.raises(ValueError, match="duplicate or out-of-order values"):
tsinfer.VariantData(path, "variant_ancestral_allele")

def test_bad_order_positions(self, tmp_path):
path = tmp_path / "data.zarr"
ds = self.simulate_genotype_call_dataset(n_variant=3, n_sample=3, phased=True)
ds["variant_position"][0] = ds["variant_position"][2] - 0.5
tsutil.save_dataset(ds, path)
z = self.simulate_vcz(path, n_variant=3, n_sample=3)
z["variant_position"][0] = z["variant_position"][2] - 0.5
with pytest.raises(ValueError, match="duplicate or out-of-order values"):
tsinfer.VariantData(path, "variant_ancestral_allele")

def test_bad_ancestral_state(self, tmp_path):
path = tmp_path / "data.zarr"
ds = self.simulate_genotype_call_dataset(n_variant=3, n_sample=3, phased=True)
ancestral_state = ds["variant_allele"][:, 0].values.astype(str)
ancestral_state[1] = ""
tsutil.save_dataset(ds, path)
z = self.simulate_vcz(path, n_variant=3, n_sample=3)
z["variant_ancestral_allele"][1] = ""
with pytest.raises(ValueError, match="cannot contain empty strings"):
tsinfer.VariantData(path, ancestral_state)
tsinfer.VariantData(path, "variant_ancestral_allele")

def test_ancestral_state_len_not_same_as_mask(self, tmp_path):
path = tmp_path / "data.zarr"
ds = self.simulate_genotype_call_dataset(n_variant=3, n_sample=3, phased=True)
tsutil.save_dataset(ds, path)
ancestral_state = ds["variant_allele"][:, 0].values.astype(str)
site_mask = np.zeros(ds.sizes["variants"], dtype=bool)
z = self.simulate_vcz(path, n_variant=3, n_sample=3)
site_mask = np.zeros(z["variant_ancestral_allele"].shape[0], dtype=bool)
site_mask[0] = True
with pytest.raises(
ValueError,
match="Ancestral state array must be the same length as the number of"
" selected sites",
):
tsinfer.VariantData(path, ancestral_state, site_mask=site_mask)
tsinfer.VariantData(
path, z["variant_ancestral_allele"][:], site_mask=site_mask
)

def test_empty_alleles_not_at_end(self, tmp_path):
path = tmp_path / "data.zarr"
ds = self.simulate_genotype_call_dataset(n_variant=3, n_sample=3, n_ploidy=1)
ds["variant_allele"] = (
ds["variant_allele"].dims,
np.array([["A", "", "C"], ["A", "C", ""], ["A", "C", ""]], dtype="S1"),
)
tsutil.save_dataset(ds, path)
z = self.simulate_vcz(path, n_variant=3, n_sample=3)
z["variant_allele"][0] = ["", "T"]
with pytest.raises(
ValueError, match='Bad alleles: fill value "" in middle of list'
):
tsinfer.VariantData(path, ds["variant_allele"][:, 0].values.astype(str))
tsinfer.VariantData(path, "variant_ancestral_allele")

def test_unique_alleles(self, tmp_path):
path = tmp_path / "data.zarr"
ds = self.simulate_genotype_call_dataset(n_variant=3, n_sample=3, n_ploidy=1)
ds["variant_allele"] = (
ds["variant_allele"].dims,
np.array([["A", "C", "T"], ["A", "C", ""], ["A", "A", ""]], dtype="S1"),
)
tsutil.save_dataset(ds, path)
z = self.simulate_vcz(path, n_variant=3, n_sample=3)
z["variant_allele"][2] = ["T", "T"]
with pytest.raises(
ValueError, match="Duplicate allele values provided at site 2"
):
tsinfer.VariantData(path, np.array(["A", "A", "A"], dtype="S1"))
tsinfer.VariantData(path, "variant_ancestral_allele")

def test_unimplemented_from_tree_sequence(self):
# NB we should reimplement something like this functionality.
Expand All @@ -972,28 +955,23 @@ def test_unimplemented_from_tree_sequence(self):

def test_all_masked(self, tmp_path):
path = tmp_path / "data.zarr"
ds = self.simulate_genotype_call_dataset(n_variant=3, n_sample=3, phased=True)
tsutil.save_dataset(ds, path)
self.simulate_vcz(path, n_variant=3, n_sample=3)
with pytest.raises(ValueError, match="All sites have been masked out"):
tsinfer.VariantData(
path, ds["variant_allele"][:, 0].astype(str), site_mask=np.ones(3, bool)
path, "variant_ancestral_allele", site_mask=np.ones(3, bool)
)

def test_missing_sites_time(self, tmp_path):
path = tmp_path / "data.zarr"
ds = self.simulate_genotype_call_dataset(n_variant=3, n_sample=3, phased=True)
tsutil.save_dataset(ds, path)
self.simulate_vcz(path, n_variant=3, n_sample=3)
with pytest.raises(
ValueError, match="The sites time array XX was not found in the dataset"
):
tsinfer.VariantData(
path, ds["variant_allele"][:, 0].astype(str), sites_time="XX"
)
tsinfer.VariantData(path, "variant_ancestral_allele", sites_time="XX")

def test_wrong_individuals_array_length(self, tmp_path):
path = tmp_path / "data.zarr"
ds = self.simulate_genotype_call_dataset(n_variant=3, n_sample=5, phased=True)
tsutil.save_dataset(ds, path)
self.simulate_vcz(path, n_variant=3, n_sample=5)

# Create arrays with wrong length (too short)
wrong_length_time = np.arange(3, dtype=np.float32)
Expand All @@ -1007,7 +985,7 @@ def test_wrong_individuals_array_length(self, tmp_path):
):
tsinfer.VariantData(
path,
ds["variant_allele"][:, 0].values.astype(str),
"variant_ancestral_allele",
individuals_time=wrong_length_time,
)

Expand All @@ -1016,7 +994,7 @@ def test_wrong_individuals_array_length(self, tmp_path):
):
tsinfer.VariantData(
path,
ds["variant_allele"][:, 0].values.astype(str),
"variant_ancestral_allele",
individuals_location=wrong_length_location,
)

Expand All @@ -1025,7 +1003,7 @@ def test_wrong_individuals_array_length(self, tmp_path):
):
tsinfer.VariantData(
path,
ds["variant_allele"][:, 0].values.astype(str),
"variant_ancestral_allele",
individuals_population=wrong_length_population,
)

Expand All @@ -1034,7 +1012,7 @@ def test_wrong_individuals_array_length(self, tmp_path):
):
tsinfer.VariantData(
path,
ds["variant_allele"][:, 0].values.astype(str),
"variant_ancestral_allele",
individuals_flags=wrong_length_flags,
)

Expand Down