Skip to content

Commit 2de0e24

Browse files
committed
Create a VariantData.from_arrays method
This makes an in-memory vdata object that can be used for testing. Fixes #924
1 parent b436755 commit 2de0e24

File tree

2 files changed

+356
-1
lines changed

2 files changed

+356
-1
lines changed

tests/test_variantdata.py

Lines changed: 201 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1128,3 +1128,204 @@ def test_with_variant_data(self, tmp_path):
11281128
else:
11291129
allele_idx = -1
11301130
assert vdata.sites_ancestral_allele[i] == allele_idx
1131+
1132+
1133+
class TestFromArrays:
1134+
def demo_data(self):
1135+
# returns pos, data, alleles, ancestral
1136+
return [
1137+
list(data)
1138+
for data in zip(
1139+
*[
1140+
(3, [[0, 1], [0, 0], [0, 0]], ["A", "T", ""], "A"),
1141+
(10, [[0, 1], [1, 1], [0, 0]], ["C", "A", ""], "C"),
1142+
(13, [[0, 1], [1, 0], [0, 0]], ["G", "C", ""], "C"),
1143+
(19, [[0, 0], [0, 1], [1, 0]], ["A", "C", ""], "A"),
1144+
(20, [[0, 1], [2, 0], [0, 0]], ["T", "G", "C"], "T"),
1145+
]
1146+
)
1147+
]
1148+
1149+
def test_simple_from_arrays(self):
1150+
pos, G, alleles, ancestral = self.demo_data()
1151+
vdata = tsinfer.VariantData.from_arrays(G, pos, alleles, ancestral)
1152+
assert vdata.num_individuals == 3
1153+
assert vdata.num_sites == 5
1154+
inf_ts = tsinfer.infer(vdata)
1155+
assert inf_ts.num_samples == 6
1156+
assert inf_ts.num_individuals == 3
1157+
assert inf_ts.num_sites == 5
1158+
assert np.all(inf_ts.sites_position == pos)
1159+
1160+
def test_named_from_arrays(self):
1161+
# When we pass sample_id names, they should be stored in the individuals metadata
1162+
pos, G, alleles, ancestral = self.demo_data()
1163+
sample_id = ["sample1", "sample2", "sample3"]
1164+
vdata = tsinfer.VariantData.from_arrays(
1165+
G, pos, alleles, ancestral, sample_id=sample_id
1166+
)
1167+
assert vdata.num_individuals == 3
1168+
inf_ts = tsinfer.infer(vdata)
1169+
assert inf_ts.num_individuals == 3
1170+
for name, ind in zip(sample_id, inf_ts.individuals()):
1171+
assert ind.metadata["variant_data_sample_id"] == name
1172+
1173+
def test_bad_variant_matrix(self):
1174+
pos, G, alleles, ancestral = self.demo_data()
1175+
G = np.array(G)
1176+
with pytest.raises(ValueError, match="must be a 3D array"):
1177+
tsinfer.VariantData.from_arrays([G], pos, alleles, ancestral)
1178+
with pytest.raises(ValueError, match="must be a 3D array"):
1179+
tsinfer.VariantData.from_arrays(G[:, :, 0], pos, alleles, ancestral)
1180+
1181+
def test_empty(self):
1182+
# Test with ploidy=1 but no sites
1183+
pos, G, alleles, ancestral = [], np.empty((0, 0, 1)), np.empty((0, 0)), []
1184+
with pytest.raises(ValueError, match="No sites exist"):
1185+
tsinfer.VariantData.from_arrays(G, pos, alleles, ancestral)
1186+
1187+
def test_zero_ploidy(self):
1188+
pos, G, alleles, ancestral = [], [[[]]], np.empty((0, 0)), []
1189+
with pytest.raises(ValueError, match="Ploidy must be greater than zero"):
1190+
tsinfer.VariantData.from_arrays(G, pos, alleles, ancestral)
1191+
1192+
def test_from_arrays_ancestral_missing_warning(self):
1193+
pos, G, alleles, ancestral = self.demo_data()
1194+
ancestral[0] = "-"
1195+
with pytest.warns(UserWarning, match=r"ancestral allele.+not found[\s\S]+'-'"):
1196+
tsinfer.VariantData.from_arrays(G, pos, alleles, ancestral)
1197+
1198+
def test_sequence_length(self):
1199+
pos, G, alleles, ancestral = self.demo_data()
1200+
vdata = tsinfer.VariantData.from_arrays(
1201+
G, pos, alleles, ancestral, sequence_length=50
1202+
)
1203+
assert vdata.sequence_length == 50
1204+
1205+
def test_bad_sequence_length(self):
1206+
pos, G, alleles, ancestral = self.demo_data()
1207+
with pytest.raises(ValueError, match="`sequence_length` cannot be less"):
1208+
tsinfer.VariantData.from_arrays(
1209+
G, pos, alleles, ancestral, sequence_length=10
1210+
)
1211+
1212+
@pytest.mark.parametrize("pos", [[[3, 10, 13, 19, 20]], [3, 10, 13, 19]])
1213+
def test_bad_position(self, pos):
1214+
_, G, alleles, ancestral = self.demo_data()
1215+
with pytest.raises(ValueError, match="`variant_position` must be a 1D array"):
1216+
tsinfer.VariantData.from_arrays(G, [pos], alleles, ancestral)
1217+
1218+
def test_unordered_position(self):
1219+
pos, G, alleles, ancestral = self.demo_data()
1220+
pos[-1] = 5 # out of order
1221+
with pytest.raises(ValueError, match="out-of-order values"):
1222+
tsinfer.VariantData.from_arrays(G, pos, alleles, ancestral)
1223+
1224+
def test_bad_dim_alleles(self):
1225+
pos, G, alleles, ancestral = self.demo_data()
1226+
with pytest.raises(ValueError, match="`variant_allele` must be a 2D array"):
1227+
tsinfer.VariantData.from_arrays(G, pos, [alleles], ancestral)
1228+
1229+
def test_bad_alleles(self):
1230+
pos, G, alleles, ancestral = self.demo_data()
1231+
alleles = np.array(alleles)
1232+
with pytest.raises(ValueError, match="same number of rows as variants"):
1233+
tsinfer.VariantData.from_arrays(G, pos, alleles[1:, :], ancestral)
1234+
1235+
def test_bad_num_alleles(self):
1236+
pos, G, alleles, ancestral = self.demo_data()
1237+
alleles = np.array(alleles)
1238+
with pytest.raises(ValueError, match="same number of columns"):
1239+
tsinfer.VariantData.from_arrays(G, pos, alleles[:, 1:], ancestral)
1240+
1241+
def test_bad_ancestral_state_length(self):
1242+
pos, G, alleles, ancestral = self.demo_data()
1243+
ancestral = np.array(ancestral)
1244+
with pytest.raises(ValueError, match="`ancestral_state` must be a 1D array"):
1245+
tsinfer.VariantData.from_arrays(G, pos, alleles, [ancestral])
1246+
with pytest.raises(ValueError, match="`ancestral_state` must be a 1D array"):
1247+
tsinfer.VariantData.from_arrays(G, pos, alleles, ancestral[1:])
1248+
1249+
@pytest.mark.parametrize("sid", [["A"], []])
1250+
def test_bad_sample_id(self, sid):
1251+
pos, G, alleles, ancestral = self.demo_data()
1252+
print(sid)
1253+
with pytest.raises(ValueError, match="`sample_id` must be a 1D array"):
1254+
tsinfer.VariantData.from_arrays(G, pos, alleles, ancestral, sample_id=sid)
1255+
1256+
def test_sample_mask(self):
1257+
pos, G, alleles, ancestral = self.demo_data()
1258+
G = np.array(G)
1259+
mask = np.array([False, False, True])
1260+
keep = np.logical_not(mask)
1261+
alleles = np.array(alleles)
1262+
vdata = tsinfer.VariantData.from_arrays(
1263+
G, pos, alleles, ancestral, sample_mask=mask
1264+
)
1265+
assert vdata.num_individuals == 2
1266+
inf_ts = tsinfer.infer(vdata)
1267+
assert inf_ts.num_individuals == 2
1268+
for v, p, allele_arr in zip(inf_ts.variants(), pos, alleles):
1269+
expected_idx = G[v.site.id, keep, :].flatten()
1270+
assert v.site.position == p
1271+
assert np.array_equal(v.states(), allele_arr[expected_idx])
1272+
1273+
def test_site_mask(self):
1274+
pos, G, alleles, ancestral = self.demo_data()
1275+
G = np.array(G)
1276+
mask = np.array([False, False, True, False, False])
1277+
keep = np.logical_not(mask)
1278+
pos = np.array(pos)
1279+
alleles = np.array(alleles)
1280+
ancestral = np.array(ancestral)
1281+
vdata = tsinfer.VariantData.from_arrays(
1282+
G, pos, alleles, ancestral[keep], site_mask=mask
1283+
)
1284+
assert vdata.num_individuals == 3
1285+
inf_ts = tsinfer.infer(vdata)
1286+
used_sites = np.where(keep)[0]
1287+
for v, p, allele_arr in zip(inf_ts.variants(), pos[keep], alleles[keep]):
1288+
expected_idx = G[used_sites[v.site.id], :, :].flatten()
1289+
assert v.site.position == p
1290+
assert np.array_equal(v.states(), allele_arr[expected_idx])
1291+
1292+
def test_bad_site_mask_length(self):
1293+
pos, G, alleles, ancestral = self.demo_data()
1294+
mask = np.array([False, True, False]) # wrong length
1295+
with pytest.raises(ValueError, match="length as the total number of variants"):
1296+
tsinfer.VariantData.from_arrays(G, pos, alleles, ancestral, site_mask=mask)
1297+
1298+
def test_bad_sample_mask_length(self):
1299+
pos, G, alleles, ancestral = self.demo_data()
1300+
mask = np.array([False, True, True, False, True]) # wrong length
1301+
with pytest.raises(ValueError, match="length as the total number of samples"):
1302+
tsinfer.VariantData.from_arrays(
1303+
G, pos, alleles, ancestral, sample_mask=mask
1304+
)
1305+
1306+
def test_bad_ancestral_state_masked(self):
1307+
pos, G, alleles, ancestral = self.demo_data()
1308+
mask = np.array([False, False, True, False, False])
1309+
with pytest.raises(ValueError, match="`ancestral_state` must be a 1D array"):
1310+
# Need to provide ancestral states of the same length as *unmasked* sites
1311+
tsinfer.VariantData.from_arrays(G, pos, alleles, ancestral, site_mask=mask)
1312+
1313+
def test_round_trip_ts(self):
1314+
ts = msprime.sim_ancestry(10, sequence_length=1000, random_seed=123)
1315+
ts = msprime.sim_mutations(ts, rate=1e-2, random_seed=123)
1316+
samples = ts.individuals_nodes
1317+
G = []
1318+
alleles = []
1319+
for v in ts.variants():
1320+
G.append(v.genotypes[samples])
1321+
alleles.append(v.alleles + ("",) * (4 - len(v.alleles))) # pad to 4 alleles
1322+
1323+
vdata = tsinfer.VariantData.from_arrays(
1324+
G,
1325+
ts.sites_position,
1326+
alleles,
1327+
np.array(ts.sites_ancestral_state, dtype="U1"),
1328+
)
1329+
inf_ts = tsinfer.infer(vdata)
1330+
for v1, v2 in zip(inf_ts.variants(), ts.variants()):
1331+
assert np.array_equal(v1.states(), v2.states())

