Skip to content

Commit 661bd48

Browse files
committed
Reinstate VariantData error checks
Using tskit2zarr rather than sgkit
1 parent a72866c commit 661bd48

File tree

1 file changed

+65
-87
lines changed

1 file changed

+65
-87
lines changed

tests/test_variantdata.py

Lines changed: 65 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -834,135 +834,118 @@ def test_ancestral_missing_info(tmp_path, caplog):
834834
assert inf_var.site.ancestral_state == var.site.ancestral_state
835835

836836

837-
@pytest.mark.skip("Depends on sgkit methods")
838837
class TestVariantDataErrors:
839838
@staticmethod
840-
def simulate_genotype_call_dataset(*args, **kwargs):
841-
# roll our own simulate_genotype_call_dataset to hack around bug in sgkit where
842-
# duplicate alleles are created. Doesn't need to be efficient: just for testing
843-
if "seed" not in kwargs:
844-
kwargs["seed"] = 123
845-
ds = sgkit.simulate_genotype_call_dataset(*args, **kwargs) ## NOQA
846-
variant_alleles = ds["variant_allele"].values
847-
allowed_alleles = np.array(
848-
["A", "T", "C", "G", "N"], dtype=variant_alleles.dtype
839+
def simulate_vcz(path, n_sample, n_variant, n_ploidy=2):
840+
# use b2z to create a zarr dataset we can modify
841+
ts = msprime.sim_ancestry(
842+
n_sample, sequence_length=1000, ploidy=n_ploidy, random_seed=123
849843
)
850-
for row in range(len(variant_alleles)):
851-
alleles = variant_alleles[row]
852-
if len(set(alleles)) != len(alleles):
853-
# Just use a set that we know is unique
854-
variant_alleles[row] = allowed_alleles[0 : len(alleles)]
855-
ds["variant_allele"] = ds["variant_allele"].dims, variant_alleles
856-
return ds
844+
# Add enough mutations that we get at least n_variant sites
845+
ts = msprime.sim_mutations(ts, rate=1e-3, random_seed=123)
846+
assert ts.num_sites >= n_variant
847+
# Filter down to the first n_variant sites
848+
ts = ts.keep_intervals([(0, ts.site(n_variant - 1).position + 1)]).trim()
849+
assert ts.num_sites == n_variant
850+
ts2z.convert(ts, path)
851+
z = zarr.open(path, mode="a")
852+
z.create_dataset(
853+
name="variant_ancestral_allele",
854+
data=ts.sites_ancestral_state,
855+
dtype="S1",
856+
chunks=True,
857+
)
858+
return z
859+
860+
def test_phased(self, tmp_path):
861+
# Check the default runs without error
862+
path = tmp_path / "data.vcz"
863+
self.simulate_vcz(path, n_variant=3, n_sample=3)
864+
vdata = tsinfer.VariantData(path, "variant_ancestral_allele")
865+
assert vdata.ploidy == 2
857866

858867
def test_bad_zarr_spec(self):
859-
ds = zarr.group()
860-
ds["call_genotype"] = zarr.array(np.zeros(10, dtype=np.int8))
868+
z = zarr.group()
869+
z["call_genotype"] = zarr.array(np.zeros(10, dtype=np.int8))
861870
with pytest.raises(
862871
ValueError, match="Expecting a VCF Zarr object with 3D call_genotype array"
863872
):
864-
tsinfer.VariantData(ds, np.zeros(10, dtype="<U1"))
873+
tsinfer.VariantData(z, np.zeros(10, dtype="<U1"))
865874

866875
def test_missing_phase(self, tmp_path):
867-
path = tmp_path / "data.zarr"
868-
ds = self.simulate_genotype_call_dataset(n_variant=3, n_sample=3)
869-
tsutil.save_dataset(ds, path)
876+
path = tmp_path / "data.vcz"
877+
z = self.simulate_vcz(path, n_variant=3, n_sample=3)
878+
del z["call_genotype_phased"]
870879
with pytest.raises(
871880
ValueError, match="The call_genotype_phased array is missing"
872881
):
873882
tsinfer.VariantData(path, "variant_ancestral_allele")
874883

875-
def test_phased(self, tmp_path):
876-
path = tmp_path / "data.zarr"
877-
ds = self.simulate_genotype_call_dataset(n_variant=3, n_sample=3)
878-
ds["call_genotype_phased"] = (
879-
ds["call_genotype"].dims,
880-
np.ones(ds["call_genotype"].shape, dtype=bool),
881-
)
882-
tsutil.save_dataset(ds, path)
883-
tsinfer.VariantData(path, ds["variant_allele"][:, 0].values.astype(str))
884-
885884
def test_ploidy1_missing_phase(self, tmp_path):
886885
path = tmp_path / "data.zarr"
887886
# Ploidy==1 is always ok
888-
ds = self.simulate_genotype_call_dataset(n_variant=3, n_sample=3, n_ploidy=1)
889-
tsutil.save_dataset(ds, path)
890-
tsinfer.VariantData(path, ds["variant_allele"][:, 0].values.astype(str))
887+
z = self.simulate_vcz(path, n_variant=3, n_sample=3, n_ploidy=1)
888+
del z["call_genotype_phased"]
889+
tsinfer.VariantData(path, "variant_ancestral_allele")
891890

892891
def test_ploidy1_unphased(self, tmp_path):
893892
path = tmp_path / "data.zarr"
894-
ds = self.simulate_genotype_call_dataset(n_variant=3, n_sample=3, n_ploidy=1)
895-
ds["call_genotype_phased"] = (
896-
ds["call_genotype"].dims,
897-
np.zeros(ds["call_genotype"].shape, dtype=bool),
898-
)
899-
tsutil.save_dataset(ds, path)
900-
tsinfer.VariantData(path, ds["variant_allele"][:, 0].values.astype(str))
893+
z = self.simulate_vcz(path, n_variant=3, n_sample=3, n_ploidy=1)
894+
z["call_genotype_phased"][:] = False
895+
tsinfer.VariantData(path, "variant_ancestral_allele")
901896

902897
def test_duplicate_positions(self, tmp_path):
903898
path = tmp_path / "data.zarr"
904-
ds = self.simulate_genotype_call_dataset(n_variant=3, n_sample=3, phased=True)
905-
ds["variant_position"][2] = ds["variant_position"][1]
906-
tsutil.save_dataset(ds, path)
899+
z = self.simulate_vcz(path, n_variant=3, n_sample=3)
900+
z["variant_position"][2] = z["variant_position"][1]
907901
with pytest.raises(ValueError, match="duplicate or out-of-order values"):
908902
tsinfer.VariantData(path, "variant_ancestral_allele")
909903

910904
def test_bad_order_positions(self, tmp_path):
911905
path = tmp_path / "data.zarr"
912-
ds = self.simulate_genotype_call_dataset(n_variant=3, n_sample=3, phased=True)
913-
ds["variant_position"][0] = ds["variant_position"][2] - 0.5
914-
tsutil.save_dataset(ds, path)
906+
z = self.simulate_vcz(path, n_variant=3, n_sample=3)
907+
z["variant_position"][0] = z["variant_position"][2] - 0.5
915908
with pytest.raises(ValueError, match="duplicate or out-of-order values"):
916909
tsinfer.VariantData(path, "variant_ancestral_allele")
917910

918911
def test_bad_ancestral_state(self, tmp_path):
919912
path = tmp_path / "data.zarr"
920-
ds = self.simulate_genotype_call_dataset(n_variant=3, n_sample=3, phased=True)
921-
ancestral_state = ds["variant_allele"][:, 0].values.astype(str)
922-
ancestral_state[1] = ""
923-
tsutil.save_dataset(ds, path)
913+
z = self.simulate_vcz(path, n_variant=3, n_sample=3)
914+
z["variant_ancestral_allele"][1] = ""
924915
with pytest.raises(ValueError, match="cannot contain empty strings"):
925-
tsinfer.VariantData(path, ancestral_state)
916+
tsinfer.VariantData(path, "variant_ancestral_allele")
926917

927918
def test_ancestral_state_len_not_same_as_mask(self, tmp_path):
928919
path = tmp_path / "data.zarr"
929-
ds = self.simulate_genotype_call_dataset(n_variant=3, n_sample=3, phased=True)
930-
tsutil.save_dataset(ds, path)
931-
ancestral_state = ds["variant_allele"][:, 0].values.astype(str)
932-
site_mask = np.zeros(ds.sizes["variants"], dtype=bool)
920+
z = self.simulate_vcz(path, n_variant=3, n_sample=3)
921+
site_mask = np.zeros(z["variant_ancestral_allele"].shape[0], dtype=bool)
933922
site_mask[0] = True
934923
with pytest.raises(
935924
ValueError,
936925
match="Ancestral state array must be the same length as the number of"
937926
" selected sites",
938927
):
939-
tsinfer.VariantData(path, ancestral_state, site_mask=site_mask)
928+
tsinfer.VariantData(
929+
path, z["variant_ancestral_allele"][:], site_mask=site_mask
930+
)
940931

941932
def test_empty_alleles_not_at_end(self, tmp_path):
942933
path = tmp_path / "data.zarr"
943-
ds = self.simulate_genotype_call_dataset(n_variant=3, n_sample=3, n_ploidy=1)
944-
ds["variant_allele"] = (
945-
ds["variant_allele"].dims,
946-
np.array([["A", "", "C"], ["A", "C", ""], ["A", "C", ""]], dtype="S1"),
947-
)
948-
tsutil.save_dataset(ds, path)
934+
z = self.simulate_vcz(path, n_variant=3, n_sample=3)
935+
z["variant_allele"][0] = ["", "T"]
949936
with pytest.raises(
950937
ValueError, match='Bad alleles: fill value "" in middle of list'
951938
):
952-
tsinfer.VariantData(path, ds["variant_allele"][:, 0].values.astype(str))
939+
tsinfer.VariantData(path, "variant_ancestral_allele")
953940

954941
def test_unique_alleles(self, tmp_path):
955942
path = tmp_path / "data.zarr"
956-
ds = self.simulate_genotype_call_dataset(n_variant=3, n_sample=3, n_ploidy=1)
957-
ds["variant_allele"] = (
958-
ds["variant_allele"].dims,
959-
np.array([["A", "C", "T"], ["A", "C", ""], ["A", "A", ""]], dtype="S1"),
960-
)
961-
tsutil.save_dataset(ds, path)
943+
z = self.simulate_vcz(path, n_variant=3, n_sample=3)
944+
z["variant_allele"][2] = ["T", "T"]
962945
with pytest.raises(
963946
ValueError, match="Duplicate allele values provided at site 2"
964947
):
965-
tsinfer.VariantData(path, np.array(["A", "A", "A"], dtype="S1"))
948+
tsinfer.VariantData(path, "variant_ancestral_allele")
966949

967950
def test_unimplemented_from_tree_sequence(self):
968951
# NB we should reimplement something like this functionality.
@@ -972,28 +955,23 @@ def test_unimplemented_from_tree_sequence(self):
972955

973956
def test_all_masked(self, tmp_path):
974957
path = tmp_path / "data.zarr"
975-
ds = self.simulate_genotype_call_dataset(n_variant=3, n_sample=3, phased=True)
976-
tsutil.save_dataset(ds, path)
958+
self.simulate_vcz(path, n_variant=3, n_sample=3)
977959
with pytest.raises(ValueError, match="All sites have been masked out"):
978960
tsinfer.VariantData(
979-
path, ds["variant_allele"][:, 0].astype(str), site_mask=np.ones(3, bool)
961+
path, "variant_ancestral_allele", site_mask=np.ones(3, bool)
980962
)
981963

982964
def test_missing_sites_time(self, tmp_path):
983965
path = tmp_path / "data.zarr"
984-
ds = self.simulate_genotype_call_dataset(n_variant=3, n_sample=3, phased=True)
985-
tsutil.save_dataset(ds, path)
966+
self.simulate_vcz(path, n_variant=3, n_sample=3)
986967
with pytest.raises(
987968
ValueError, match="The sites time array XX was not found in the dataset"
988969
):
989-
tsinfer.VariantData(
990-
path, ds["variant_allele"][:, 0].astype(str), sites_time="XX"
991-
)
970+
tsinfer.VariantData(path, "variant_ancestral_allele", sites_time="XX")
992971

993972
def test_wrong_individuals_array_length(self, tmp_path):
994973
path = tmp_path / "data.zarr"
995-
ds = self.simulate_genotype_call_dataset(n_variant=3, n_sample=5, phased=True)
996-
tsutil.save_dataset(ds, path)
974+
self.simulate_vcz(path, n_variant=3, n_sample=5)
997975

998976
# Create arrays with wrong length (too short)
999977
wrong_length_time = np.arange(3, dtype=np.float32)
@@ -1007,7 +985,7 @@ def test_wrong_individuals_array_length(self, tmp_path):
1007985
):
1008986
tsinfer.VariantData(
1009987
path,
1010-
ds["variant_allele"][:, 0].values.astype(str),
988+
"variant_ancestral_allele",
1011989
individuals_time=wrong_length_time,
1012990
)
1013991

@@ -1016,7 +994,7 @@ def test_wrong_individuals_array_length(self, tmp_path):
1016994
):
1017995
tsinfer.VariantData(
1018996
path,
1019-
ds["variant_allele"][:, 0].values.astype(str),
997+
"variant_ancestral_allele",
1020998
individuals_location=wrong_length_location,
1021999
)
10221000

@@ -1025,7 +1003,7 @@ def test_wrong_individuals_array_length(self, tmp_path):
10251003
):
10261004
tsinfer.VariantData(
10271005
path,
1028-
ds["variant_allele"][:, 0].values.astype(str),
1006+
"variant_ancestral_allele",
10291007
individuals_population=wrong_length_population,
10301008
)
10311009

@@ -1034,7 +1012,7 @@ def test_wrong_individuals_array_length(self, tmp_path):
10341012
):
10351013
tsinfer.VariantData(
10361014
path,
1037-
ds["variant_allele"][:, 0].values.astype(str),
1015+
"variant_ancestral_allele",
10381016
individuals_flags=wrong_length_flags,
10391017
)
10401018

0 commit comments

Comments
 (0)