2424import sc2ts
2525from . import core
2626from . import data_import
27+ from . import tree_ops
2728from . import jit
29+ from . import validation
30+ from . import inference as si # sc2ts inference
2831
2932logger = logging .getLogger (__name__ )
3033
@@ -186,7 +189,7 @@ def import_metadata(dataset, metadata, field_descriptions, viridian, verbose):
186189 df_in = pd .read_csv (metadata , sep = "\t " , dtype = dtype )
187190 index_field = "Run"
188191 if viridian :
189- df_in = sc2ts .massage_viridian_metadata (df_in )
192+ df_in = data_import .massage_viridian_metadata (df_in )
190193 df = df_in .set_index (index_field )
191194 d = {}
192195 if field_descriptions is not None :
@@ -232,7 +235,7 @@ def info_matches(match_db, all_matches, verbose):
232235 Information about matches in the MatchDB
233236 """
234237 setup_logging (verbose )
235- with sc2ts .MatchDb (match_db ) as db :
238+ with si .MatchDb (match_db ) as db :
236239 if all_matches :
237240 list_all_matches (db )
238241 else :
@@ -261,7 +264,7 @@ def info_dataset(dataset, verbose, zarr_details):
261264def _run_extend (out_path , verbose , log_file , ** params ):
262265 date = params ["date" ]
263266 setup_logging (verbose , log_file , date = date )
264- ts = sc2ts .extend (show_progress = True , ** params )
267+ ts = si .extend (show_progress = True , ** params )
265268 ts .dump (out_path )
266269 resource_usage = summarise_usage (ts )
267270 logger .info (resource_usage )
@@ -317,15 +320,15 @@ def infer(config_file, start, stop, force):
317320 f"Do you want to overwrite MatchDB at { match_db } " ,
318321 abort = True ,
319322 )
320- init_ts = sc2ts .initial_ts (exclude_sites )
321- sc2ts .MatchDb .initialise (match_db )
323+ init_ts = si .initial_ts (exclude_sites )
324+ si .MatchDb .initialise (match_db )
322325 base_ts = results_dir / f"{ run_id } _init.ts"
323326 init_ts .dump (base_ts )
324327 start = "2000"
325328 else :
326329 base_ts = find_previous_date_path (start , ts_file_pattern )
327330 print (f"Starting from { base_ts } " )
328- with sc2ts .MatchDb (match_db ) as mdb :
331+ with si .MatchDb (match_db ) as mdb :
329332 newer_matches = mdb .count_newer (start )
330333 if newer_matches > 0 :
331334 if not force :
@@ -430,9 +433,9 @@ def validate(
430433 dataset , date_field = date_field , chunk_cache_size = chunk_cache_size
431434 )
432435 if genotypes :
433- sc2ts .validate_genotypes (ts , ds , deletions_as_missing , show_progress = True )
436+ validation .validate_genotypes (ts , ds , deletions_as_missing , show_progress = True )
434437 if metadata :
435- sc2ts .validate_metadata (ts , ds , skip_fields = set (skip ), show_progress = True )
438+ validation .validate_metadata (ts , ds , skip_fields = set (skip ), show_progress = True )
436439
437440
438441@click .command ()
@@ -481,7 +484,7 @@ def run_hmm(
481484 """
482485 setup_logging (verbose , log_file )
483486
484- runs = sc2ts .run_hmm (
487+ runs = si .run_hmm (
485488 dataset ,
486489 ts_path ,
487490 strains = strains ,
@@ -517,14 +520,14 @@ def postprocess(
517520 setup_logging (verbose , log_file )
518521 ts = tszip .load (ts_in )
519522 if match_db is not None :
520- with sc2ts .MatchDb (match_db ) as db :
521- ts = sc2ts .append_exact_matches (ts , db , show_progress = progress )
523+ with si .MatchDb (match_db ) as db :
524+ ts = si .append_exact_matches (ts , db , show_progress = progress )
522525
523- ts = sc2ts .push_up_unary_recombinant_mutations (ts )
526+ ts = si .push_up_unary_recombinant_mutations (ts )
524527 # See if we can remove some of the reversions in a straightforward way.
525- mutations_is_reversion = sc2ts .find_reversions (ts )
528+ mutations_is_reversion = si .find_reversions (ts )
526529 mutations_before = ts .num_mutations
527- ts = sc2ts .push_up_reversions (
530+ ts = tree_ops .push_up_reversions (
528531 ts , ts .mutations_node [mutations_is_reversion ], date = None
529532 )
530533 ts .dump (ts_out )
@@ -569,9 +572,9 @@ def minimise_metadata(
569572 field_mapping = dict (field_mapping )
570573 setup_logging (verbose , log_file )
571574 ts = tszip .load (ts_in )
572- ts = sc2ts .minimise_metadata (ts , field_mapping , show_progress = progress )
575+ ts = si .minimise_metadata (ts , field_mapping , show_progress = progress )
573576 if drop_vestigial_root :
574- ts = sc2ts .drop_vestigial_root_edge (ts )
577+ ts = tree_ops .drop_vestigial_root_edge (ts )
575578 ts .dump (ts_out )
576579
577580
@@ -602,7 +605,7 @@ def map_parsimony(
602605 ts = tszip .load (ts_in )
603606 if sites is not None :
604607 sites = np .loadtxt (sites , dtype = int )
605- result = sc2ts .map_parsimony (ts , ds , sites , show_progress = progress )
608+ result = si .map_parsimony (ts , ds , sites , show_progress = progress )
606609 if report is not None :
607610 result .report .to_csv (report )
608611 result .tree_sequence .dump (ts_out )
@@ -630,7 +633,7 @@ def apply_node_parsimony(
630633 setup_logging (verbose , log_file )
631634 ts = tszip .load (ts_in )
632635
633- result = sc2ts .apply_node_parsimony_heuristics (ts , show_progress = progress )
636+ result = si .apply_node_parsimony_heuristics (ts , show_progress = progress )
634637 if report is not None :
635638 result .report .to_csv (report )
636639 result .tree_sequence .dump (ts_out )
@@ -667,7 +670,7 @@ def rematch_recombinant(
667670
668671 base_ts = tszip .load (base_ts )
669672 recomb_ts = tszip .load (recomb_ts )
670- result = sc2ts .rematch_recombinant (
673+ result = si .rematch_recombinant (
671674 base_ts , recomb_ts , node_id , num_mismatches = num_mismatches
672675 )
673676 print (json .dumps (result .asdict ()))
@@ -687,7 +690,7 @@ def rematch_recombinant_lbs(ts, node_id, num_mismatches, verbose, log_file):
687690 setup_logging (verbose , log_file )
688691
689692 ts = tszip .load (ts )
690- result = sc2ts .rematch_recombinant_lbs (ts , node_id , num_mismatches = num_mismatches )
693+ result = si .rematch_recombinant_lbs (ts , node_id , num_mismatches = num_mismatches )
691694 print (json .dumps (result .asdict ()))
692695
693696
@@ -709,7 +712,7 @@ def rewire_lbs(ts_in, rematch_data, ts_out, verbose, log_file):
709712 records = []
710713 with open (rematch_data ) as f :
711714 for d in json .load (f ):
712- records .append (sc2ts .RematchRecombinantsLbsResult .fromdict (d ))
715+ records .append (si .RematchRecombinantsLbsResult .fromdict (d ))
713716
714717 recombs_to_rewire = []
715718 rewire_existing = 0
@@ -729,8 +732,8 @@ def rewire_lbs(ts_in, rematch_data, ts_out, verbose, log_file):
729732 f"(existing={ rewire_existing } lbs={ rewire_lbs } )"
730733 )
731734
732- ts = sc2ts .push_up_unary_recombinant_mutations (ts )
733- ts = sc2ts .rewire_long_branch_splits (ts , recombs_to_rewire )
735+ ts = si .push_up_unary_recombinant_mutations (ts )
736+ ts = si .rewire_long_branch_splits (ts , recombs_to_rewire )
734737 ts .dump (ts_out )
735738
736739
0 commit comments