Skip to content

Commit 7ac0bb6

Browse files
author
Julian
committed
Add multi process functionality
1 parent 46e7950 commit 7ac0bb6

File tree

2 files changed

+76
-104
lines changed

2 files changed

+76
-104
lines changed

fracsim/input_layer/cli.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ def parse_arguments():
6767
'-t','--threads',
6868
type=int,
6969
default=1,
70-
help='计算线程数,默认值1'
70+
help='并行任务数(生成素描阶段使用多进程加速),默认值1'
7171
)
7272
parser.add_argument(
7373
'-m','--min-similarity',

fracsim/process_layer/kmer_sketch.py

Lines changed: 75 additions & 103 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,11 @@
11
"""k-mer计算,FracMinHash素描生成模块"""
22

33

4-
from typing import Set
4+
import sys
55
from .models import SketchData, GenomeData
66
from ..utils.hash import HashFunction
77
from typing import Set, Generator
8-
from concurrent.futures import ThreadPoolExecutor, as_completed
9-
from ..utils.hash import HashFunction
8+
from concurrent.futures import ProcessPoolExecutor, as_completed
109

1110

1211
class KmerGenerator:
@@ -22,18 +21,7 @@ def __init__(self, k: int, seed: int = 42, threads: int = 1):
2221
threads: 线程数
2322
"""
2423
self.k = k
25-
self.threads = threads
2624
self.hash_func = HashFunction(seed)
27-
# 创建线程池(仅在多线程时)
28-
if threads > 1:
29-
self.executor = ThreadPoolExecutor(max_workers=threads)
30-
else:
31-
self.executor = None
32-
33-
def __del__(self):
34-
# 尝试关闭线程池(非阻塞)
35-
if hasattr(self, 'executor') and self.executor:
36-
self.executor.shutdown(wait=False)
3725

3826
def generate_kmers(self, sequence: str, canonical: bool = True) -> Generator[str, None, None]:
3927
"""
@@ -90,87 +78,16 @@ def get_kmer_hashes(self, sequence: str, max_hash: int, canonical: bool = True)
9078
"""
9179
hashes = set()
9280

93-
if self.threads > 1 and len(sequence) > 100000 and self.executor is not None: # 避免小序列的并行开销
94-
# 并行处理长序列
95-
hashes = self._parallel_kmer_hash(sequence, max_hash, canonical)
96-
else:
97-
# 串行处理
98-
for kmer in self.generate_kmers(sequence, canonical):
99-
hash_val = self.hash_func.get_hash(kmer)
100-
if hash_val < max_hash: # 筛选
101-
hashes.add(hash_val)
102-
103-
return hashes
104-
105-
106-
def _parallel_kmer_hash(self, sequence: str, max_hash: int, canonical: bool) -> Set[int]:
107-
"""
108-
并行计算k-mer哈希,使用实例线程池
109-
110-
Args:
111-
sequence: 序列字符串
112-
canonical: 是否使用正则形式
113-
114-
Returns:
115-
Set[int]: 哈希值集合
116-
"""
117-
hashes = set()
118-
seq_len = len(sequence)
119-
chunk_size = seq_len // self.threads
120-
121-
futures = []
122-
for i in range(self.threads):
123-
start = i * chunk_size
124-
# 最后一个块直接到末尾,其他块扩展k-1保证边界k-mer完整
125-
end = start + chunk_size + self.k if i < self.threads - 1 else seq_len
126-
if start < seq_len - self.k + 1:
127-
future = self.executor.submit(
128-
self._process_chunk,
129-
sequence[start:end],
130-
max_hash,
131-
start,
132-
canonical
133-
)
134-
futures.append(future)
135-
136-
for future in as_completed(futures):
137-
hashes.update(future.result())
138-
139-
return hashes
140-
141-
142-
def _process_chunk(self, chunk: str, max_hash: int, offset: int, canonical: bool) -> Set[int]:
143-
"""
144-
处理序列片段
14581

146-
Args:
147-
chunk: 序列片段
148-
offset: 起始位置偏移
149-
canonical: 是否使用正则形式
150-
151-
Returns:
152-
Set[int]: 哈希值集合
153-
"""
154-
hashes = set()
155-
AMBIGUOUS_BASES = set('RYSWKMBDHVryswkmbdhv')
156-
157-
for i in range(len(chunk) - self.k + 1):
158-
kmer = chunk[i:i + self.k].upper()
159-
160-
if 'N' in kmer or any(c in AMBIGUOUS_BASES for c in kmer):
161-
continue
162-
163-
if canonical:
164-
rev_comp = self._reverse_complement(kmer)
165-
kmer = min(kmer, rev_comp)
166-
82+
# 串行处理
83+
for kmer in self.generate_kmers(sequence, canonical):
16784
hash_val = self.hash_func.get_hash(kmer)
168-
if hash_val < max_hash:
85+
if hash_val < max_hash: # 筛选
16986
hashes.add(hash_val)
170-
17187

17288
return hashes
17389

90+
17491
def _reverse_complement(self, seq: str) -> str:
17592
"""
17693
计算序列的反向互补
@@ -185,12 +102,33 @@ def _reverse_complement(self, seq: str) -> str:
185102
return ''.join(complement.get(base, base) for base in reversed(seq))
186103

187104

105+
# 辅助函数:供多进程调用,计算单条序列的哈希集合
106+
def _process_sequence_for_sketch(sequence: str, k: int, max_hash: int, seed: int, canonical: bool) -> Set[int]:
107+
"""
108+
子进程任务:计算单条序列的FracMinHash哈希集合
109+
110+
Args:
111+
sequence: 序列字符串
112+
k: k-mer长度
113+
max_hash: 最大哈希阈值
114+
seed: 哈希种子
115+
canonical: 是否使用正则形式
116+
117+
Returns:
118+
Set[int]: 哈希值集合
119+
"""
120+
# 每个子进程独立创建KmerGenerator,避免传递不可序列化对象
121+
kg = KmerGenerator(k, seed, threads=1)
122+
return kg.get_kmer_hashes(sequence, max_hash, canonical)
123+
124+
125+
188126
# ---------------------------------------------------
189127
#----------------------------------------------------
190128
#----------------------------------------------------
191129

192130
class FracMinHashSketch:
193-
"""FracMinHash素描生成器类"""
131+
"""FracMinHash素描生成器类(支持多进程并行)"""
194132

195133
def __init__(self, k: int, scaled: float, seed: int = 42, threads: int = 1):
196134
"""
@@ -220,20 +158,21 @@ def create_sketch(self, genome_data: GenomeData) -> SketchData:
220158
SketchData: 素描数据
221159
"""
222160
max_hash = self._calculate_max_hash()
223-
all_hashes = set()
224-
total_kmers = 0
225-
226-
# 遍历所有序列,合并哈希集合
227-
for sequence in genome_data.sequences:
228-
# 获取当前序列的过滤后哈希集合
229-
seq_hashes = self.kmer_gen.get_kmer_hashes(sequence, max_hash)
230-
all_hashes.update(seq_hashes) # 合并哈希集合
231-
sketch_size = len(all_hashes)
232-
233-
# 累加总窗口数
234-
total_kmers += len(sequence) - self.k + 1
161+
sequences = genome_data.sequences
235162

163+
# 计算总k-mer数(串行,仅用长度)
164+
total_kmers = 0
165+
for seq in sequences:
166+
if len(seq) >= self.k:
167+
total_kmers += len(seq) - self.k + 1
236168

169+
# 判断是否启用多进程
170+
if self.threads > 1 and len(sequences) > 1:
171+
all_hashes = self._parallel_process(sequences, max_hash)
172+
else:
173+
all_hashes = self._serial_process(sequences, max_hash)
174+
175+
237176
# 创建素描数据
238177
sketch = SketchData(
239178
genome_id=genome_data.seq_id,
@@ -242,13 +181,46 @@ def create_sketch(self, genome_data: GenomeData) -> SketchData:
242181
seed=self.seed,
243182
hashes=all_hashes,
244183
total_kmers=total_kmers,
245-
sketch_size=sketch_size
184+
sketch_size=len(all_hashes)
246185

247186
)
248187

249188
return sketch
250189

190+
191+
def _serial_process(self, sequences: list, max_hash: int) -> Set[int]:
192+
"""串行处理所有序列"""
193+
all_hashes = set()
194+
for seq in sequences:
195+
seq_hashes = self.kmer_gen.get_kmer_hashes(seq, max_hash, canonical=True)
196+
all_hashes.update(seq_hashes)
197+
return all_hashes
251198

199+
def _parallel_process(self, sequences: list, max_hash: int) -> Set[int]:
200+
"""使用进程池并行处理所有序列"""
201+
all_hashes = set()
202+
with ProcessPoolExecutor(max_workers=self.threads) as executor:
203+
futures = []
204+
for seq in sequences:
205+
# 跳过长度小于k的序列(不会产生k-mer)
206+
if len(seq) < self.k:
207+
continue
208+
future = executor.submit(
209+
_process_sequence_for_sketch,
210+
seq, self.k, max_hash, self.seed, True
211+
)
212+
futures.append(future)
213+
214+
for future in as_completed(futures):
215+
try:
216+
seq_hashes = future.result()
217+
all_hashes.update(seq_hashes)
218+
except Exception as e:
219+
# 实际项目中可改用日志模块
220+
print(f"Error in subprocess: {e}", file=sys.stderr)
221+
return all_hashes
222+
223+
252224
def _calculate_max_hash(self) -> int:
253225
"""
254226
计算最大哈希阈值

0 commit comments

Comments
 (0)