1010# but WITHOUT ANY WARRANTY; without even the implied warranty of
1111# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
1212# GNU General Public License for more details.
13+ import os
1314
1415# You should have received a copy of the GNU General Public License
1516# along with this program; if not, write to the Free Software
@@ -107,6 +108,7 @@ def init(
107108 self .flows_to_process_q = multiprocessing .Queue (maxsize = 5162220 )
108109 self .handle_setting_local_net_lock = multiprocessing .Lock ()
109110 self .is_first_msg = True
111+ # runs a separate server process behind the scenes.
110112 self .manager = multiprocessing .Manager ()
111113 self .localnet_cache = self .manager .dict ()
112114 # max parallel profiler workers to start when high throughput is detected
@@ -147,14 +149,22 @@ def get_input_type(self, line: dict, input_type: str) -> str:
147149 return input_type
148150
149151 def stop_profiler_workers (self ):
150- self .stop_profiler_workers_event .set ()
152+ self .stop_profiler_workers_event .set () # Signal workers to exit
153+ time .sleep (2 )
154+ # Try to join gracefully first
155+ for process in self .profiler_child_processes :
156+ try :
157+ process .join (timeout = 3 )
158+ except (OSError , ChildProcessError ):
159+ pass
160+
161+ # Terminate any processes that are still alive after the join timeout
151162 for process in self .profiler_child_processes :
152163 try :
153164 if process .is_alive ():
154165 process .terminate ()
155- process .join (timeout = 3 )
166+ process .join (timeout = 0.1 )
156167 except (OSError , ChildProcessError ):
157- # continue loop; don't abort shutdown
158168 pass
159169
160170 def mark_process_as_done_processing (self ):
@@ -199,7 +209,12 @@ def worker(
199209 ZeekTabs | ZeekJSON | Argus | Suricata | ZeekTabs | Nfdump
200210 ),
201211 ):
202- ProfilerWorker (
212+ worker_number = name .split ("_" )[- 1 ]
213+ self .print (
214+ f"Started Profiler Worker { green (worker_number )} [PID"
215+ f" { green (os .getpid ())} ]"
216+ )
217+ worker = ProfilerWorker (
203218 name = name ,
204219 logger = self .logger ,
205220 output_dir = self .output_dir ,
@@ -214,18 +229,20 @@ def worker(
214229 flows_to_process_q = self .flows_to_process_q ,
215230 input_handler = input_handler_obj ,
216231 bloom_filters = self .bloom_filters ,
217- ).start ()
232+ )
233+ worker .main ()
218234
219235 def start_profiler_worker (self , worker_id : int = None ):
220236 """starts A profiler worker for faster processing of the flows"""
221- worker_name = f"ProfilerWorker_ { worker_id } "
237+ worker_name = f"ProfilerWorker_Process_ { worker_id } "
222238 proc = multiprocessing .Process (
223239 target = self .worker ,
224240 args = (
225241 worker_name ,
226242 self .input_handler_cls ,
227243 ),
228244 name = worker_name ,
245+ daemon = True ,
229246 )
230247 utils .start_process (proc , self .db )
231248 self .profiler_child_processes .append (proc )
@@ -273,12 +290,17 @@ def shutdown_gracefully(self):
273290
274291 # wait for all flows to be processed by the profiler processes.
275292 self .stop_profiler_workers ()
276-
277293 # close the queues to avoid deadlocks.
278294 # this step SHOULD NEVER be done before closing the workers
279295 self .flows_to_process_q .close ()
296+ # By default if a process is not the creator of the queue then on
297+ # exit it will attempt to join the queue’s background thread. The
298+ # process can call cancel_join_thread() to make join_thread()
299+ # do nothing.
300+ self .flows_to_process_q .cancel_join_thread ()
280301 self .profiler_queue .close ()
281302
303+ self .manager .shutdown ()
282304 self .db .set_new_incoming_flows (False )
283305 self .print (
284306 f"Stopping. Total lines read: { self .rec_lines } " ,
@@ -317,6 +339,7 @@ def check_if_high_throughput_and_add_workers(self):
317339 Checks for input and profile flows/sec imbalance and adds more
318340 profiler workers if needed.
319341 """
342+
320343 if self .max_workers_started ():
321344 return
322345
@@ -354,7 +377,6 @@ def main(self):
354377 # we're using self.should_stop() here instead of while True to be
355378 # able to unit test this function:D
356379 while not self .should_stop ():
357-
358380 self .lines = sum (
359381 [worker .received_lines for worker in self .workers ]
360382 )
@@ -376,7 +398,6 @@ def main(self):
376398
377399 if self .is_first_msg :
378400 self .is_first_msg = False
379-
380401 self .input_handler_cls = self .get_handler_class (msg )
381402 if not self .input_handler_cls :
382403 self .print ("Unsupported input type, exiting." )
@@ -396,3 +417,4 @@ def main(self):
396417
397418 self .flows_to_process_q .put (msg , block = True , timeout = None )
398419 self .check_if_high_throughput_and_add_workers ()
420+ return
0 commit comments