@@ -1347,12 +1347,19 @@ def test_bad_arg_types(self, arg):
1347
1347
1348
1348
1349
1349
class TestGeneticRelatednessMatrix :
1350
- def check (self , ts , mode , * , windows = None , span_normalise = True ):
1350
+ def check (self , ts , mode , * , sample_sets = None , windows = None , span_normalise = True ):
1351
1351
G1 = stats_api_genetic_relatedness_matrix (
1352
- ts , mode = mode , windows = windows , span_normalise = span_normalise
1352
+ ts ,
1353
+ mode = mode ,
1354
+ sample_sets = sample_sets ,
1355
+ windows = windows ,
1356
+ span_normalise = span_normalise ,
1353
1357
)
1354
1358
G2 = ts .genetic_relatedness_matrix (
1355
- mode = mode , windows = windows , span_normalise = span_normalise
1359
+ mode = mode ,
1360
+ sample_sets = sample_sets ,
1361
+ windows = windows ,
1362
+ span_normalise = span_normalise ,
1356
1363
)
1357
1364
np .testing .assert_array_almost_equal (G1 , G2 )
1358
1365
@@ -1368,6 +1375,33 @@ def test_single_tree(self, mode):
1368
1375
ts = tsutil .insert_branch_sites (ts )
1369
1376
self .check (ts , mode )
1370
1377
1378
+ @pytest .mark .parametrize ("mode" , DIVMAT_MODES )
1379
+ def test_single_tree_sample_sets (self , mode ):
1380
+ # 2.00┊ 6 ┊
1381
+ # ┊ ┏━┻━┓ ┊
1382
+ # 1.00┊ 4 5 ┊
1383
+ # ┊ ┏┻┓ ┏┻┓ ┊
1384
+ # 0.00┊ 0 1 2 3 ┊
1385
+ # 0 1
1386
+ ts = tskit .Tree .generate_balanced (4 ).tree_sequence
1387
+ ts = tsutil .insert_branch_sites (ts )
1388
+ with pytest .raises (ValueError , match = "2888" ):
1389
+ self .check (ts , mode , sample_sets = [[0 , 1 ], [2 , 3 ]])
1390
+
1391
+ @pytest .mark .parametrize ("mode" , DIVMAT_MODES )
1392
+ def test_single_tree_single_samples (self , mode ):
1393
+ # 2.00┊ 6 ┊
1394
+ # ┊ ┏━┻━┓ ┊
1395
+ # 1.00┊ 4 5 ┊
1396
+ # ┊ ┏┻┓ ┏┻┓ ┊
1397
+ # 0.00┊ 0 1 2 3 ┊
1398
+ # 0 1
1399
+ ts = tskit .Tree .generate_balanced (4 ).tree_sequence
1400
+ ts = tsutil .insert_branch_sites (ts )
1401
+ self .check (ts , mode , sample_sets = [[0 ], [1 ]])
1402
+ self .check (ts , mode , sample_sets = [[0 ], [2 ]])
1403
+ self .check (ts , mode , sample_sets = [[0 ], [1 ], [2 ]])
1404
+
1371
1405
@pytest .mark .parametrize ("mode" , DIVMAT_MODES )
1372
1406
def test_single_tree_windows (self , mode ):
1373
1407
# 2.00┊ 6 ┊
@@ -1390,3 +1424,12 @@ def test_suite_defaults(self, ts, mode):
1390
1424
@pytest .mark .parametrize ("span_normalise" , [True , False ])
1391
1425
def test_suite_span_normalise (self , ts , mode , span_normalise ):
1392
1426
self .check (ts , mode = mode , span_normalise = span_normalise )
1427
+
1428
+ @pytest .mark .skip ("fix sample sets #2888" )
1429
+ @pytest .mark .parametrize ("ts" , get_example_tree_sequences ())
1430
+ @pytest .mark .parametrize ("mode" , DIVMAT_MODES )
1431
+ @pytest .mark .parametrize ("num_sets" , [2 ]) # [[2, 3, 4, 5])
1432
+ def test_suite_sample_sets (self , ts , mode , num_sets ):
1433
+ if ts .num_samples >= num_sets :
1434
+ sample_sets = np .array_split (ts .samples (), num_sets )
1435
+ self .check (ts , sample_sets = sample_sets , mode = mode )
0 commit comments