Skip to content

Commit 0c86a4f

Browse files
Merge pull request #2326 from abureau/internal-samples
Test_pedigree.py with internal samples
2 parents af5bd81 + 78255b6 commit 0c86a4f

File tree

1 file changed

+24
-2
lines changed

1 file changed

+24
-2
lines changed

tests/test_pedigree.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ def simulate_pedigree(
4848
num_generations=3,
4949
sequence_length=1,
5050
random_seed=42,
51+
sample_gen=None,
5152
) -> tskit.TableCollection:
5253
"""
5354
Simulates pedigree.
@@ -61,12 +62,19 @@ def simulate_pedigree(
6162
num_generations: Number of generations to attempt to simulate
6263
sequence_length: The sequence_length of the output tables.
6364
random_seed: Random seed.
65+
sample_gen: Generations at which all individuals are samples. Defaults
66+
to the first generation (backwards in time).
6467
"""
6568
rng = np.random.RandomState(random_seed)
6669
builder = msprime.PedigreeBuilder()
6770

6871
time = num_generations - 1
69-
curr_gen = [builder.add_individual(time=time) for _ in range(num_founders)]
72+
if sample_gen is None:
73+
sample_gen = [0]
74+
curr_gen = [
75+
builder.add_individual(time=time, is_sample=time in sample_gen)
76+
for _ in range(num_founders)
77+
]
7078
for generation in range(1, num_generations):
7179
num_pairs = len(curr_gen) // 2
7280
if num_pairs == 0 and num_children_prob[0] != 1:
@@ -80,7 +88,9 @@ def simulate_pedigree(
8088
num_children = rng.choice(len(num_children_prob), p=num_children_prob)
8189
for _ in range(num_children):
8290
parents = np.sort(parents).astype(np.int32)
83-
ind_id = builder.add_individual(time=time, parents=parents)
91+
ind_id = builder.add_individual(
92+
time=time, parents=parents, is_sample=time in sample_gen
93+
)
8494
curr_gen.append(ind_id)
8595
return builder.finalise(sequence_length)
8696

@@ -545,6 +555,18 @@ def test_shallow(self, num_founders, recombination_rate):
545555
)
546556
self.verify(tables, recombination_rate)
547557

558+
@pytest.mark.parametrize("num_founders", [2, 3, 5])
559+
@pytest.mark.parametrize("recombination_rate", [0, 0.01])
560+
def test_shallow_internal(self, num_founders, recombination_rate):
561+
tables = simulate_pedigree(
562+
num_founders=num_founders,
563+
num_children_prob=[0, 0, 1],
564+
num_generations=2,
565+
sequence_length=100,
566+
sample_gen=[0, 1],
567+
)
568+
self.verify(tables, recombination_rate)
569+
548570
@pytest.mark.parametrize("num_founders", [2, 3, 10, 20])
549571
@pytest.mark.parametrize("recombination_rate", [0, 0.01])
550572
def test_deep(self, num_founders, recombination_rate):

0 commit comments

Comments
 (0)