Skip to content

Commit bc3ebdd

Browse files
Merge pull request #561 from jeromekelleher/more-package-moving
Finalise external API
2 parents 238ff6b + 79b1818 commit bc3ebdd

File tree

11 files changed

+369
-348
lines changed

11 files changed

+369
-348
lines changed

sc2ts/__init__.py

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,6 @@
11
from .core import __version__
22

3-
4-
from .dataset import decode_alignment, Dataset
5-
6-
from .stats import *
7-
8-
# FIXME
3+
# star imports are fine here as it's just a bunch of constants
94
from .core import *
10-
from .dataset import *
11-
12-
from .inference import *
13-
from .validation import *
14-
from .tree_ops import *
5+
from .dataset import mask_ambiguous, mask_flanking_deletions, decode_alignment, Dataset
6+
from .stats import node_data, mutation_data

sc2ts/cli.py

Lines changed: 26 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,10 @@
2424
import sc2ts
2525
from . import core
2626
from . import data_import
27+
from . import tree_ops
2728
from . import jit
29+
from . import validation
30+
from . import inference as si # sc2ts inference
2831

2932
logger = logging.getLogger(__name__)
3033

@@ -186,7 +189,7 @@ def import_metadata(dataset, metadata, field_descriptions, viridian, verbose):
186189
df_in = pd.read_csv(metadata, sep="\t", dtype=dtype)
187190
index_field = "Run"
188191
if viridian:
189-
df_in = sc2ts.massage_viridian_metadata(df_in)
192+
df_in = data_import.massage_viridian_metadata(df_in)
190193
df = df_in.set_index(index_field)
191194
d = {}
192195
if field_descriptions is not None:
@@ -232,7 +235,7 @@ def info_matches(match_db, all_matches, verbose):
232235
Information about matches in the MatchDB
233236
"""
234237
setup_logging(verbose)
235-
with sc2ts.MatchDb(match_db) as db:
238+
with si.MatchDb(match_db) as db:
236239
if all_matches:
237240
list_all_matches(db)
238241
else:
@@ -261,7 +264,7 @@ def info_dataset(dataset, verbose, zarr_details):
261264
def _run_extend(out_path, verbose, log_file, **params):
262265
date = params["date"]
263266
setup_logging(verbose, log_file, date=date)
264-
ts = sc2ts.extend(show_progress=True, **params)
267+
ts = si.extend(show_progress=True, **params)
265268
ts.dump(out_path)
266269
resource_usage = summarise_usage(ts)
267270
logger.info(resource_usage)
@@ -317,15 +320,15 @@ def infer(config_file, start, stop, force):
317320
f"Do you want to overwrite MatchDB at {match_db}",
318321
abort=True,
319322
)
320-
init_ts = sc2ts.initial_ts(exclude_sites)
321-
sc2ts.MatchDb.initialise(match_db)
323+
init_ts = si.initial_ts(exclude_sites)
324+
si.MatchDb.initialise(match_db)
322325
base_ts = results_dir / f"{run_id}_init.ts"
323326
init_ts.dump(base_ts)
324327
start = "2000"
325328
else:
326329
base_ts = find_previous_date_path(start, ts_file_pattern)
327330
print(f"Starting from {base_ts}")
328-
with sc2ts.MatchDb(match_db) as mdb:
331+
with si.MatchDb(match_db) as mdb:
329332
newer_matches = mdb.count_newer(start)
330333
if newer_matches > 0:
331334
if not force:
@@ -430,9 +433,9 @@ def validate(
430433
dataset, date_field=date_field, chunk_cache_size=chunk_cache_size
431434
)
432435
if genotypes:
433-
sc2ts.validate_genotypes(ts, ds, deletions_as_missing, show_progress=True)
436+
validation.validate_genotypes(ts, ds, deletions_as_missing, show_progress=True)
434437
if metadata:
435-
sc2ts.validate_metadata(ts, ds, skip_fields=set(skip), show_progress=True)
438+
validation.validate_metadata(ts, ds, skip_fields=set(skip), show_progress=True)
436439

437440

438441
@click.command()
@@ -481,7 +484,7 @@ def run_hmm(
481484
"""
482485
setup_logging(verbose, log_file)
483486

