@@ -1128,3 +1128,204 @@ def test_with_variant_data(self, tmp_path):
11281128 else :
11291129 allele_idx = - 1
11301130 assert vdata .sites_ancestral_allele [i ] == allele_idx
1131+
1132+
1133+ class TestFromArrays :
1134+ def demo_data (self ):
1135+ # returns pos, data, alleles, ancestral
1136+ return [
1137+ list (data )
1138+ for data in zip (
1139+ * [
1140+ (3 , [[0 , 1 ], [0 , 0 ], [0 , 0 ]], ["A" , "T" , "" ], "A" ),
1141+ (10 , [[0 , 1 ], [1 , 1 ], [0 , 0 ]], ["C" , "A" , "" ], "C" ),
1142+ (13 , [[0 , 1 ], [1 , 0 ], [0 , 0 ]], ["G" , "C" , "" ], "C" ),
1143+ (19 , [[0 , 0 ], [0 , 1 ], [1 , 0 ]], ["A" , "C" , "" ], "A" ),
1144+ (20 , [[0 , 1 ], [2 , 0 ], [0 , 0 ]], ["T" , "G" , "C" ], "T" ),
1145+ ]
1146+ )
1147+ ]
1148+
1149+ def test_simple_from_arrays (self ):
1150+ pos , G , alleles , ancestral = self .demo_data ()
1151+ vdata = tsinfer .VariantData .from_arrays (G , pos , alleles , ancestral )
1152+ assert vdata .num_individuals == 3
1153+ assert vdata .num_sites == 5
1154+ inf_ts = tsinfer .infer (vdata )
1155+ assert inf_ts .num_samples == 6
1156+ assert inf_ts .num_individuals == 3
1157+ assert inf_ts .num_sites == 5
1158+ assert np .all (inf_ts .sites_position == pos )
1159+
1160+ def test_named_from_arrays (self ):
1161+ # When we pass sample_id names, they should be stored in the individuals metadata
1162+ pos , G , alleles , ancestral = self .demo_data ()
1163+ sample_id = ["sample1" , "sample2" , "sample3" ]
1164+ vdata = tsinfer .VariantData .from_arrays (
1165+ G , pos , alleles , ancestral , sample_id = sample_id
1166+ )
1167+ assert vdata .num_individuals == 3
1168+ inf_ts = tsinfer .infer (vdata )
1169+ assert inf_ts .num_individuals == 3
1170+ for name , ind in zip (sample_id , inf_ts .individuals ()):
1171+ assert ind .metadata ["variant_data_sample_id" ] == name
1172+
1173+ def test_bad_variant_matrix (self ):
1174+ pos , G , alleles , ancestral = self .demo_data ()
1175+ G = np .array (G )
1176+ with pytest .raises (ValueError , match = "must be a 3D array" ):
1177+ tsinfer .VariantData .from_arrays ([G ], pos , alleles , ancestral )
1178+ with pytest .raises (ValueError , match = "must be a 3D array" ):
1179+ tsinfer .VariantData .from_arrays (G [:, :, 0 ], pos , alleles , ancestral )
1180+
1181+ def test_empty (self ):
1182+ # Test with ploidy=1 but no sites
1183+ pos , G , alleles , ancestral = [], np .empty ((0 , 0 , 1 )), np .empty ((0 , 0 )), []
1184+ with pytest .raises (ValueError , match = "No sites exist" ):
1185+ tsinfer .VariantData .from_arrays (G , pos , alleles , ancestral )
1186+
1187+ def test_zero_ploidy (self ):
1188+ pos , G , alleles , ancestral = [], [[[]]], np .empty ((0 , 0 )), []
1189+ with pytest .raises (ValueError , match = "Ploidy must be greater than zero" ):
1190+ tsinfer .VariantData .from_arrays (G , pos , alleles , ancestral )
1191+
1192+ def test_from_arrays_ancestral_missing_warning (self ):
1193+ pos , G , alleles , ancestral = self .demo_data ()
1194+ ancestral [0 ] = "-"
1195+ with pytest .warns (UserWarning , match = r"ancestral allele.+not found[\s\S]+'-'" ):
1196+ tsinfer .VariantData .from_arrays (G , pos , alleles , ancestral )
1197+
1198+ def test_sequence_length (self ):
1199+ pos , G , alleles , ancestral = self .demo_data ()
1200+ vdata = tsinfer .VariantData .from_arrays (
1201+ G , pos , alleles , ancestral , sequence_length = 50
1202+ )
1203+ assert vdata .sequence_length == 50
1204+
1205+ def test_bad_sequence_length (self ):
1206+ pos , G , alleles , ancestral = self .demo_data ()
1207+ with pytest .raises (ValueError , match = "`sequence_length` cannot be less" ):
1208+ tsinfer .VariantData .from_arrays (
1209+ G , pos , alleles , ancestral , sequence_length = 10
1210+ )
1211+
1212+ @pytest .mark .parametrize ("pos" , [[[3 , 10 , 13 , 19 , 20 ]], [3 , 10 , 13 , 19 ]])
1213+ def test_bad_position (self , pos ):
1214+ _ , G , alleles , ancestral = self .demo_data ()
1215+ with pytest .raises (ValueError , match = "`variant_position` must be a 1D array" ):
1216+ tsinfer .VariantData .from_arrays (G , [pos ], alleles , ancestral )
1217+
1218+ def test_unordered_position (self ):
1219+ pos , G , alleles , ancestral = self .demo_data ()
1220+ pos [- 1 ] = 5 # out of order
1221+ with pytest .raises (ValueError , match = "out-of-order values" ):
1222+ tsinfer .VariantData .from_arrays (G , pos , alleles , ancestral )
1223+
1224+ def test_bad_dim_alleles (self ):
1225+ pos , G , alleles , ancestral = self .demo_data ()
1226+ with pytest .raises (ValueError , match = "`variant_allele` must be a 2D array" ):
1227+ tsinfer .VariantData .from_arrays (G , pos , [alleles ], ancestral )
1228+
1229+ def test_bad_alleles (self ):
1230+ pos , G , alleles , ancestral = self .demo_data ()
1231+ alleles = np .array (alleles )
1232+ with pytest .raises (ValueError , match = "same number of rows as variants" ):
1233+ tsinfer .VariantData .from_arrays (G , pos , alleles [1 :, :], ancestral )
1234+
1235+ def test_bad_num_alleles (self ):
1236+ pos , G , alleles , ancestral = self .demo_data ()
1237+ alleles = np .array (alleles )
1238+ with pytest .raises (ValueError , match = "same number of columns" ):
1239+ tsinfer .VariantData .from_arrays (G , pos , alleles [:, 1 :], ancestral )
1240+
1241+ def test_bad_ancestral_state_length (self ):
1242+ pos , G , alleles , ancestral = self .demo_data ()
1243+ ancestral = np .array (ancestral )
1244+ with pytest .raises (ValueError , match = "`ancestral_state` must be a 1D array" ):
1245+ tsinfer .VariantData .from_arrays (G , pos , alleles , [ancestral ])
1246+ with pytest .raises (ValueError , match = "`ancestral_state` must be a 1D array" ):
1247+ tsinfer .VariantData .from_arrays (G , pos , alleles , ancestral [1 :])
1248+
1249+ @pytest .mark .parametrize ("sid" , [["A" ], []])
1250+ def test_bad_sample_id (self , sid ):
1251+ pos , G , alleles , ancestral = self .demo_data ()
1252+ print (sid )
1253+ with pytest .raises (ValueError , match = "`sample_id` must be a 1D array" ):
1254+ tsinfer .VariantData .from_arrays (G , pos , alleles , ancestral , sample_id = sid )
1255+
1256+ def test_sample_mask (self ):
1257+ pos , G , alleles , ancestral = self .demo_data ()
1258+ G = np .array (G )
1259+ mask = np .array ([False , False , True ])
1260+ keep = np .logical_not (mask )
1261+ alleles = np .array (alleles )
1262+ vdata = tsinfer .VariantData .from_arrays (
1263+ G , pos , alleles , ancestral , sample_mask = mask
1264+ )
1265+ assert vdata .num_individuals == 2
1266+ inf_ts = tsinfer .infer (vdata )
1267+ assert inf_ts .num_individuals == 2
1268+ for v , p , allele_arr in zip (inf_ts .variants (), pos , alleles ):
1269+ expected_idx = G [v .site .id , keep , :].flatten ()
1270+ assert v .site .position == p
1271+ assert np .array_equal (v .states (), allele_arr [expected_idx ])
1272+
1273+ def test_site_mask (self ):
1274+ pos , G , alleles , ancestral = self .demo_data ()
1275+ G = np .array (G )
1276+ mask = np .array ([False , False , True , False , False ])
1277+ keep = np .logical_not (mask )
1278+ pos = np .array (pos )
1279+ alleles = np .array (alleles )
1280+ ancestral = np .array (ancestral )
1281+ vdata = tsinfer .VariantData .from_arrays (
1282+ G , pos , alleles , ancestral [keep ], site_mask = mask
1283+ )
1284+ assert vdata .num_individuals == 3
1285+ inf_ts = tsinfer .infer (vdata )
1286+ used_sites = np .where (keep )[0 ]
1287+ for v , p , allele_arr in zip (inf_ts .variants (), pos [keep ], alleles [keep ]):
1288+ expected_idx = G [used_sites [v .site .id ], :, :].flatten ()
1289+ assert v .site .position == p
1290+ assert np .array_equal (v .states (), allele_arr [expected_idx ])
1291+
1292+ def test_bad_site_mask_length (self ):
1293+ pos , G , alleles , ancestral = self .demo_data ()
1294+ mask = np .array ([False , True , False ]) # wrong length
1295+ with pytest .raises (ValueError , match = "length as the total number of variants" ):
1296+ tsinfer .VariantData .from_arrays (G , pos , alleles , ancestral , site_mask = mask )
1297+
1298+ def test_bad_sample_mask_length (self ):
1299+ pos , G , alleles , ancestral = self .demo_data ()
1300+ mask = np .array ([False , True , True , False , True ]) # wrong length
1301+ with pytest .raises (ValueError , match = "length as the total number of samples" ):
1302+ tsinfer .VariantData .from_arrays (
1303+ G , pos , alleles , ancestral , sample_mask = mask
1304+ )
1305+
1306+ def test_bad_ancestral_state_masked (self ):
1307+ pos , G , alleles , ancestral = self .demo_data ()
1308+ mask = np .array ([False , False , True , False , False ])
1309+ with pytest .raises (ValueError , match = "`ancestral_state` must be a 1D array" ):
1310+ # Need to provide ancestral states of the same length as *unmasked* sites
1311+ tsinfer .VariantData .from_arrays (G , pos , alleles , ancestral , site_mask = mask )
1312+
1313+ def test_round_trip_ts (self ):
1314+ ts = msprime .sim_ancestry (10 , sequence_length = 1000 , random_seed = 123 )
1315+ ts = msprime .sim_mutations (ts , rate = 1e-2 , random_seed = 123 )
1316+ samples = ts .individuals_nodes
1317+ G = []
1318+ alleles = []
1319+ for v in ts .variants ():
1320+ G .append (v .genotypes [samples ])
1321+ alleles .append (v .alleles + ("" ,) * (4 - len (v .alleles ))) # pad to 4 alleles
1322+
1323+ vdata = tsinfer .VariantData .from_arrays (
1324+ G ,
1325+ ts .sites_position ,
1326+ alleles ,
1327+ np .array (ts .sites_ancestral_state , dtype = "U1" ),
1328+ )
1329+ inf_ts = tsinfer .infer (vdata )
1330+ for v1 , v2 in zip (inf_ts .variants (), ts .variants ()):
1331+ assert np .array_equal (v1 .states (), v2 .states ())
0 commit comments