Skip to content

Commit 238ff6b

Browse files
Merge pull request #560 from jeromekelleher/more-tidying
Move some stuff around
2 parents 0a24976 + fcd4ae5 commit 238ff6b

File tree

13 files changed

+174
-29189
lines changed

13 files changed

+174
-29189
lines changed

sc2ts/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
from .core import __version__
22

3+
4+
from .dataset import decode_alignment, Dataset
5+
36
from .stats import *
47

58
# FIXME

sc2ts/cli.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@
2323

2424
import sc2ts
2525
from . import core
26+
from . import data_import
27+
from . import jit
2628

2729
logger = logging.getLogger(__name__)
2830

@@ -141,7 +143,7 @@ def import_alignments(dataset, fastas, initialise, progress, verbose):
141143

142144
f_bar = tqdm.tqdm(sorted(fastas), desc="Files", disable=not progress, position=0)
143145
for fasta_path in f_bar:
144-
reader = core.FastaReader(fasta_path, add_zero_base=False)
146+
reader = data_import.FastaReader(fasta_path, add_zero_base=False)
145147
logger.info(f"Reading {len(reader)} alignments from {fasta_path}")
146148
alignments = {}
147149
a_bar = tqdm.tqdm(
@@ -152,7 +154,7 @@ def import_alignments(dataset, fastas, initialise, progress, verbose):
152154
position=1,
153155
)
154156
for k, v in a_bar:
155-
alignments[k] = sc2ts.encode_alignment(v)
157+
alignments[k] = jit.encode_alignment(v)
156158
sc2ts.Dataset.append_alignments(dataset, alignments)
157159

158160

sc2ts/core.py

Lines changed: 7 additions & 169 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,6 @@
11
import dataclasses
2-
import json
3-
import pathlib
4-
import collections.abc
5-
import csv
62

73
import tskit
8-
import numba
9-
import pyfaidx
10-
import numpy as np
11-
12-
from . import jit
134

145
__version__ = "undefined"
156
try:
@@ -26,6 +17,13 @@
2617
REFERENCE_GENBANK = "MN908947"
2718
REFERENCE_SEQUENCE_LENGTH = 29904
2819

20+
# We omit N here as it's mapped to -1. Make "-" the 5th allele
21+
# as this is a valid allele for us.
22+
# NOTE!! This string is also used in the jit module where it's
23+
# hard-coded into a numba function, so if this ever changes
24+
# it needs to be updated there also!
25+
IUPAC_ALLELES = "ACGT-RYSWKMBDHV."
26+
2927
NODE_IS_MUTATION_OVERLAP = 1 << 21
3028
NODE_IS_REVERSION_PUSH = 1 << 22
3129
NODE_IS_RECOMBINANT = 1 << 23
@@ -97,163 +95,3 @@ def decode_flags(f):
9795

