Skip to content

Commit 3364349

Browse files
committed
Create a VariantData.from_arrays method
This makes an in-memory vdata object that can be used for testing. Fixes #924
1 parent 02621a9 commit 3364349

File tree

2 files changed

+356
-1
lines changed

2 files changed

+356
-1
lines changed

tests/test_variantdata.py

Lines changed: 201 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1167,3 +1167,204 @@ def test_with_variant_data(self, tmp_path):
11671167
else:
11681168
allele_idx = -1
11691169
assert vdata.sites_ancestral_allele[i] == allele_idx
1170+
1171+
1172+
class TestFromArrays:
1173+
def demo_data(self):
1174+
# returns pos, data, alleles, ancestral
1175+
return [
1176+
list(data)
1177+
for data in zip(
1178+
*[
1179+
(3, [[0, 1], [0, 0], [0, 0]], ["A", "T", ""], "A"),
1180+
(10, [[0, 1], [1, 1], [0, 0]], ["C", "A", ""], "C"),
1181+
(13, [[0, 1], [1, 0], [0, 0]], ["G", "C", ""], "C"),
1182+
(19, [[0, 0], [0, 1], [1, 0]], ["A", "C", ""], "A"),
1183+
(20, [[0, 1], [2, 0], [0, 0]], ["T", "G", "C"], "T"),
1184+
]
1185+
)
1186+
]
1187+
1188+
def test_simple_from_arrays(self):
1189+
pos, G, alleles, ancestral = self.demo_data()
1190+
vdata = tsinfer.VariantData.from_arrays(G, pos, alleles, ancestral)
1191+
assert vdata.num_individuals == 3
1192+
assert vdata.num_sites == 5
1193+
inf_ts = tsinfer.infer(vdata)
1194+
assert inf_ts.num_samples == 6
1195+
assert inf_ts.num_individuals == 3
1196+
assert inf_ts.num_sites == 5
1197+
assert np.all(inf_ts.sites_position == pos)
1198+
1199+
def test_named_from_arrays(self):
1200+
# When we pass sample_id names, they should be stored in the individuals metadata
1201+
pos, G, alleles, ancestral = self.demo_data()
1202+
sample_id = ["sample1", "sample2", "sample3"]
1203+
vdata = tsinfer.VariantData.from_arrays(
1204+
G, pos, alleles, ancestral, sample_id=sample_id
1205+
)
1206+
assert vdata.num_individuals == 3
1207+
inf_ts = tsinfer.infer(vdata)
1208+
assert inf_ts.num_individuals == 3
1209+
for name, ind in zip(sample_id, inf_ts.individuals()):
1210+
assert ind.metadata["variant_data_sample_id"] == name
1211+
1212+
def test_bad_variant_matrix(self):
1213+
pos, G, alleles, ancestral = self.demo_data()
1214+
G = np.array(G)
1215+
with pytest.raises(ValueError, match="must be a 3D array"):
1216+
tsinfer.VariantData.from_arrays([G], pos, alleles, ancestral)
1217+
with pytest.raises(ValueError, match="must be a 3D array"):
1218+
tsinfer.VariantData.from_arrays(G[:, :, 0], pos, alleles, ancestral)
1219+
1220+
def test_empty(self):
1221+
# Test with ploidy=1 but no sites
1222+
pos, G, alleles, ancestral = [], np.empty((0, 0, 1)), np.empty((0, 0)), []
1223+
with pytest.raises(ValueError, match="No sites exist"):
1224+
tsinfer.VariantData.from_arrays(G, pos, alleles, ancestral)
1225+
1226+
def test_zero_ploidy(self):
1227+
pos, G, alleles, ancestral = [], [[[]]], np.empty((0, 0)), []
1228+
with pytest.raises(ValueError, match="Ploidy must be greater than zero"):
1229+
tsinfer.VariantData.from_arrays(G, pos, alleles, ancestral)
1230+
1231+
def test_from_arrays_ancestral_missing_warning(self):
1232+
pos, G, alleles, ancestral = self.demo_data()
1233+
ancestral[0] = "-"
1234+
with pytest.warns(UserWarning, match=r"ancestral allele.+not found[\s\S]+'-'"):
1235+
tsinfer.VariantData.from_arrays(G, pos, alleles, ancestral)
1236+
1237+
def test_sequence_length(self):
1238+
pos, G, alleles, ancestral = self.demo_data()
1239+
vdata = tsinfer.VariantData.from_arrays(
1240+
G, pos, alleles, ancestral, sequence_length=50
1241+
)
1242+
assert vdata.sequence_length == 50
1243+
1244+
def test_bad_sequence_length(self):
1245+
pos, G, alleles, ancestral = self.demo_data()
1246+
with pytest.raises(ValueError, match="`sequence_length` cannot be less"):
1247+
tsinfer.VariantData.from_arrays(
1248+
G, pos, alleles, ancestral, sequence_length=10
1249+
)
1250+
1251+
@pytest.mark.parametrize("pos", [[[3, 10, 13, 19, 20]], [3, 10, 13, 19]])
1252+
def test_bad_position(self, pos):
1253+
_, G, alleles, ancestral = self.demo_data()
1254+
with pytest.raises(ValueError, match="`variant_position` must be a 1D array"):
1255+
tsinfer.VariantData.from_arrays(G, [pos], alleles, ancestral)
1256+
1257+
def test_unordered_position(self):
1258+
pos, G, alleles, ancestral = self.demo_data()
1259+
pos[-1] = 5 # out of order
1260+
with pytest.raises(ValueError, match="out-of-order values"):
1261+
tsinfer.VariantData.from_arrays(G, pos, alleles, ancestral)
1262+
1263+
def test_bad_dim_alleles(self):
1264+
pos, G, alleles, ancestral = self.demo_data()
1265+
with pytest.raises(ValueError, match="`variant_allele` must be a 2D array"):
1266+
tsinfer.VariantData.from_arrays(G, pos, [alleles], ancestral)
1267+
1268+
def test_bad_alleles(self):
1269+
pos, G, alleles, ancestral = self.demo_data()
1270+
alleles = np.array(alleles)
1271+
with pytest.raises(ValueError, match="same number of rows as variants"):
1272+
tsinfer.VariantData.from_arrays(G, pos, alleles[1:, :], ancestral)
1273+
1274+
def test_bad_num_alleles(self):
1275+
pos, G, alleles, ancestral = self.demo_data()
1276+
alleles = np.array(alleles)
1277+
with pytest.raises(ValueError, match="same number of columns"):
1278+
tsinfer.VariantData.from_arrays(G, pos, alleles[:, 1:], ancestral)
1279+
1280+
def test_bad_ancestral_state_length(self):
1281+
pos, G, alleles, ancestral = self.demo_data()
1282+
ancestral = np.array(ancestral)
1283+
with pytest.raises(ValueError, match="`ancestral_state` must be a 1D array"):
1284+
tsinfer.VariantData.from_arrays(G, pos, alleles, [ancestral])
1285+
with pytest.raises(ValueError, match="`ancestral_state` must be a 1D array"):
1286+
tsinfer.VariantData.from_arrays(G, pos, alleles, ancestral[1:])
1287+
1288+
@pytest.mark.parametrize("sid", [["A"], []])
1289+
def test_bad_sample_id(self, sid):
1290+
pos, G, alleles, ancestral = self.demo_data()
1291+
print(sid)
1292+
with pytest.raises(ValueError, match="`sample_id` must be a 1D array"):
1293+
tsinfer.VariantData.from_arrays(G, pos, alleles, ancestral, sample_id=sid)
1294+
1295+
def test_sample_mask(self):
1296+
pos, G, alleles, ancestral = self.demo_data()
1297+
G = np.array(G)
1298+
mask = np.array([False, False, True])
1299+
keep = np.logical_not(mask)
1300+
alleles = np.array(alleles)
1301+
vdata = tsinfer.VariantData.from_arrays(
1302+
G, pos, alleles, ancestral, sample_mask=mask
1303+
)
1304+
assert vdata.num_individuals == 2
1305+
inf_ts = tsinfer.infer(vdata)
1306+
assert inf_ts.num_individuals == 2
1307+
for v, p, allele_arr in zip(inf_ts.variants(), pos, alleles):
1308+
expected_idx = G[v.site.id, keep, :].flatten()
1309+
assert v.site.position == p
1310+
assert np.array_equal(v.states(), allele_arr[expected_idx])
1311+
1312+
def test_site_mask(self):
1313+
pos, G, alleles, ancestral = self.demo_data()
1314+
G = np.array(G)
1315+
mask = np.array([False, False, True, False, False])
1316+
keep = np.logical_not(mask)
1317+
pos = np.array(pos)
1318+
alleles = np.array(alleles)
1319+
ancestral = np.array(ancestral)
1320+
vdata = tsinfer.VariantData.from_arrays(
1321+
G, pos, alleles, ancestral[keep], site_mask=mask
1322+
)
1323+
assert vdata.num_individuals == 3
1324+
inf_ts = tsinfer.infer(vdata)
1325+
used_sites = np.where(keep)[0]
1326+
for v, p, allele_arr in zip(inf_ts.variants(), pos[keep], alleles[keep]):
1327+
expected_idx = G[used_sites[v.site.id], :, :].flatten()
1328+
assert v.site.position == p
1329+
assert np.array_equal(v.states(), allele_arr[expected_idx])
1330+
1331+
def test_bad_site_mask_length(self):
1332+
pos, G, alleles, ancestral = self.demo_data()
1333+
mask = np.array([False, True, False]) # wrong length
1334+
with pytest.raises(ValueError, match="length as the total number of variants"):
1335+
tsinfer.VariantData.from_arrays(G, pos, alleles, ancestral, site_mask=mask)
1336+
1337+
def test_bad_sample_mask_length(self):
1338+
pos, G, alleles, ancestral = self.demo_data()
1339+
mask = np.array([False, True, True, False, True]) # wrong length
1340+
with pytest.raises(ValueError, match="length as the total number of samples"):
1341+
tsinfer.VariantData.from_arrays(
1342+
G, pos, alleles, ancestral, sample_mask=mask
1343+
)
1344+
1345+
def test_bad_ancestral_state_masked(self):
1346+
pos, G, alleles, ancestral = self.demo_data()
1347+
mask = np.array([False, False, True, False, False])
1348+
with pytest.raises(ValueError, match="`ancestral_state` must be a 1D array"):
1349+
# Need to provide ancestral states of the same length as *unmasked* sites
1350+
tsinfer.VariantData.from_arrays(G, pos, alleles, ancestral, site_mask=mask)
1351+
1352+
def test_round_trip_ts(self):
1353+
ts = msprime.sim_ancestry(10, sequence_length=1000, random_seed=123)
1354+
ts = msprime.sim_mutations(ts, rate=1e-2, random_seed=123)
1355+
samples = ts.individuals_nodes
1356+
G = []
1357+
alleles = []
1358+
for v in ts.variants():
1359+
G.append(v.genotypes[samples])
1360+
alleles.append(v.alleles + ("",) * (4 - len(v.alleles))) # pad to 4 alleles
1361+
1362+
vdata = tsinfer.VariantData.from_arrays(
1363+
G,
1364+
ts.sites_position,
1365+
alleles,
1366+
np.array(ts.sites_ancestral_state, dtype="U1"),
1367+
)
1368+
inf_ts = tsinfer.infer(vdata)
1369+
for v1, v2 in zip(inf_ts.variants(), ts.variants()):
1370+
assert np.array_equal(v1.states(), v2.states())