tsinfer/formats.py

Lines changed: 155 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2424,7 +2424,6 @@ def __init__(
24242424
individuals_flags=None,
24252425
sequence_length=None,
24262426
):
2427-
self._sequence_length = sequence_length
24282427
self._contig_index = None
24292428
self._contig_id = None
24302429
try:
@@ -2494,6 +2493,9 @@ def process_array(
24942493
self.individuals_select = ~sample_mask.astype(bool)
24952494
self._num_sites = np.sum(self.sites_select)
24962495

2496+
if len(self.sites_select) == 0:
2497+
raise ValueError("No sites exist")
2498+
24972499
if np.sum(self.sites_select) == 0:
24982500
raise ValueError(
24992501
"All sites have been masked out, at least one value "
@@ -2647,6 +2649,158 @@ def process_array(
26472649
logger.info(
26482650
f"Number of individuals after applying mask: {self.num_individuals}"
26492651
)
2652+
if sequence_length is not None:
2653+
if sequence_length <= self.sites_position[-1]:
2654+
raise ValueError(
2655+
"`sequence_length` cannot be less than or equal to the maximum "
2656+
"unmasked variant position"
2657+
)
2658+
self._sequence_length = sequence_length
2659+
2660+
@classmethod
2661+
def from_arrays(
2662+
cls,
2663+
variant_matrix_phased,
2664+
variant_position,
2665+
variant_allele,
2666+
ancestral_state,
2667+
*,
2668+
sample_id=None,
2669+
site_mask=None,
2670+
sample_mask=None,
2671+
**kwargs,
2672+
):
2673+
"""
2674+
Create a basic in-memory VariantData instance directly from array data.
2675+
Mainly useful for small test datasets. Larger datasets, or ones that require
2676+
metadata to be included, should use e.g. bio2zarr
2677+
to create a zarr datastore containing the required data and call
2678+
VariantData(path_to_zarr)
2679+
2680+
.. note::
2681+
If a ``site_mask`` or ``sample_mask`` is provided, this does not
2682+
require changing the size of the ``variant_position``, ``variant_allele``
2683+
or `sample_id` arrays. Which must always match the dimensions of the
2684+
``variant_matrix_phased`` array. However, if sites are masked out, this
2685+
*does* require changing the ``ancestral_state`` array to match the number of
2686+
unmasked sites (and similarly for other arrays passed as kwargs).
2687+
2688+
:param array variant_matrix_phased: a 3D array of variants X samples x ploidy,
2689+
giving an index into the allele array for each corresponding variant. Values
2690+
must be coercable into 8-bit (np.int8) integers. Data for all samples is
2691+
assumed to be phased. This corresponds to the ``call_genotype`` array in the
2692+
VCF Zarr specification (e.g. missing data can be coded as -1).
2693+
:param array variant_position: a 1D array of variant positions, of the same
2694+
length as the first dimension of ``variant_matrix_phased``.
2695+
:param array variant_allele: a 2D string array of variants x max_num_alleles at a
2696+
site. The length of the first dimension must match the first dimension of the
2697+
``variant_matrix_phased`` array. The second dimension must be at least as
2698+
long as the maximum allele index in the `variant_matrix_phased` array (i.e.
2699+
each allele list for a variant must be the same length; this can be ensured
2700+
by padding the list with `""`).
2701+
:param array ancestral_state: A numpy array of strings specifying
2702+
the ancestral states (alleles) used in inference. For unknown ancestral
2703+
alleles, any character which is not in the allele list can be used.
2704+
This must be the same length as the number of *unmasked* variants in
2705+
``variant_matrix_phased`` (see note above).
2706+
:param array sample_id: a 1D string array of sample names, of length of the
2707+
total number of n-ploid samples in ``variant_matrix_phased`` (i.e. the
2708+
second dimension of ``variant_matrix_phased``).
2709+
If None, each individual n-ploid sample will be
2710+
allocated an ID corresponding to its sample index in the *unmasked*
2711+
``variant_matrix_phased`` array (i.e. "0", "1", "2", .. etc.)
2712+
:param array site_mask: A numpy array of booleans of length specifying which
2713+
sites to mask out (exclude) from the dataset.
2714+
:param array sample_mask: A numpy array of booleans of length specifying which
2715+
samples to mask out (exclude) from the dataset.
2716+
:param \\**kwargs: Further arguments passed to the VariantData constructor.
2717+
In particular you may wish to specify `sequence_length`. Arrays for
2718+
``sites_time``, ``individuals_time`` etc. can also be provided.
2719+
"""
2720+
call_genotype = np.array(variant_matrix_phased, dtype=np.int8)
2721+
if call_genotype.ndim != 3:
2722+
raise ValueError("`variant_matrix_phased` must be a 3D array")
2723+
2724+
num_variants, num_samples, ploidy = call_genotype.shape
2725+
if ploidy == 0:
2726+
raise ValueError("Ploidy must be greater than zero")
2727+
variant_position = np.asarray(variant_position, dtype=np.float64)
2728+
if variant_position.shape != (num_variants,):
2729+
raise ValueError(
2730+
"`variant_position` must be a 1D array of the same length as "
2731+
"the number of variants in variant_matrix_phased"
2732+
)
2733+
2734+
variant_allele = np.asarray(variant_allele, dtype="U") # make unicode for zarr
2735+
if variant_allele.ndim != 2 or variant_allele.shape[0] != num_variants:
2736+
raise ValueError(
2737+
"`variant_allele` must be a 2D array with the same number of rows as "
2738+
"variants in `variant_matrix_phased`"
2739+
)
2740+
2741+
if sample_id is None:
2742+
sample_id = np.arange(call_genotype.shape[1]).astype(str)
2743+
sample_id = np.asarray(sample_id, dtype="U")
2744+
if sample_id.shape != (num_samples,):
2745+
raise ValueError(
2746+
"`sample_id` must be a 1D array of the same length as the total "
2747+
"number of samples in `variant_matrix_phased`"
2748+
)
2749+
2750+
if site_mask is None:
2751+
site_keep = slice(None)
2752+
else:
2753+
if site_mask.shape != (num_variants,):
2754+
raise ValueError(
2755+
"`site_mask` must be a 1D array of the same length as the total "
2756+
"number of variants in `variant_matrix_phased`"
2757+
)
2758+
site_keep = np.logical_not(site_mask) # turn into an inclusion mask
2759+
num_variants = np.sum(site_keep)
2760+
2761+
if sample_mask is None:
2762+
sample_keep = slice(None)
2763+
else:
2764+
if sample_mask.shape != (num_samples,):
2765+
raise ValueError(
2766+
"`sample_mask` must be a 1D array of the same length as the "
2767+
"total number of samples in `variant_matrix_phased`"
2768+
)
2769+
sample_keep = np.logical_not(sample_mask) # turn into an inclusion mask
2770+
num_samples = np.sum(sample_keep)
2771+
2772+
# Further tests must take into account the masked num_samples & num_variants
2773+
2774+
max_alleles = np.max(call_genotype[site_keep, sample_keep, :], initial=-1) + 1
2775+
if variant_allele.shape[1] < max_alleles:
2776+
raise ValueError(
2777+
"`variant_allele` must have the same number of columns as the maximum "
2778+
"value in the unmasked variant_matrix_phased plus one"
2779+
)
2780+
2781+
ancestral_state = np.asarray(ancestral_state, dtype="U")
2782+
if ancestral_state.shape != (num_variants,):
2783+
raise ValueError(
2784+
"`ancestral_state` must be a 1D array of the same length as the "
2785+
"number of unmasked variants in `variant_matrix_phased`"
2786+
)
2787+
2788+
store = zarr.storage.MemoryStore()
2789+
root = zarr.group(store=store, overwrite=True)
2790+
root.create_dataset("variant_position", data=variant_position)
2791+
root.create_dataset("call_genotype", data=call_genotype)
2792+
root.create_dataset( # Assume all phased
2793+
"call_genotype_phased", data=np.ones(call_genotype.shape[:2], dtype=bool)
2794+
)
2795+
root.create_dataset("variant_allele", data=variant_allele)
2796+
root.create_dataset("sample_id", data=sample_id)
2797+
return cls(
2798+
root,
2799+
ancestral_state,
2800+
site_mask=site_mask,
2801+
sample_mask=sample_mask,
2802+
**kwargs,
2803+
)
26502804

26512805
@functools.cached_property
26522806
def format_name(self):

0 commit comments

Comments
 (0)