9896
def flags_summary(f):
9997
return "".join([v.short if (v.value & f) > 0 else "_" for v in flag_values])
100-
101-
102-
class FastaReader(collections.abc.Mapping):
103-
def __init__(self, path, add_zero_base=True):
104-
self.reader = pyfaidx.Fasta(str(path))
105-
self._keys = list(self.reader.keys())
106-
self.add_zero_base = add_zero_base
107-
108-
def __getitem__(self, key):
109-
x = self.reader[key]
110-
h = np.array(x).astype(str)
111-
h = np.char.upper(h)
112-
if self.add_zero_base:
113-
return np.append(["X"], h)
114-
return h
115-
116-
def __iter__(self):
117-
return iter(self._keys)
118-
119-
def __len__(self):
120-
return len(self._keys)
121-
122-
123-
data_path = pathlib.Path(__file__).parent / "data"
124-
125-
126-
def get_problematic_regions():
127-
"""
128-
These regions have been reported to have highly recurrent or unusual
129-
patterns of deletions.
130-
131-
https://github.com/jeromekelleher/sc2ts/issues/231#issuecomment-2401405355
132-
133-
Region: NTD domain
134-
Coords: [21602-22472)
135-
Multiple highly recurrent deleted regions in NTD domain in Spike
136-
https://www.ncbi.nlm.nih.gov/pmc/articles/PMC7971772/
137-
138-
Region: ORF8
139-
https://virological.org/t/repeated-loss-of-orf8-expression-in-circulating-sars-cov-2-lineages/931/1
140-
141-
The 1-based (half-open) coordinates were taken from the UCSC Genome Browser.
142-
"""
143-
orf8 = get_gene_coordinates()["ORF8"]
144-
return np.concatenate(
145-
[
146-
np.arange(21602, 22472, dtype=np.int64), # NTD domain in S
147-
np.arange(*orf8, dtype=np.int64),
148-
]
149-
)
150-
151-
152-
def get_flank_coordinates():
153-
"""
154-
Return the coordinates at either end of the genome for masking out.
155-
"""
156-
genes = get_gene_coordinates()
157-
start = genes["ORF1ab"][0]
158-
end = genes["ORF10"][1]
159-
return np.concatenate(
160-
(np.arange(1, start), np.arange(end, REFERENCE_SEQUENCE_LENGTH))
161-
)
162-
163-
164-
def get_masked_sites(ts):
165-
"""
166-
Return the set of sites not used in the sequence.
167-
"""
168-
unused = np.ones(int(ts.sequence_length), dtype=bool)
169-
unused[ts.sites_position.astype(int)] = False
170-
unused[0] = False
171-
return np.where(unused)[0]
172-
173-
174-
@dataclasses.dataclass
175-
class CovLineage:
176-
name: str
177-
earliest_date: str
178-
latest_date: str
179-
description: str
180-
181-
182-
def get_cov_lineages_data():
183-
with open(data_path / "lineages.json") as f:
184-
data = json.load(f)
185-
ret = {}
186-
for record in data:
187-
lineage = CovLineage(
188-
record["Lineage"],
189-
record["Earliest date"],
190-
record["Latest date"],
191-
record["Description"],
192-
)
193-
assert lineage.name not in ret
194-
ret[lineage.name] = lineage
195-
return ret
196-
197-
198-
__cached_reference = None
199-
200-
201-
def get_reference_sequence(as_array=False):
202-
global __cached_reference
203-
if __cached_reference is None:
204-
reader = pyfaidx.Fasta(str(data_path / "reference.fasta"))
205-
__cached_reference = reader[REFERENCE_GENBANK]
206-
if as_array:
207-
h = np.array(__cached_reference).astype(str)
208-
return np.append(["X"], h)
209-
else:
210-
return "X" + str(__cached_reference)
211-
212-
213-
__cached_genes = None
214-
215-
216-
def get_gene_coordinates():
217-
"""
218-
Returns a map of gene name to interval, (start, stop). These are
219-
half-open, left-inclusive, right-exclusive.
220-
"""
221-
global __cached_genes
222-
if __cached_genes is None:
223-
d = {}
224-
with open(data_path / "annotation.csv") as f:
225-
reader = csv.DictReader(f, delimiter=",")
226-
for row in reader:
227-
d[row["gene"]] = (int(row["start"]), int(row["end"]))
228-
__cached_genes = d
229-
return __cached_genes
230-
231-
232-
# We omit N here as it's mapped to -1. Make "-" the 5th allele
233-
# as this is a valid allele for us.
234-
IUPAC_ALLELES = "ACGT-RYSWKMBDHV."
235-
236-
237-
# FIXME make cache optional
238-
@numba.njit(cache=True)
239-
def encode_alignment(h):
240-
# Just so numba knows this is a constant string
241-
alleles = "ACGT-RYSWKMBDHV."
242-
n = h.shape[0]
243-
a = np.full(n, -1, dtype=np.int8)
244-
for j in range(n):
245-
if h[j] == "N":
246-
a[j] = -1
247-
else:
248-
for k, c in enumerate(alleles):
249-
if c == h[j]:
250-
break
251-
else:
252-
raise ValueError(f"Allele {h[j]} not recognised")
253-
a[j] = k
254-
return a
255-
256-
257-
def decode_alignment(a):
258-
alleles = np.array(tuple(IUPAC_ALLELES + "N"), dtype=str)
259-
return alleles[a]

0 commit comments

Comments
 (0)