@@ -48,7 +48,7 @@ def simulate_pedigree(
4848 num_generations = 3 ,
4949 sequence_length = 1 ,
5050 random_seed = 42 ,
51- internal_sample_gen = (False , False , False ),
51+ internal_sample_gen = (None , None , None ),
5252) -> tskit .TableCollection :
5353 """
5454 Simulates pedigree.
@@ -63,11 +63,16 @@ def simulate_pedigree(
6363 sequence_length: The sequence_length of the output tables.
6464 random_seed: Random seed.
6565 """
66+ # Fill-in internal_sample_gen with None if shorter than number of generations
67+ if len (internal_sample_gen ) < num_generations :
68+ tmp = internal_sample_gen
69+ internal_sample_gen = np .repeat (None ,num_generations )
70+ internal_sample_gen [0 :len (tmp )] = tmp
6671 rng = np .random .RandomState (random_seed )
6772 builder = msprime .PedigreeBuilder ()
6873
6974 time = num_generations - 1
70- curr_gen = [builder .add_individual (time = time ) for _ in range (num_founders )]
75+ curr_gen = [builder .add_individual (time = time , is_sample = internal_sample_gen [ 0 ] ) for _ in range (num_founders )]
7176 for generation in range (1 , num_generations ):
7277 num_pairs = len (curr_gen ) // 2
7378 if num_pairs == 0 and num_children_prob [0 ] != 1 :
@@ -81,7 +86,7 @@ def simulate_pedigree(
8186 num_children = rng .choice (len (num_children_prob ), p = num_children_prob )
8287 for _ in range (num_children ):
8388 parents = np .sort (parents ).astype (np .int32 )
84- ind_id = builder .add_individual (time = time , parents = parents , is_sample = internal_sample_gen [generation - 1 ])
89+ ind_id = builder .add_individual (time = time , parents = parents , is_sample = internal_sample_gen [generation ])
8590 curr_gen .append (ind_id )
8691 return builder .finalise (sequence_length )
8792
0 commit comments