@@ -48,6 +48,7 @@ def simulate_pedigree(
4848 num_generations = 3 ,
4949 sequence_length = 1 ,
5050 random_seed = 42 ,
51+ internal_sample_gen = (None ,None ,None ),
5152) -> tskit .TableCollection :
5253 """
5354 Simulates pedigree.
@@ -62,11 +63,16 @@ def simulate_pedigree(
6263 sequence_length: The sequence_length of the output tables.
6364 random_seed: Random seed.
6465 """
66+ # Fill-in internal_sample_gen with None if shorter than number of generations
67+ if len (internal_sample_gen ) < num_generations :
68+ tmp = internal_sample_gen
69+ internal_sample_gen = np .repeat (None ,num_generations )
70+ internal_sample_gen [0 :len (tmp )] = tmp
6571 rng = np .random .RandomState (random_seed )
6672 builder = msprime .PedigreeBuilder ()
6773
6874 time = num_generations - 1
69- curr_gen = [builder .add_individual (time = time ) for _ in range (num_founders )]
75+ curr_gen = [builder .add_individual (time = time , is_sample = internal_sample_gen [ 0 ] ) for _ in range (num_founders )]
7076 for generation in range (1 , num_generations ):
7177 num_pairs = len (curr_gen ) // 2
7278 if num_pairs == 0 and num_children_prob [0 ] != 1 :
@@ -80,7 +86,7 @@ def simulate_pedigree(
8086 num_children = rng .choice (len (num_children_prob ), p = num_children_prob )
8187 for _ in range (num_children ):
8288 parents = np .sort (parents ).astype (np .int32 )
83- ind_id = builder .add_individual (time = time , parents = parents )
89+ ind_id = builder .add_individual (time = time , parents = parents , is_sample = internal_sample_gen [ generation ] )
8490 curr_gen .append (ind_id )
8591 return builder .finalise (sequence_length )
8692
@@ -544,6 +550,18 @@ def test_shallow(self, num_founders, recombination_rate):
544550 sequence_length = 100 ,
545551 )
546552 self .verify (tables , recombination_rate )
553+
554+ @pytest .mark .parametrize ("num_founders" , [2 , 3 , 5 , 100 ])
555+ @pytest .mark .parametrize ("recombination_rate" , [0 , 0.01 ])
556+ def test_shallow_internal (self , num_founders , recombination_rate ):
557+ tables = simulate_pedigree (
558+ num_founders = num_founders ,
559+ num_children_prob = [0 , 0 , 1 ],
560+ num_generations = 2 ,
561+ sequence_length = 100 ,
562+ internal_sample_gen = [True , False ],
563+ )
564+ self .verify (tables , recombination_rate )
547565
548566 @pytest .mark .parametrize ("num_founders" , [2 , 3 , 10 , 20 ])
549567 @pytest .mark .parametrize ("recombination_rate" , [0 , 0.01 ])
0 commit comments