11"""k-mer计算,FracMinHash素描生成模块"""
22
33
4- from typing import Set
4+ import sys
55from .models import SketchData , GenomeData
66from ..utils .hash import HashFunction
77from typing import Set , Generator
8- from concurrent .futures import ThreadPoolExecutor , as_completed
9- from ..utils .hash import HashFunction
8+ from concurrent .futures import ProcessPoolExecutor , as_completed
109
1110
1211class KmerGenerator :
@@ -22,18 +21,7 @@ def __init__(self, k: int, seed: int = 42, threads: int = 1):
2221 threads: 线程数
2322 """
2423 self .k = k
25- self .threads = threads
2624 self .hash_func = HashFunction (seed )
27- # 创建线程池(仅在多线程时)
28- if threads > 1 :
29- self .executor = ThreadPoolExecutor (max_workers = threads )
30- else :
31- self .executor = None
32-
33- def __del__ (self ):
34- # 尝试关闭线程池(非阻塞)
35- if hasattr (self , 'executor' ) and self .executor :
36- self .executor .shutdown (wait = False )
3725
3826 def generate_kmers (self , sequence : str , canonical : bool = True ) -> Generator [str , None , None ]:
3927 """
@@ -90,87 +78,16 @@ def get_kmer_hashes(self, sequence: str, max_hash: int, canonical: bool = True)
9078 """
9179 hashes = set ()
9280
93- if self .threads > 1 and len (sequence ) > 100000 and self .executor is not None : # 避免小序列的并行开销
94- # 并行处理长序列
95- hashes = self ._parallel_kmer_hash (sequence , max_hash , canonical )
96- else :
97- # 串行处理
98- for kmer in self .generate_kmers (sequence , canonical ):
99- hash_val = self .hash_func .get_hash (kmer )
100- if hash_val < max_hash : # 筛选
101- hashes .add (hash_val )
102-
103- return hashes
104-
105-
106- def _parallel_kmer_hash (self , sequence : str , max_hash : int , canonical : bool ) -> Set [int ]:
107- """
108- 并行计算k-mer哈希,使用实例线程池
109-
110- Args:
111- sequence: 序列字符串
112- canonical: 是否使用正则形式
113-
114- Returns:
115- Set[int]: 哈希值集合
116- """
117- hashes = set ()
118- seq_len = len (sequence )
119- chunk_size = seq_len // self .threads
120-
121- futures = []
122- for i in range (self .threads ):
123- start = i * chunk_size
124- # 最后一个块直接到末尾,其他块扩展k-1保证边界k-mer完整
125- end = start + chunk_size + self .k if i < self .threads - 1 else seq_len
126- if start < seq_len - self .k + 1 :
127- future = self .executor .submit (
128- self ._process_chunk ,
129- sequence [start :end ],
130- max_hash ,
131- start ,
132- canonical
133- )
134- futures .append (future )
135-
136- for future in as_completed (futures ):
137- hashes .update (future .result ())
138-
139- return hashes
140-
141-
142- def _process_chunk (self , chunk : str , max_hash : int , offset : int , canonical : bool ) -> Set [int ]:
143- """
144- 处理序列片段
14581
146- Args:
147- chunk: 序列片段
148- offset: 起始位置偏移
149- canonical: 是否使用正则形式
150-
151- Returns:
152- Set[int]: 哈希值集合
153- """
154- hashes = set ()
155- AMBIGUOUS_BASES = set ('RYSWKMBDHVryswkmbdhv' )
156-
157- for i in range (len (chunk ) - self .k + 1 ):
158- kmer = chunk [i :i + self .k ].upper ()
159-
160- if 'N' in kmer or any (c in AMBIGUOUS_BASES for c in kmer ):
161- continue
162-
163- if canonical :
164- rev_comp = self ._reverse_complement (kmer )
165- kmer = min (kmer , rev_comp )
166-
82+ # 串行处理
83+ for kmer in self .generate_kmers (sequence , canonical ):
16784 hash_val = self .hash_func .get_hash (kmer )
168- if hash_val < max_hash :
85+ if hash_val < max_hash : # 筛选
16986 hashes .add (hash_val )
170-
17187
17288 return hashes
17389
90+
17491 def _reverse_complement (self , seq : str ) -> str :
17592 """
17693 计算序列的反向互补
@@ -185,12 +102,33 @@ def _reverse_complement(self, seq: str) -> str:
185102 return '' .join (complement .get (base , base ) for base in reversed (seq ))
186103
187104
105+ # 辅助函数:供多进程调用,计算单条序列的哈希集合
106+ def _process_sequence_for_sketch (sequence : str , k : int , max_hash : int , seed : int , canonical : bool ) -> Set [int ]:
107+ """
108+ 子进程任务:计算单条序列的FracMinHash哈希集合
109+
110+ Args:
111+ sequence: 序列字符串
112+ k: k-mer长度
113+ max_hash: 最大哈希阈值
114+ seed: 哈希种子
115+ canonical: 是否使用正则形式
116+
117+ Returns:
118+ Set[int]: 哈希值集合
119+ """
120+ # 每个子进程独立创建KmerGenerator,避免传递不可序列化对象
121+ kg = KmerGenerator (k , seed , threads = 1 )
122+ return kg .get_kmer_hashes (sequence , max_hash , canonical )
123+
124+
125+
188126# ---------------------------------------------------
189127#----------------------------------------------------
190128#----------------------------------------------------
191129
192130class FracMinHashSketch :
193- """FracMinHash素描生成器类"""
131+ """FracMinHash素描生成器类(支持多进程并行) """
194132
195133 def __init__ (self , k : int , scaled : float , seed : int = 42 , threads : int = 1 ):
196134 """
@@ -220,20 +158,21 @@ def create_sketch(self, genome_data: GenomeData) -> SketchData:
220158 SketchData: 素描数据
221159 """
222160 max_hash = self ._calculate_max_hash ()
223- all_hashes = set ()
224- total_kmers = 0
225-
226- # 遍历所有序列,合并哈希集合
227- for sequence in genome_data .sequences :
228- # 获取当前序列的过滤后哈希集合
229- seq_hashes = self .kmer_gen .get_kmer_hashes (sequence , max_hash )
230- all_hashes .update (seq_hashes ) # 合并哈希集合
231- sketch_size = len (all_hashes )
232-
233- # 累加总窗口数
234- total_kmers += len (sequence ) - self .k + 1
161+ sequences = genome_data .sequences
235162
163+ # 计算总k-mer数(串行,仅用长度)
164+ total_kmers = 0
165+ for seq in sequences :
166+ if len (seq ) >= self .k :
167+ total_kmers += len (seq ) - self .k + 1
236168
169+ # 判断是否启用多进程
170+ if self .threads > 1 and len (sequences ) > 1 :
171+ all_hashes = self ._parallel_process (sequences , max_hash )
172+ else :
173+ all_hashes = self ._serial_process (sequences , max_hash )
174+
175+
237176 # 创建素描数据
238177 sketch = SketchData (
239178 genome_id = genome_data .seq_id ,
@@ -242,13 +181,46 @@ def create_sketch(self, genome_data: GenomeData) -> SketchData:
242181 seed = self .seed ,
243182 hashes = all_hashes ,
244183 total_kmers = total_kmers ,
245- sketch_size = sketch_size
184+ sketch_size = len ( all_hashes )
246185
247186 )
248187
249188 return sketch
250189
190+
191+ def _serial_process (self , sequences : list , max_hash : int ) -> Set [int ]:
192+ """串行处理所有序列"""
193+ all_hashes = set ()
194+ for seq in sequences :
195+ seq_hashes = self .kmer_gen .get_kmer_hashes (seq , max_hash , canonical = True )
196+ all_hashes .update (seq_hashes )
197+ return all_hashes
251198
199+ def _parallel_process (self , sequences : list , max_hash : int ) -> Set [int ]:
200+ """使用进程池并行处理所有序列"""
201+ all_hashes = set ()
202+ with ProcessPoolExecutor (max_workers = self .threads ) as executor :
203+ futures = []
204+ for seq in sequences :
205+ # 跳过长度小于k的序列(不会产生k-mer)
206+ if len (seq ) < self .k :
207+ continue
208+ future = executor .submit (
209+ _process_sequence_for_sketch ,
210+ seq , self .k , max_hash , self .seed , True
211+ )
212+ futures .append (future )
213+
214+ for future in as_completed (futures ):
215+ try :
216+ seq_hashes = future .result ()
217+ all_hashes .update (seq_hashes )
218+ except Exception as e :
219+ # 实际项目中可改用日志模块
220+ print (f"Error in subprocess: { e } " , file = sys .stderr )
221+ return all_hashes
222+
223+
252224 def _calculate_max_hash (self ) -> int :
253225 """
254226 计算最大哈希阈值
0 commit comments