484-
runs = sc2ts.run_hmm(
487+
runs = si.run_hmm(
485488
dataset,
486489
ts_path,
487490
strains=strains,
@@ -517,14 +520,14 @@ def postprocess(
517520
setup_logging(verbose, log_file)
518521
ts = tszip.load(ts_in)
519522
if match_db is not None:
520-
with sc2ts.MatchDb(match_db) as db:
521-
ts = sc2ts.append_exact_matches(ts, db, show_progress=progress)
523+
with si.MatchDb(match_db) as db:
524+
ts = si.append_exact_matches(ts, db, show_progress=progress)
522525

523-
ts = sc2ts.push_up_unary_recombinant_mutations(ts)
526+
ts = si.push_up_unary_recombinant_mutations(ts)
524527
# See if we can remove some of the reversions in a straightforward way.
525-
mutations_is_reversion = sc2ts.find_reversions(ts)
528+
mutations_is_reversion = si.find_reversions(ts)
526529
mutations_before = ts.num_mutations
527-
ts = sc2ts.push_up_reversions(
530+
ts = tree_ops.push_up_reversions(
528531
ts, ts.mutations_node[mutations_is_reversion], date=None
529532
)
530533
ts.dump(ts_out)
@@ -569,9 +572,9 @@ def minimise_metadata(
569572
field_mapping = dict(field_mapping)
570573
setup_logging(verbose, log_file)
571574
ts = tszip.load(ts_in)
572-
ts = sc2ts.minimise_metadata(ts, field_mapping, show_progress=progress)
575+
ts = si.minimise_metadata(ts, field_mapping, show_progress=progress)
573576
if drop_vestigial_root:
574-
ts = sc2ts.drop_vestigial_root_edge(ts)
577+
ts = tree_ops.drop_vestigial_root_edge(ts)
575578
ts.dump(ts_out)
576579

577580

@@ -602,7 +605,7 @@ def map_parsimony(
602605
ts = tszip.load(ts_in)
603606
if sites is not None:
604607
sites = np.loadtxt(sites, dtype=int)
605-
result = sc2ts.map_parsimony(ts, ds, sites, show_progress=progress)
608+
result = si.map_parsimony(ts, ds, sites, show_progress=progress)
606609
if report is not None:
607610
result.report.to_csv(report)
608611
result.tree_sequence.dump(ts_out)
@@ -630,7 +633,7 @@ def apply_node_parsimony(
630633
setup_logging(verbose, log_file)
631634
ts = tszip.load(ts_in)
632635

633-
result = sc2ts.apply_node_parsimony_heuristics(ts, show_progress=progress)
636+
result = si.apply_node_parsimony_heuristics(ts, show_progress=progress)
634637
if report is not None:
635638
result.report.to_csv(report)
636639
result.tree_sequence.dump(ts_out)
@@ -667,7 +670,7 @@ def rematch_recombinant(
667670

668671
base_ts = tszip.load(base_ts)
669672
recomb_ts = tszip.load(recomb_ts)
670-
result = sc2ts.rematch_recombinant(
673+
result = si.rematch_recombinant(
671674
base_ts, recomb_ts, node_id, num_mismatches=num_mismatches
672675
)
673676
print(json.dumps(result.asdict()))
@@ -687,7 +690,7 @@ def rematch_recombinant_lbs(ts, node_id, num_mismatches, verbose, log_file):
687690
setup_logging(verbose, log_file)
688691

689692
ts = tszip.load(ts)
690-
result = sc2ts.rematch_recombinant_lbs(ts, node_id, num_mismatches=num_mismatches)
693+
result = si.rematch_recombinant_lbs(ts, node_id, num_mismatches=num_mismatches)
691694
print(json.dumps(result.asdict()))
692695

693696

@@ -709,7 +712,7 @@ def rewire_lbs(ts_in, rematch_data, ts_out, verbose, log_file):
709712
records = []
710713
with open(rematch_data) as f:
711714
for d in json.load(f):
712-
records.append(sc2ts.RematchRecombinantsLbsResult.fromdict(d))
715+
records.append(si.RematchRecombinantsLbsResult.fromdict(d))
713716

714717
recombs_to_rewire = []
715718
rewire_existing = 0
@@ -729,8 +732,8 @@ def rewire_lbs(ts_in, rematch_data, ts_out, verbose, log_file):
729732
f"(existing={rewire_existing} lbs={rewire_lbs})"
730733
)
731734

732-
ts = sc2ts.push_up_unary_recombinant_mutations(ts)
733-
ts = sc2ts.rewire_long_branch_splits(ts, recombs_to_rewire)
735+
ts = si.push_up_unary_recombinant_mutations(ts)
736+
ts = si.rewire_long_branch_splits(ts, recombs_to_rewire)
734737
ts.dump(ts_out)
735738

736739

sc2ts/data_import.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,3 +101,43 @@ def get_flank_coordinates():
101101
return np.concatenate(
102102
(np.arange(1, start), np.arange(end, REFERENCE_SEQUENCE_LENGTH))
103103
)
104+
105+
106+
def massage_viridian_metadata(df):
107+
"""
108+
Takes a pandas dataframe indexex by sample ID and massages it
109+
so that the returned dataframe has consistent types:
110+
111+
- bool T/F columns encoded as booleans
112+
- integer columns encoded with -1 as N/A
113+
"""
114+
# print(df)
115+
bool_cols = [name for name in df if name.startswith("In")]
116+
N = df.shape[0]
117+
for name in bool_cols:
118+
data = df[name]
119+
assert set(data.unique()) <= set(["F", "T"])
120+
a = np.zeros(N, dtype=bool)
121+
a[data == "T"] = 1
122+
df[name] = a
123+
int_fields = [
124+
"Genbank_N",
125+
"Viridian_N",
126+
"Run_count",
127+
"Viridian_cons_len",
128+
"Viridian_cons_het",
129+
]
130+
for name in int_fields:
131+
try:
132+
data = df[name]
133+
except KeyError:
134+
continue
135+
if str(data.dtype) == "int64":
136+
continue
137+
a = np.zeros(N, dtype=int)
138+
missing = data == "."
139+
a[missing] = -1
140+
a[~missing] = np.array(data[~missing], dtype=int)
141+
df[name] = a
142+
return df
143+

sc2ts/dataset.py

Lines changed: 18 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -29,43 +29,24 @@ def decode_alignment(a):
2929
return alleles[a]
3030

3131

32-
def massage_viridian_metadata(df):
33-
"""
34-
Takes a pandas dataframe indexex by sample ID and massages it
35-
so that the returned dataframe has consistent types:
36-
37-
- bool T/F columns encoded as booleans
38-
- integer columns encoded with -1 as N/A
39-
"""
40-
# print(df)
41-
bool_cols = [name for name in df if name.startswith("In")]
42-
N = df.shape[0]
43-
for name in bool_cols:
44-
data = df[name]
45-
assert set(data.unique()) <= set(["F", "T"])
46-
a = np.zeros(N, dtype=bool)
47-
a[data == "T"] = 1
48-
df[name] = a
49-
int_fields = [
50-
"Genbank_N",
51-
"Viridian_N",
52-
"Run_count",
53-
"Viridian_cons_len",
54-
"Viridian_cons_het",
55-
]
56-
for name in int_fields:
57-
try:
58-
data = df[name]
59-
except KeyError:
60-
continue
61-
if str(data.dtype) == "int64":
62-
continue
63-
a = np.zeros(N, dtype=int)
64-
missing = data == "."
65-
a[missing] = -1
66-
a[~missing] = np.array(data[~missing], dtype=int)
67-
df[name] = a
68-
return df
32+
DELETION = core.IUPAC_ALLELES.index("-")
33+
34+
35+
def mask_ambiguous(a):
36+
a = a.copy()
37+
a[a > DELETION] = -1
38+
return a
39+
40+
41+
def mask_flanking_deletions(a):
42+
a = a.copy()
43+
non_dels = np.nonzero(a != DELETION)[0]
44+
if len(non_dels) == 0:
45+
a[:] = -1
46+
else:
47+
a[: non_dels[0]] = -1
48+
a[non_dels[-1] + 1 :] = -1
49+
return a
6950

7051

7152
def readahead_retrieve(array, blocks):

0 commit comments

Comments
 (0)