Skip to content

Commit baad735

Browse files
committed
reduce peak memory usage
1 parent 5212b3c commit baad735

File tree

1 file changed

+40
-34
lines changed

1 file changed

+40
-34
lines changed

src/melon/melon.py

Lines changed: 40 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -152,36 +152,38 @@ def parse_minimap(self):
152152
'''
153153
Parse minimap2's output and record alignments.
154154
'''
155+
accession2lineage = {}
156+
with open(os.path.join(self.db, 'metadata.tsv')) as f:
157+
next(f)
158+
for line in f:
159+
ls = line.rstrip().split('\t')
160+
accession2lineage[ls[0]] = ';'.join(ls[1:])
161+
155162
qcoords = defaultdict(set)
156163
for hit in self.hits:
157164
qcoords[hit[0]].add(tuple(hit[-2:]))
165+
scores = defaultdict(lambda: defaultdict(lambda: {'AS': 0, 'DE': 0, 'ID': 0}))
158166

159-
with open(get_filename(self.file, self.output, '.minimap.tmp')) as f:
167+
with open(f'{self.outfile}.minimap.tmp') as f:
160168
for line in f:
161169
ls = line.rstrip().split('\t')
162170
qstart, qend, qseqid, sseqid = int(ls[2]), int(ls[3]), ls[0], ls[5]
171+
lineage = accession2lineage[sseqid.rsplit('_', 1)[0]]
163172

164-
AS = int(ls[14].split('AS:i:')[-1])
165-
DE = 1 - float((ls[19] if ls[16] in {'tp:A:S', 'tp:A:i'} else ls[20]).split('de:f:')[-1]) # gap-compressed identity
166-
ID = int(ls[9]) / int(ls[10]) # gap-uncompressed identity
173+
AS_MAX, AS = scores[qseqid][lineage].get('AS', 0), int(ls[14].split('AS:i:')[-1])
174+
DE_MAX, DE = scores[qseqid][lineage].get('DE', 0), 1 - float((ls[19] if ls[16] in {'tp:A:S', 'tp:A:i'} else ls[20]).split('de:f:')[-1])
175+
ID_MAX, ID = scores[qseqid][lineage].get('ID', 0), int(ls[9]) / int(ls[10])
167176

168177
## filter out non-overlapping alignments
169-
if any(compute_overlap((qstart, qend, *qcoord))>0 for qcoord in qcoords[qseqid]):
170-
self.alignments.append([qseqid, sseqid, AS, DE, ID])
171-
172-
def postprocess(self, max_iteration=1000, epsilon=1e-10):
173-
'''
174-
Post-processing and label reassignment using EM.
175-
'''
176-
accession2lineage = {}
177-
with open(os.path.join(self.db, 'metadata.tsv')) as f:
178-
next(f)
179-
for line in f:
180-
ls = line.rstrip().split('\t')
181-
accession2lineage[ls[0]] = ';'.join(ls[1:])
182-
183-
## alignment filtering based on AS, DE, and ID, keep only the first per qseqid and lineage, remove all inferior ones
184-
data = []
178+
if AS > AS_MAX or DE > DE_MAX or ID > ID_MAX:
179+
if any(compute_overlap((qstart, qend, *qcoord))>0 for qcoord in qcoords[qseqid]):
180+
scores[qseqid][lineage]['AS'] = max(AS_MAX, AS)
181+
scores[qseqid][lineage]['DE'] = max(DE_MAX, DE)
182+
scores[qseqid][lineage]['ID'] = max(ID_MAX, ID)
183+
self.alignments.append([qseqid, sseqid, AS, DE, ID, lineage])
184+
185+
## filter out low-score alignments
186+
alignments = []
185187
duplicates = set()
186188
max_scores = defaultdict(lambda: {'AS': 0, 'DE': 0, 'ID': 0})
187189

@@ -190,27 +192,31 @@ def postprocess(self, max_iteration=1000, epsilon=1e-10):
190192
max_scores[alignment[0]]['DE'] = max(max_scores[alignment[0]]['DE'], alignment[3])
191193
max_scores[alignment[0]]['ID'] = max(max_scores[alignment[0]]['ID'], alignment[4])
192194

193-
for row in sorted(self.alignments, key=lambda alignment: (alignment[0], alignment[2], alignment[3], alignment[4]), reverse=True):
195+
for alignment in sorted(self.alignments, key=lambda alignment: (alignment[0], alignment[2], alignment[3], alignment[4]), reverse=True):
194196
if (
195-
max(row[2] / 0.9975, row[2] + 25) > max_scores[row[0]]['AS'] or
196-
row[3] / 0.9995 > max_scores[row[0]]['DE'] or
197-
row[4] / 0.9995 > max_scores[row[0]]['ID']
197+
max(alignment[2] / 0.9975, alignment[2] + 25) > max_scores[alignment[0]]['AS'] or
198+
alignment[3] / 0.9995 > max_scores[alignment[0]]['DE'] or
199+
alignment[4] / 0.9995 > max_scores[alignment[0]]['ID']
198200
):
199-
species = accession2lineage[row[1].rsplit('_', 1)[0]]
200-
if (row[0], species) not in duplicates:
201-
data.append(row + [species])
202-
duplicates.add((row[0], species))
201+
if (alignment[0], alignment[-1]) not in duplicates:
202+
alignments.append(alignment)
203+
duplicates.add((alignment[0], alignment[-1]))
203204

204-
## save pairwise gap-uncompressed/gap-compressed identity for ANI calculation
205-
self.identities = {(row[0], row[-1]): (row[3], row[4]) for row in data}
205+
## save pairwise gap-compressed/gap-uncompressed identity for ANI calculation
206+
self.alignments = alignments
207+
self.identities = {(alignment[0], alignment[-1]): (alignment[3], alignment[4]) for alignment in self.alignments}
206208

209+
def postprocess(self, max_iteration=1000, epsilon=1e-10):
210+
'''
211+
Post-processing and label reassignment using EM.
212+
'''
207213
## create a matrix then fill
208-
qseqids, lineages = np.unique([row[0] for row in data]), np.unique([row[-1] for row in data])
214+
qseqids, lineages = np.unique([alignment[0] for alignment in self.alignments]), np.unique([alignment[-1] for alignment in self.alignments])
209215
qseqid2index = {qseqid: index for index, qseqid in enumerate(qseqids)}
210216
lineage2index = {lineage: index for index, lineage in enumerate(lineages)}
211217

212-
rows = [qseqid2index[row[0]] for row in data]
213-
cols = [lineage2index[row[-1]] for row in data]
218+
rows = [qseqid2index[alignment[0]] for alignment in self.alignments]
219+
cols = [lineage2index[alignment[-1]] for alignment in self.alignments]
214220
matrix = csr_matrix((np.ones(len(rows)), (rows, cols)), shape=(len(qseqids), len(lineages)), dtype=int)
215221

216222
## run EM using the count matrix as input
@@ -287,7 +293,7 @@ def run(self, debug=False, db_kraken=None, skip_profile=False, skip_clean=False,
287293
'''
288294
if db_kraken is not None:
289295
logger.info('Filtering reads ...')
290-
if not debug: self.run_kraken(db_krake=db_kraken)
296+
if not debug: self.run_kraken(db_kraken=db_kraken)
291297
self.parse_kraken()
292298
logger.info(f'... removed {len(self.nset)} putatively non-prokaryotic reads.')
293299

0 commit comments

Comments
 (0)