@@ -1167,3 +1167,204 @@ def test_with_variant_data(self, tmp_path):
11671167 else :
11681168 allele_idx = - 1
11691169 assert vdata .sites_ancestral_allele [i ] == allele_idx
1170+
1171+
1172+ class TestFromArrays :
1173+ def demo_data (self ):
1174+ # returns pos, data, alleles, ancestral
1175+ return [
1176+ list (data )
1177+ for data in zip (
1178+ * [
1179+ (3 , [[0 , 1 ], [0 , 0 ], [0 , 0 ]], ["A" , "T" , "" ], "A" ),
1180+ (10 , [[0 , 1 ], [1 , 1 ], [0 , 0 ]], ["C" , "A" , "" ], "C" ),
1181+ (13 , [[0 , 1 ], [1 , 0 ], [0 , 0 ]], ["G" , "C" , "" ], "C" ),
1182+ (19 , [[0 , 0 ], [0 , 1 ], [1 , 0 ]], ["A" , "C" , "" ], "A" ),
1183+ (20 , [[0 , 1 ], [2 , 0 ], [0 , 0 ]], ["T" , "G" , "C" ], "T" ),
1184+ ]
1185+ )
1186+ ]
1187+
1188+ def test_simple_from_arrays (self ):
1189+ pos , G , alleles , ancestral = self .demo_data ()
1190+ vdata = tsinfer .VariantData .from_arrays (G , pos , alleles , ancestral )
1191+ assert vdata .num_individuals == 3
1192+ assert vdata .num_sites == 5
1193+ inf_ts = tsinfer .infer (vdata )
1194+ assert inf_ts .num_samples == 6
1195+ assert inf_ts .num_individuals == 3
1196+ assert inf_ts .num_sites == 5
1197+ assert np .all (inf_ts .sites_position == pos )
1198+
1199+ def test_named_from_arrays (self ):
1200+ # When we pass sample_id names, they should be stored in the individuals metadata
1201+ pos , G , alleles , ancestral = self .demo_data ()
1202+ sample_id = ["sample1" , "sample2" , "sample3" ]
1203+ vdata = tsinfer .VariantData .from_arrays (
1204+ G , pos , alleles , ancestral , sample_id = sample_id
1205+ )
1206+ assert vdata .num_individuals == 3
1207+ inf_ts = tsinfer .infer (vdata )
1208+ assert inf_ts .num_individuals == 3
1209+ for name , ind in zip (sample_id , inf_ts .individuals ()):
1210+ assert ind .metadata ["variant_data_sample_id" ] == name
1211+
1212+ def test_bad_variant_matrix (self ):
1213+ pos , G , alleles , ancestral = self .demo_data ()
1214+ G = np .array (G )
1215+ with pytest .raises (ValueError , match = "must be a 3D array" ):
1216+ tsinfer .VariantData .from_arrays ([G ], pos , alleles , ancestral )
1217+ with pytest .raises (ValueError , match = "must be a 3D array" ):
1218+ tsinfer .VariantData .from_arrays (G [:, :, 0 ], pos , alleles , ancestral )
1219+
1220+ def test_empty (self ):
1221+ # Test with ploidy=1 but no sites
1222+ pos , G , alleles , ancestral = [], np .empty ((0 , 0 , 1 )), np .empty ((0 , 0 )), []
1223+ with pytest .raises (ValueError , match = "No sites exist" ):
1224+ tsinfer .VariantData .from_arrays (G , pos , alleles , ancestral )
1225+
1226+ def test_zero_ploidy (self ):
1227+ pos , G , alleles , ancestral = [], [[[]]], np .empty ((0 , 0 )), []
1228+ with pytest .raises (ValueError , match = "Ploidy must be greater than zero" ):
1229+ tsinfer .VariantData .from_arrays (G , pos , alleles , ancestral )
1230+
1231+ def test_from_arrays_ancestral_missing_warning (self ):
1232+ pos , G , alleles , ancestral = self .demo_data ()
1233+ ancestral [0 ] = "-"
1234+ with pytest .warns (UserWarning , match = r"ancestral allele.+not found[\s\S]+'-'" ):
1235+ tsinfer .VariantData .from_arrays (G , pos , alleles , ancestral )
1236+
1237+ def test_sequence_length (self ):
1238+ pos , G , alleles , ancestral = self .demo_data ()
1239+ vdata = tsinfer .VariantData .from_arrays (
1240+ G , pos , alleles , ancestral , sequence_length = 50
1241+ )
1242+ assert vdata .sequence_length == 50
1243+
1244+ def test_bad_sequence_length (self ):
1245+ pos , G , alleles , ancestral = self .demo_data ()
1246+ with pytest .raises (ValueError , match = "`sequence_length` cannot be less" ):
1247+ tsinfer .VariantData .from_arrays (
1248+ G , pos , alleles , ancestral , sequence_length = 10
1249+ )
1250+
1251+ @pytest .mark .parametrize ("pos" , [[[3 , 10 , 13 , 19 , 20 ]], [3 , 10 , 13 , 19 ]])
1252+ def test_bad_position (self , pos ):
1253+ _ , G , alleles , ancestral = self .demo_data ()
1254+ with pytest .raises (ValueError , match = "`variant_position` must be a 1D array" ):
1255+ tsinfer .VariantData .from_arrays (G , [pos ], alleles , ancestral )
1256+
1257+ def test_unordered_position (self ):
1258+ pos , G , alleles , ancestral = self .demo_data ()
1259+ pos [- 1 ] = 5 # out of order
1260+ with pytest .raises (ValueError , match = "out-of-order values" ):
1261+ tsinfer .VariantData .from_arrays (G , pos , alleles , ancestral )
1262+
1263+ def test_bad_dim_alleles (self ):
1264+ pos , G , alleles , ancestral = self .demo_data ()
1265+ with pytest .raises (ValueError , match = "`variant_allele` must be a 2D array" ):
1266+ tsinfer .VariantData .from_arrays (G , pos , [alleles ], ancestral )
1267+
1268+ def test_bad_alleles (self ):
1269+ pos , G , alleles , ancestral = self .demo_data ()
1270+ alleles = np .array (alleles )
1271+ with pytest .raises (ValueError , match = "same number of rows as variants" ):
1272+ tsinfer .VariantData .from_arrays (G , pos , alleles [1 :, :], ancestral )
1273+
1274+ def test_bad_num_alleles (self ):
1275+ pos , G , alleles , ancestral = self .demo_data ()
1276+ alleles = np .array (alleles )
1277+ with pytest .raises (ValueError , match = "same number of columns" ):
1278+ tsinfer .VariantData .from_arrays (G , pos , alleles [:, 1 :], ancestral )
1279+
1280+ def test_bad_ancestral_state_length (self ):
1281+ pos , G , alleles , ancestral = self .demo_data ()
1282+ ancestral = np .array (ancestral )
1283+ with pytest .raises (ValueError , match = "`ancestral_state` must be a 1D array" ):
1284+ tsinfer .VariantData .from_arrays (G , pos , alleles , [ancestral ])
1285+ with pytest .raises (ValueError , match = "`ancestral_state` must be a 1D array" ):
1286+ tsinfer .VariantData .from_arrays (G , pos , alleles , ancestral [1 :])
1287+
1288+ @pytest .mark .parametrize ("sid" , [["A" ], []])
1289+ def test_bad_sample_id (self , sid ):
1290+ pos , G , alleles , ancestral = self .demo_data ()
1291+ print (sid )
1292+ with pytest .raises (ValueError , match = "`sample_id` must be a 1D array" ):
1293+ tsinfer .VariantData .from_arrays (G , pos , alleles , ancestral , sample_id = sid )
1294+
1295+ def test_sample_mask (self ):
1296+ pos , G , alleles , ancestral = self .demo_data ()
1297+ G = np .array (G )
1298+ mask = np .array ([False , False , True ])
1299+ keep = np .logical_not (mask )
1300+ alleles = np .array (alleles )
1301+ vdata = tsinfer .VariantData .from_arrays (
1302+ G , pos , alleles , ancestral , sample_mask = mask
1303+ )
1304+ assert vdata .num_individuals == 2
1305+ inf_ts = tsinfer .infer (vdata )
1306+ assert inf_ts .num_individuals == 2
1307+ for v , p , allele_arr in zip (inf_ts .variants (), pos , alleles ):
1308+ expected_idx = G [v .site .id , keep , :].flatten ()
1309+ assert v .site .position == p
1310+ assert np .array_equal (v .states (), allele_arr [expected_idx ])
1311+
1312+ def test_site_mask (self ):
1313+ pos , G , alleles , ancestral = self .demo_data ()
1314+ G = np .array (G )
1315+ mask = np .array ([False , False , True , False , False ])
1316+ keep = np .logical_not (mask )
1317+ pos = np .array (pos )
1318+ alleles = np .array (alleles )
1319+ ancestral = np .array (ancestral )
1320+ vdata = tsinfer .VariantData .from_arrays (
1321+ G , pos , alleles , ancestral [keep ], site_mask = mask
1322+ )
1323+ assert vdata .num_individuals == 3
1324+ inf_ts = tsinfer .infer (vdata )
1325+ used_sites = np .where (keep )[0 ]
1326+ for v , p , allele_arr in zip (inf_ts .variants (), pos [keep ], alleles [keep ]):
1327+ expected_idx = G [used_sites [v .site .id ], :, :].flatten ()
1328+ assert v .site .position == p
1329+ assert np .array_equal (v .states (), allele_arr [expected_idx ])
1330+
1331+ def test_bad_site_mask_length (self ):
1332+ pos , G , alleles , ancestral = self .demo_data ()
1333+ mask = np .array ([False , True , False ]) # wrong length
1334+ with pytest .raises (ValueError , match = "length as the total number of variants" ):
1335+ tsinfer .VariantData .from_arrays (G , pos , alleles , ancestral , site_mask = mask )
1336+
1337+ def test_bad_sample_mask_length (self ):
1338+ pos , G , alleles , ancestral = self .demo_data ()
1339+ mask = np .array ([False , True , True , False , True ]) # wrong length
1340+ with pytest .raises (ValueError , match = "length as the total number of samples" ):
1341+ tsinfer .VariantData .from_arrays (
1342+ G , pos , alleles , ancestral , sample_mask = mask
1343+ )
1344+
1345+ def test_bad_ancestral_state_masked (self ):
1346+ pos , G , alleles , ancestral = self .demo_data ()
1347+ mask = np .array ([False , False , True , False , False ])
1348+ with pytest .raises (ValueError , match = "`ancestral_state` must be a 1D array" ):
1349+ # Need to provide ancestral states of the same length as *unmasked* sites
1350+ tsinfer .VariantData .from_arrays (G , pos , alleles , ancestral , site_mask = mask )
1351+
1352+ def test_round_trip_ts (self ):
1353+ ts = msprime .sim_ancestry (10 , sequence_length = 1000 , random_seed = 123 )
1354+ ts = msprime .sim_mutations (ts , rate = 1e-2 , random_seed = 123 )
1355+ samples = ts .individuals_nodes
1356+ G = []
1357+ alleles = []
1358+ for v in ts .variants ():
1359+ G .append (v .genotypes [samples ])
1360+ alleles .append (v .alleles + ("" ,) * (4 - len (v .alleles ))) # pad to 4 alleles
1361+
1362+ vdata = tsinfer .VariantData .from_arrays (
1363+ G ,
1364+ ts .sites_position ,
1365+ alleles ,
1366+ np .array (ts .sites_ancestral_state , dtype = "U1" ),
1367+ )
1368+ inf_ts = tsinfer .infer (vdata )
1369+ for v1 , v2 in zip (inf_ts .variants (), ts .variants ()):
1370+ assert np .array_equal (v1 .states (), v2 .states ())
0 commit comments