Skip to content

Commit 2c00ec1

Browse files
Merge pull request #663 from jeromekelleher/support-one-sample
Allow a subset with one sample
2 parents 5e27d1b + aa012d7 commit 2c00ec1

File tree

2 files changed

+18
-9
lines changed

2 files changed

+18
-9
lines changed

tests/test_formats.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
#
2-
# Copyright (C) 2018-2020 University of Oxford
2+
# Copyright (C) 2018-2022 University of Oxford
33
#
44
# This file is part of tsinfer.
55
#
@@ -569,7 +569,6 @@ def test_str(self):
569569

570570
def test_eq(self):
571571
ts = tsutil.get_example_ts(5, random_seed=3)
572-
print(ts.num_sites)
573572
input_file = formats.SampleData(sequence_length=ts.sequence_length)
574573
self.verify_data_round_trip(ts, input_file)
575574
assert input_file == input_file
@@ -663,9 +662,6 @@ def test_insufficient_samples(self):
663662
with pytest.raises(ValueError):
664663
sample_data.add_site(position=0, alleles=["0", "1"], genotypes=[])
665664
sample_data = formats.SampleData(sequence_length=10)
666-
with pytest.raises(ValueError):
667-
sample_data.add_site(position=0, alleles=["0", "1"], genotypes=[0])
668-
sample_data = formats.SampleData(sequence_length=10)
669665
sample_data.add_individual(ploidy=3)
670666
with pytest.raises(ValueError):
671667
sample_data.add_site(position=0, alleles=["0", "1"], genotypes=[0])
@@ -1420,6 +1416,19 @@ def verify_subset_data(self, source, individuals, sites):
14201416
j += 1
14211417
assert j == len(sites)
14221418

1419+
def test_one_sample(self):
1420+
ts = tsutil.get_example_ts(10)
1421+
sd1 = formats.SampleData.from_tree_sequence(ts)
1422+
G1 = ts.genotype_matrix()
1423+
# Because this is a haploid tree sequence we can use the
1424+
# individual and sample IDs interchangably.
1425+
cols = [3]
1426+
rows = np.arange(ts.num_sites)
1427+
subset = sd1.subset(individuals=cols, sites=rows)
1428+
G2 = np.array([v.genotypes for v in subset.variants()])
1429+
assert np.array_equal(G1[rows][:, cols], G2)
1430+
self.verify_subset_data(sd1, cols, rows)
1431+
14231432
def test_simple_case(self):
14241433
ts = tsutil.get_example_ts(10)
14251434
sd1 = formats.SampleData.from_tree_sequence(ts)

tsinfer/formats.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1372,8 +1372,8 @@ def subset(self, individuals=None, sites=None, **kwargs):
13721372
ploidy=len(individual.samples),
13731373
)
13741374
sample_selection = np.array(sample_selection, dtype=int)
1375-
if len(sample_selection) < 2:
1376-
raise ValueError("Must have at least two samples")
1375+
if len(sample_selection) < 1:
1376+
raise ValueError("Must have at least one sample")
13771377
for variant in self.variants():
13781378
if variant.site.id in sites:
13791379
subset.add_site(
@@ -1576,8 +1576,8 @@ def encode_metadata(metadata, schema):
15761576
return self
15771577

15781578
def _alloc_site_writer(self):
1579-
if self.num_samples < 2:
1580-
raise ValueError("Must have at least 2 samples")
1579+
if self.num_samples < 1:
1580+
raise ValueError("Must have at least 1 sample")
15811581
self.sites_genotypes.resize(0, self.num_samples)
15821582
arrays = {
15831583
"position": self.sites_position,

0 commit comments

Comments
 (0)