@@ -48,6 +48,7 @@ def simulate_pedigree(
4848 num_generations = 3 ,
4949 sequence_length = 1 ,
5050 random_seed = 42 ,
51+ sample_gen = None ,
5152) -> tskit .TableCollection :
5253 """
5354 Simulates pedigree.
@@ -61,12 +62,19 @@ def simulate_pedigree(
6162 num_generations: Number of generations to attempt to simulate
6263 sequence_length: The sequence_length of the output tables.
6364 random_seed: Random seed.
65+ sample_gen: Generations at which all individuals are samples. Defaults
66+ to the first generation (backwards in time).
6467 """
6568 rng = np .random .RandomState (random_seed )
6669 builder = msprime .PedigreeBuilder ()
6770
6871 time = num_generations - 1
69- curr_gen = [builder .add_individual (time = time ) for _ in range (num_founders )]
72+ if sample_gen is None :
73+ sample_gen = [0 ]
74+ curr_gen = [
75+ builder .add_individual (time = time , is_sample = time in sample_gen )
76+ for _ in range (num_founders )
77+ ]
7078 for generation in range (1 , num_generations ):
7179 num_pairs = len (curr_gen ) // 2
7280 if num_pairs == 0 and num_children_prob [0 ] != 1 :
@@ -80,7 +88,9 @@ def simulate_pedigree(
8088 num_children = rng .choice (len (num_children_prob ), p = num_children_prob )
8189 for _ in range (num_children ):
8290 parents = np .sort (parents ).astype (np .int32 )
83- ind_id = builder .add_individual (time = time , parents = parents )
91+ ind_id = builder .add_individual (
92+ time = time , parents = parents , is_sample = time in sample_gen
93+ )
8494 curr_gen .append (ind_id )
8595 return builder .finalise (sequence_length )
8696
@@ -545,6 +555,18 @@ def test_shallow(self, num_founders, recombination_rate):
545555 )
546556 self .verify (tables , recombination_rate )
547557
558+ @pytest .mark .parametrize ("num_founders" , [2 , 3 , 5 ])
559+ @pytest .mark .parametrize ("recombination_rate" , [0 , 0.01 ])
560+ def test_shallow_internal (self , num_founders , recombination_rate ):
561+ tables = simulate_pedigree (
562+ num_founders = num_founders ,
563+ num_children_prob = [0 , 0 , 1 ],
564+ num_generations = 2 ,
565+ sequence_length = 100 ,
566+ sample_gen = [0 , 1 ],
567+ )
568+ self .verify (tables , recombination_rate )
569+
548570 @pytest .mark .parametrize ("num_founders" , [2 , 3 , 10 , 20 ])
549571 @pytest .mark .parametrize ("recombination_rate" , [0 , 0.01 ])
550572 def test_deep (self , num_founders , recombination_rate ):
0 commit comments