@@ -48,7 +48,7 @@ def simulate_pedigree(
4848 num_generations = 3 ,
4949 sequence_length = 1 ,
5050 random_seed = 42 ,
51- internal_sample_gen = ( None , None , None ) ,
51+ sample_gen = None ,
5252) -> tskit .TableCollection :
5353 """
5454 Simulates pedigree.
@@ -62,17 +62,19 @@ def simulate_pedigree(
6262 num_generations: Number of generations to attempt to simulate
6363 sequence_length: The sequence_length of the output tables.
6464 random_seed: Random seed.
65+ sample_gen: Generations at which all individuals are samples. Defaults
66+ to the first generation.
6567 """
66- # Fill-in internal_sample_gen with None if shorter than number of generations
67- if len (internal_sample_gen ) < num_generations :
68- tmp = internal_sample_gen
69- internal_sample_gen = np .repeat (None ,num_generations )
70- internal_sample_gen [0 :len (tmp )] = tmp
7168 rng = np .random .RandomState (random_seed )
7269 builder = msprime .PedigreeBuilder ()
7370
7471 time = num_generations - 1
75- curr_gen = [builder .add_individual (time = time ,is_sample = internal_sample_gen [0 ]) for _ in range (num_founders )]
72+ if sample_gen is None :
73+ sample_gen = [time ]
74+ curr_gen = [
75+ builder .add_individual (time = time , is_sample = time in sample_gen )
76+ for _ in range (num_founders )
77+ ]
7678 for generation in range (1 , num_generations ):
7779 num_pairs = len (curr_gen ) // 2
7880 if num_pairs == 0 and num_children_prob [0 ] != 1 :
@@ -86,7 +88,9 @@ def simulate_pedigree(
8688 num_children = rng .choice (len (num_children_prob ), p = num_children_prob )
8789 for _ in range (num_children ):
8890 parents = np .sort (parents ).astype (np .int32 )
89- ind_id = builder .add_individual (time = time , parents = parents , is_sample = internal_sample_gen [generation ])
91+ ind_id = builder .add_individual (
92+ time = time , parents = parents , is_sample = time in sample_gen
93+ )
9094 curr_gen .append (ind_id )
9195 return builder .finalise (sequence_length )
9296
@@ -550,16 +554,16 @@ def test_shallow(self, num_founders, recombination_rate):
550554 sequence_length = 100 ,
551555 )
552556 self .verify (tables , recombination_rate )
553-
554- @pytest .mark .parametrize ("num_founders" , [2 , 3 , 5 , 100 ])
557+
558+ @pytest .mark .parametrize ("num_founders" , [2 , 3 , 5 ])
555559 @pytest .mark .parametrize ("recombination_rate" , [0 , 0.01 ])
556560 def test_shallow_internal (self , num_founders , recombination_rate ):
557561 tables = simulate_pedigree (
558562 num_founders = num_founders ,
559563 num_children_prob = [0 , 0 , 1 ],
560564 num_generations = 2 ,
561- sequence_length = 100 ,
562- internal_sample_gen = [ True , False ],
565+ sequence_length = 100 ,
566+ sample_gen = [ 0 , 1 ],
563567 )
564568 self .verify (tables , recombination_rate )
565569
0 commit comments