tsinfer/formats.py

Lines changed: 155 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2424,7 +2424,6 @@ def __init__(
24242424
individuals_flags=None,
24252425
sequence_length=None,
24262426
):
2427-
self._sequence_length = sequence_length
24282427
self._contig_index = None
24292428
self._contig_id = None
24302429
try:
@@ -2494,6 +2493,9 @@ def process_array(
24942493
self.individuals_select = ~sample_mask.astype(bool)
24952494
self._num_sites = np.sum(self.sites_select)
24962495

2496+
if len(self.sites_select) == 0:
2497+
raise ValueError("No sites exist")
2498+
24972499
if np.sum(self.sites_select) == 0:
24982500
raise ValueError(
24992501
"All sites have been masked out, at least one value "
@@ -2647,6 +2649,158 @@ def process_array(
26472649
logger.info(
26482650
f"Number of individuals after applying mask: {self.num_individuals}"
26492651
)
2652+
if sequence_length is not None:
2653+
if sequence_length <= self.sites_position[-1]:
2654+
raise ValueError(
2655+
"`sequence_length` cannot be less than or equal to the maximum "
2656+
"unmasked variant position"
2657+
)
2658+
self._sequence_length = sequence_length
2659+
2660+
@classmethod
2661+
def from_arrays(
2662+
cls,
2663+
variant_matrix_phased,
2664+
variant_position,
2665+
variant_allele,
2666+
ancestral_state,
2667+
*,
2668+
sample_id=None,
2669+
site_mask=None,
2670+
sample_mask=None,
2671+
**kwargs,
2672+
):
2673+
"""
2674+
Create a basic in-memory VariantData instance directly from array data.
2675+
Mainly useful for small test datasets. Larger datasets, or ones that require
2676+
metadata to be included, should use e.g. bio2zarr
2677+
to create a zarr datastore containing the required data and call
2678+
VariantData(path_to_zarr)
2679+
2680+
.. note::
2681+
If a ``site_mask`` or ``sample_mask`` is provided, this does not
2682+
require changing the size of the ``variant_position``, ``variant_allele``
2683+
or `sample_id` arrays. Which must always match the dimensions of the
2684+
``variant_matrix_phased`` array. However, if sites are masked out, this
2685+
*does* require changing the ``ancestral_state`` array to match the number of
2686+
unmasked sites (and similarly for other arrays passed as kwargs).
2687+
2688+
:param array variant_matrix_phased: a 3D array of variants X samples x ploidy,
2689+
giving an index into the allele array for each corresponding variant. Values
2690+
must be coercable into 8-bit (np.int8) integers. Data for all samples is
2691+
assumed to be phased. This corresponds to the ``call_genotype`` array in the
2692+
VCF Zarr specification (e.g. missing data can be coded as -1).
2693+
:param array variant_position: a 1D array of variant positions, of the same
2694+
length as the first dimension of ``variant_matrix_phased``.
2695+
:param array variant_allele: a 2D string array of variants x max_num_alleles at a
2696+
site. The length of the first dimension must match the first dimension of the
2697+
``variant_matrix_phased`` array. The second dimension must be at least as
2698+
long as the maximum allele index in the `variant_matrix_phased` array (i.e.
2699+
each allele list for a variant must be the same length; this can be ensured
2700+
by padding the list with `""`).
2701+
:param array ancestral_state: A numpy array of strings specifying
2702+
the ancestral states (alleles) used in inference. For unknown ancestral
2703+
alleles, any character which is not in the allele list can be used.
2704+
This must be the same length as the number of *unmasked* variants in
2705+
``variant_matrix_phased`` (see note above).
2706+
:param array sample_id: a 1D string array of sample names, of length of the
2707+
total number of n-ploid samples in ``variant_matrix_phased`` (i.e. the
2708+
second dimension of ``variant_matrix_phased``).
2709+
If None, each individual n-ploid sample will be
2710+
allocated an ID corresponding to its sample index in the *unmasked*
2711+
``variant_matrix_phased`` array (i.e. "0", "1", "2", .. etc.)
2712+
:param array site_mask: A numpy array of booleans of length specifying which
2713+
sites to mask out (exclude) from the dataset.
2714+
:param array sample_mask: A numpy array of booleans of length specifying which
2715+
samples to mask out (exclude) from the dataset.
2716+
:param \\**kwargs: Further arguments passed to the VariantData constructor.
2717+
In particular you may wish to specify `sequence_length`. Arrays for
2718+
``sites_time``, ``individuals_time`` etc. can also be provided.
2719+
"""
2720+
call_genotype = np.array(variant_matrix_phased, dtype=np.int8)
2721+
if call_genotype.ndim != 3:
2722+
raise ValueError("`variant_matrix_phased` must be a 3D array")
2723+
2724+
num_variants, num_samples, ploidy = call_genotype.shape
2725+
if ploidy == 0:
2726+
raise ValueError("Ploidy must be greater than zero")
2727+
variant_position = np.asarray(variant_position, dtype=np.float64)
2728+
if variant_position.shape != (num_variants,):
2729+
raise ValueError(
2730+
"`variant_position` must be a 1D array of the same length as "
2731+
"the number of variants in variant_matrix_phased"
2732+
)
2733+
2734+
variant_allele = np.asarray(variant_allele, dtype="U") # make unicode for zarr
2735+
if variant_allele.ndim != 2 or variant_allele.shape[0] != num_variants:
2736+
raise ValueError(
2737+
"`variant_allele` must be a 2D array with the same number of rows as "
2738+
"variants in `variant_matrix_phased`"
2739+
)
2740+
2741+
if sample_id is None:
2742+
sample_id = np.arange(call_genotype.shape[1]).astype(str)
2743+
sample_id = np.asarray(sample_id, dtype="U")
2744+
if sample_id.shape != (num_samples,):
2745+
raise ValueError(
2746+
"`sample_id` must be a 1D array of the same length as the total "
2747+
"number of samples in `variant_matrix_phased`"
2748+
)
2749+
2750+
if site_mask is None:
2751+
site_keep = slice(None)
2752+
else:
2753+
if site_mask.shape != (num_variants,):
2754+
raise ValueError(
2755+
"`site_mask` must be a 1D array of the same length as the total "
2756+
"number of variants in `variant_matrix_phased`"
2757+
)
2758+
site_keep = np.logical_not(site_mask) # turn into an inclusion mask
2759+
num_variants = np.sum(site_keep)
2760+
2761+
if sample_mask is None:
2762+
sample_keep = slice(None)
2763+
else:
2764+
if sample_mask.shape != (num_samples,):
2765+
raise ValueError(
2766+
"`sample_mask` must be a 1D array of the same length as the "
2767+
"total number of samples in `variant_matrix_phased`"
2768+
)
2769+
sample_keep = np.logical_not(sample_mask) # turn into an inclusion mask
2770+
num_samples = np.sum(sample_keep)
2771+
2772+
# Further tests must take into account the masked num_samples & num_variants
2773+
2774+
max_alleles = np.max(call_genotype[site_keep, sample_keep, :], initial=-1) + 1
2775+
if variant_allele.shape[1] < max_alleles:
2776+
raise ValueError(
2777+
"`variant_allele` must have the same number of columns as the maximum "
2778+
"value in the unmasked variant_matrix_phased plus one"
2779+
)
2780+
2781+
ancestral_state = np.asarray(ancestral_state, dtype="U")
2782+
if ancestral_state.shape != (num_variants,):
2783+
raise ValueError(
2784+
"`ancestral_state` must be a 1D array of the same length as the "
2785+
"number of unmasked variants in `variant_matrix_phased`"
2786+
)
2787+
2788+
store = zarr.storage.MemoryStore()
2789+
root = zarr.group(store=store, overwrite=True)
2790+
root.create_dataset("variant_position", data=variant_position)
2791+
root.create_dataset("call_genotype", data=call_genotype)
2792+
root.create_dataset( # Assume all phased
2793+
"call_genotype_phased", data=np.ones(call_genotype.shape[:2], dtype=bool)
2794+
)
2795+
root.create_dataset("variant_allele", data=variant_allele)
2796+
root.create_dataset("sample_id", data=sample_id)
2797+
return cls(
2798+
root,
2799+
ancestral_state,
2800+
site_mask=site_mask,
2801+
sample_mask=sample_mask,
2802+
**kwargs,
2803+
)
26502804

26512805
@functools.cached_property
26522806
def format_name(self):

0 commit comments

Comments
 (0)