@@ -29,6 +29,12 @@ def test_missing_dependency():
2929 )
3030
3131
32+ def tskit_model_mapping (ind_nodes , ind_names = None ):
33+ if ind_names is None :
34+ ind_names = ["tsk{j}" for j in range (len (ind_nodes ))]
35+ return tskit .VcfModelMapping (ind_nodes , ind_names )
36+
37+
3238def add_mutations (ts ):
3339 # Add some mutation to the tree sequence. This guarantees that
3440 # we have variation at all sites > 0.
@@ -88,15 +94,6 @@ def insert_branch_sites(ts, m=1):
8894 return tables .tree_sequence ()
8995
9096
91- @pytest .fixture ()
92- def fx_ts_isolated_samples ():
93- tables = tskit .Tree .generate_balanced (2 , span = 10 ).tree_sequence .dump_tables ()
94- # This also tests sample nodes that are not a single block at
95- # the start of the nodes table.
96- tables .nodes .add_row (time = 0 , flags = tskit .NODE_IS_SAMPLE )
97- return insert_branch_sites (tables .tree_sequence ())
98-
99-
10097class TestSimpleTs :
10198 @pytest .fixture ()
10299 def conversion (self , tmp_path ):
@@ -193,17 +190,28 @@ class TestTskitFormat:
193190 """Unit tests for TskitFormat without using full conversion."""
194191
195192 @pytest .fixture ()
196- def fx_simple_ts (self , tmp_path ):
193+ def fx_simple_ts (self ):
197194 return simple_ts (add_individuals = True )
198195
199196 @pytest .fixture ()
200- def fx_ts_2_diploids (self , tmp_path ):
197+ def fx_ts_2_diploids (self ):
201198 ts = msprime .sim_ancestry (2 , sequence_length = 10 , random_seed = 42 )
202199 return add_mutations (ts )
203200
204201 @pytest .fixture ()
205- def fx_no_individuals_ts (self , tmp_path ):
206- return simple_ts (add_individuals = False )
202+ def fx_ts_isolated_samples (self ):
203+ tables = tskit .Tree .generate_balanced (2 , span = 10 ).tree_sequence .dump_tables ()
204+ # This also tests sample nodes that are not a single block at
205+ # the start of the nodes table.
206+ tables .nodes .add_row (time = 0 , flags = tskit .NODE_IS_SAMPLE )
207+ return insert_branch_sites (tables .tree_sequence ())
208+
209+ def test_path_or_ts_input (self , tmp_path , fx_simple_ts ):
210+ f1 = tsk .TskitFormat (fx_simple_ts )
211+ ts_path = tmp_path / "trees.ts"
212+ fx_simple_ts .dump (ts_path )
213+ f2 = tsk .TskitFormat (ts_path )
214+ f1 .ts .tables .assert_equals (f2 .ts .tables )
207215
208216 def test_small_position_dtype (self ):
209217 tables = tskit .TableCollection (sequence_length = 100 )
@@ -311,6 +319,23 @@ def test_iter_field(self, fx_simple_ts):
311319 with pytest .raises (ValueError , match = "Unknown field" ):
312320 list (format_obj .iter_field ("unknown_field" , None , 0 , 3 ))
313321
322+ def test_zero_samples (self , fx_simple_ts ):
323+ model_mapping = tskit_model_mapping (np .array ([]))
324+ with pytest .raises (ValueError , match = "at least one sample" ):
325+ tsk .TskitFormat (fx_simple_ts , model_mapping = model_mapping )
326+
327+ def test_no_valid_samples (self , fx_simple_ts ):
328+ model_mapping = fx_simple_ts .map_to_vcf_model ()
329+ model_mapping .individuals_nodes [:] = - 1
330+ with pytest .raises (ValueError , match = "at least one valid sample" ):
331+ tsk .TskitFormat (fx_simple_ts , model_mapping = model_mapping )
332+
333+ def test_model_size_mismatch (self , fx_simple_ts ):
334+ model_mapping = fx_simple_ts .map_to_vcf_model ()
335+ model_mapping .individuals_name = ["x" ]
336+ with pytest .raises (ValueError , match = "match number of samples" ):
337+ tsk .TskitFormat (fx_simple_ts , model_mapping = model_mapping )
338+
314339 @pytest .mark .parametrize (
315340 ("ind_nodes" , "expected_gts" ),
316341 [
@@ -347,10 +372,7 @@ def test_iter_field(self, fx_simple_ts):
347372 ],
348373 )
349374 def test_iter_alleles_and_genotypes (self , fx_simple_ts , ind_nodes , expected_gts ):
350- model_mapping = tskit .VcfModelMapping (
351- ind_nodes , ["tsk{j}" for j in range (len (ind_nodes ))]
352- )
353-
375+ model_mapping = tskit_model_mapping (ind_nodes )
354376 format_obj = tsk .TskitFormat (fx_simple_ts , model_mapping = model_mapping )
355377
356378 shape = (2 , 2 ) # (num_samples, max_ploidy)
@@ -375,9 +397,7 @@ def test_iter_alleles_and_genotypes(self, fx_simple_ts, ind_nodes, expected_gts)
375397 def test_iter_alleles_and_genotypes_missing_node (self , fx_ts_2_diploids ):
376398 # Test with node ID that doesn't exist in tree sequence (out of range)
377399 ind_nodes = np .array ([[10 , 11 ], [12 , 13 ]], dtype = np .int32 )
378- model_mapping = tskit .VcfModelMapping (
379- ind_nodes , ["tsk{j}" for j in range (len (ind_nodes ))]
380- )
400+ model_mapping = tskit_model_mapping (ind_nodes )
381401 format_obj = tsk .TskitFormat (fx_ts_2_diploids , model_mapping = model_mapping )
382402 shape = (2 , 2 )
383403 with pytest .raises (
@@ -387,9 +407,7 @@ def test_iter_alleles_and_genotypes_missing_node(self, fx_ts_2_diploids):
387407
388408 def test_isolated_as_missing (self , fx_ts_isolated_samples ):
389409 ind_nodes = np .array ([[0 ], [1 ], [3 ]])
390- model_mapping = tskit .VcfModelMapping (
391- ind_nodes , ["tsk{j}" for j in range (len (ind_nodes ))]
392- )
410+ model_mapping = tskit_model_mapping (ind_nodes )
393411
394412 format_obj_default = tsk .TskitFormat (
395413 fx_ts_isolated_samples ,
@@ -427,7 +445,7 @@ def test_isolated_as_missing(self, fx_ts_isolated_samples):
427445 expected_gt_missing = np .array ([[1 ], [0 ], [- 1 ]])
428446 nt .assert_array_equal (variant_data_missing .genotypes , expected_gt_missing )
429447
430- def test_genotype_dtype_i1 (self , tmp_path ):
448+ def test_genotype_dtype_i1 (self ):
431449 tables = tskit .TableCollection (sequence_length = 100 )
432450 for _ in range (4 ):
433451 tables .nodes .add_row (flags = tskit .NODE_IS_SAMPLE , time = 0 )
@@ -438,15 +456,13 @@ def test_genotype_dtype_i1(self, tmp_path):
438456 tables .mutations .add_row (site = site_id , node = 0 , derived_state = "T" )
439457 tables .sort ()
440458 tree_sequence = tables .tree_sequence ()
441- ts_path = tmp_path / "small_alleles.trees"
442- tree_sequence .dump (ts_path )
443459
444- format_obj = tsk .TskitFormat (ts_path )
460+ format_obj = tsk .TskitFormat (tree_sequence )
445461 schema = format_obj .generate_schema ()
446462 call_genotype_spec = next (s for s in schema .fields if s .name == "call_genotype" )
447463 assert call_genotype_spec .dtype == "i1"
448464
449- def test_genotype_dtype_i4 (self , tmp_path ):
465+ def test_genotype_dtype_i4 (self ):
450466 tables = tskit .TableCollection (sequence_length = 100 )
451467 for _ in range (4 ):
452468 tables .nodes .add_row (flags = tskit .NODE_IS_SAMPLE , time = 0 )
@@ -459,10 +475,8 @@ def test_genotype_dtype_i4(self, tmp_path):
459475
460476 tables .sort ()
461477 tree_sequence = tables .tree_sequence ()
462- ts_path = tmp_path / "large_alleles.trees"
463- tree_sequence .dump (ts_path )
464478
465- format_obj = tsk .TskitFormat (ts_path )
479+ format_obj = tsk .TskitFormat (tree_sequence )
466480 schema = format_obj .generate_schema ()
467481 call_genotype_spec = next (s for s in schema .fields if s .name == "call_genotype" )
468482 assert call_genotype_spec .dtype == "i4"
@@ -471,6 +485,7 @@ def test_genotype_dtype_i4(self, tmp_path):
471485@pytest .mark .parametrize (
472486 "ts" ,
473487 [
488+ # Standard individuals-with-a-given-ploidy situation
474489 add_mutations (
475490 msprime .sim_ancestry (4 , ploidy = 1 , sequence_length = 10 , random_seed = 42 )
476491 ),
@@ -480,20 +495,20 @@ def test_genotype_dtype_i4(self, tmp_path):
480495 add_mutations (
481496 msprime .sim_ancestry (3 , ploidy = 12 , sequence_length = 10 , random_seed = 142 )
482497 ),
498+ # No individuals, ploidy1
499+ add_mutations (msprime .simulate (4 , length = 10 , random_seed = 412 )),
483500 ],
484501)
485502def test_against_tskit_vcf_output (ts , tmp_path ):
486503 vcf_path = tmp_path / "ts.vcf"
487- ts_path = tmp_path / "ts.trees"
488- ts .dump (ts_path )
489504 with open (vcf_path , "w" ) as f :
490505 ts .write_vcf (f )
491506
492507 tskit_zarr = tmp_path / "tskit.zarr"
493508 vcf_zarr = tmp_path / "vcf.zarr"
494- tsk .convert (ts_path , tskit_zarr )
509+ tsk .convert (ts , tskit_zarr , worker_processes = 0 )
495510
496- vcf .convert ([vcf_path ], vcf_zarr )
511+ vcf .convert ([vcf_path ], vcf_zarr , worker_processes = 0 )
497512 ds1 = sg .load_dataset (tskit_zarr )
498513 ds2 = (
499514 sg .load_dataset (vcf_zarr )
0 commit comments