diff --git a/sotodlib/preprocess/pcore.py b/sotodlib/preprocess/pcore.py index ce9501927..20513b8f9 100644 --- a/sotodlib/preprocess/pcore.py +++ b/sotodlib/preprocess/pcore.py @@ -7,6 +7,7 @@ from so3g.proj import Ranges, RangesMatrix from scipy.sparse import csr_array from matplotlib import pyplot as plt +import tracemalloc class _Preprocess(object): """The base class for Preprocessing modules which defines the required @@ -408,7 +409,7 @@ def extend(self, index, other): def __setitem__(self, index, item): super().__setitem__(index, self._check_item(item)) - def run(self, aman, proc_aman=None, select=True, sim=False, update_plot=False): + def run(self, aman, proc_aman=None, select=True, sim=False, update_plot=False, run_tracemalloc=False): """ The main workhorse function for the pipeline class. This function takes an AxisManager TOD and successively runs the pipeline of preprocessing @@ -469,15 +470,26 @@ def run(self, aman, proc_aman=None, select=True, sim=False, update_plot=False): proc_aman.restrict('dets', det_list) full = proc_aman.copy() run_calc = False - + + tracemalloc.start() + + snapshots_process = [] + snapshots_calc = [] + success = 'end' for step, process in enumerate(self): if sim and process.skip_on_sim: continue self.logger.debug(f"Running {process.name}") process.process(aman, proc_aman, sim) + if run_tracemalloc: + snapshot = tracemalloc.take_snapshot() + snapshots_process.append((process.name, snapshot)) if run_calc: process.calc_and_save(aman, proc_aman) + if run_tracemalloc: + snapshot = tracemalloc.take_snapshot() + snapshots_calc.append((process.name, snapshot)) process.plot(aman, proc_aman, filename=os.path.join(self.plot_dir, '{ctime}/{obsid}', f'{step+1}_{{name}}.png')) update_full_aman( proc_aman, full, self.wrap_valid) if update_plot: @@ -491,8 +503,11 @@ def run(self, aman, proc_aman=None, select=True, sim=False, update_plot=False): if aman.dets.count == 0: success = process.name break - - return full, success + tracemalloc.stop() + if run_tracemalloc: + return full, success, snapshots_process, snapshots_calc + else: + return full, success class _FracFlaggedMixIn(object): diff --git a/sotodlib/preprocess/preprocess_util.py b/sotodlib/preprocess/preprocess_util.py index 0f3e1c193..687c676d1 100644 --- a/sotodlib/preprocess/preprocess_util.py +++ b/sotodlib/preprocess/preprocess_util.py @@ -16,6 +16,49 @@ from . import _Preprocess, Pipeline, processes +import cProfile, pstats, io + + + +def profile_function(func, profile_path, *args, **kwargs): + """Runs CProfile on the input function and writes the profile out to a + file using pstats. + + Arguments + ---------- + func : function + The function to be called + profile_path : str + The path to the output profile file. + *args : tuple + Additional positional arguments. + **kwargs : dict + Additional keyword arguments. + + Returns + ------- + local_vars : Any or None + Either the outputs of the function or None if the profile or function + call fails. + """ + + local_vars = None + + def wrapper_func(): + nonlocal local_vars + local_vars = func(*args, **kwargs) + + if profile_path is None: + wrapper_func() + return local_vars + else: + try: + cProfile.runctx('wrapper_func()', globals(), locals(), filename=profile_path) + return local_vars + except Exception as e: + return None + + class ArchivePolicy: """Storage policy assistance. Helps to determine the HDF5 filename and dataset name for a result. @@ -805,7 +848,7 @@ def cleanup_obs(obs_id, policy_dir, errlog, configs, context=None, def preproc_or_load_group(obs_id, configs_init, dets, configs_proc=None, - logger=None, overwrite=False): + logger=None, overwrite=False, run_tracemalloc=False): """ This function is expected to receive a single obs_id, and dets dictionary. The dets dictionary must match the grouping specified in the preprocess @@ -985,7 +1028,10 @@ def preproc_or_load_group(obs_id, configs_init, dets, configs_proc=None, aman = context_init.get_obs(obs_id, dets=dets) tags = np.array(context_init.obsdb.get(aman.obs_info.obs_id, tags=True)['tags']) aman.wrap('tags', tags) - proc_aman, success = pipe_init.run(aman) + if run_tracemalloc: + proc_aman, success, snapshots_process, snapshots_calc = pipe_init.run(aman, run_tracemalloc=run_tracemalloc) + else: + proc_aman, success = pipe_init.run(aman) aman.wrap('preprocess', proc_aman) except Exception as e: error = f'Failed to run initial pipeline: {obs_id} {dets}' @@ -1001,7 +1047,7 @@ def preproc_or_load_group(obs_id, configs_init, dets, configs_proc=None, proc_aman.save(outputs_init['temp_file'], outputs_init['db_data']['dataset'], overwrite) if configs_proc is None: - return error, outputs_init, [obs_id, dets], aman + return error, outputs_init, [obs_id, dets], aman, snapshots_process, snapshots_calc else: try: outputs_proc = save_group(obs_id, configs_proc, dets, context_proc, subdir='temp_proc') diff --git a/sotodlib/site_pipeline/multilayer_preprocess_tod.py b/sotodlib/site_pipeline/multilayer_preprocess_tod.py index 5af3f6b57..29cc593db 100644 --- a/sotodlib/site_pipeline/multilayer_preprocess_tod.py +++ b/sotodlib/site_pipeline/multilayer_preprocess_tod.py @@ -7,6 +7,7 @@ import argparse import traceback from typing import Optional +import cProfile, pstats, io from sotodlib.utils.procs_pool import get_exec_env import h5py import copy @@ -269,6 +270,12 @@ def get_parser(parser=None): type=int, default=4 ) + parser.add_argument( + '--profile', + help="Run profiling.", + type=bool, + default=False + ) parser.add_argument( '--raise-error', help="Raise an error upon completion if any obsids or groups fail.", @@ -291,6 +298,7 @@ def _main(executor: Union["MPICommExecutor", "ProcessPoolExecutor"], planet_obs: bool = False, verbosity: Optional[int] = None, nproc: Optional[int] = 4, + run_profiling: Optional[bool] = False, raise_error: Optional[bool] = False): logger = pp_util.init_logger("preprocess", verbosity=verbosity) @@ -341,8 +349,18 @@ def _main(executor: Union["MPICommExecutor", "ProcessPoolExecutor"], n_fail = 0 + if run_profiling: + profile_dir = os.path.join(os.path.dirname(configs['archive']['policy']['filename']), 'prof') + if not(os.path.exists(profile_dir)): + os.makedirs(profile_dir) + else: + profile_dir = None + # run write_block obs-ids in parallel at once then write all to the sqlite db. - futures = [executor.submit(multilayer_preprocess_tod, obs_id=r[0]['obs_id'], + futures = [executor.submit(pp_util.profile_func, + func=multilayer_preprocess_tod, + profile_path=os.path.join(profile_dir, f'{r[0]["obs_id"]}.prof') if profile_dir is not None else None, + obs_id=r[0]['obs_id'], group_list=r[1], verbosity=verbosity, configs_init=configs_init, configs_proc=configs_proc, @@ -380,6 +398,26 @@ def _main(executor: Union["MPICommExecutor", "ProcessPoolExecutor"], else: pp_util.cleanup_mandb(err, db_datasets_proc, configs_proc, logger, overwrite) + if run_profiling: + combined_profile_dir = os.path.join(profile_dir, 'combined_profile.prof') + if os.path.exists(combined_profile_dir): + combined_stats = pstats.Stats(combined_profile_dir) + else: + combined_stats = None + for r in run_list: + profile_file = os.path.join(profile_dir, f'{r[0]["obs_id"]}.prof') + if os.path.exists(profile_file): + try: + stats = pstats.Stats(profile_file) + if combined_stats is None: + combined_stats = stats + else: + combined_stats.add(stats) + except: + logger.error(f"cannot get stats for {r[0]['obs_id']}") + if combined_stats is not None: + combined_stats.dump_stats(combined_profile_dir) + if raise_error and n_fail > 0: raise RuntimeError(f"multilayer_preprocess_tod: {n_fail}/{len(run_list)} obs_ids failed") @@ -396,6 +434,7 @@ def main(configs_init: str, planet_obs: bool = False, verbosity: Optional[int] = None, nproc: Optional[int] = 4, + run_profiling: Optional[bool] = False, raise_error: Optional[bool] = False): rank, executor, as_completed_callable = get_exec_env(nproc) @@ -414,6 +453,7 @@ def main(configs_init: str, planet_obs=planet_obs, verbosity=verbosity, nproc=nproc, + run_profiling=run_profiling, raise_error=raise_error) diff --git a/sotodlib/site_pipeline/preprocess_tod.py b/sotodlib/site_pipeline/preprocess_tod.py index a8e260483..d5b29c748 100644 --- a/sotodlib/site_pipeline/preprocess_tod.py +++ b/sotodlib/site_pipeline/preprocess_tod.py @@ -7,6 +7,8 @@ import argparse import traceback from typing import Optional +import cProfile, pstats, io +import tracemalloc from sotodlib.utils.procs_pool import get_exec_env import h5py import copy @@ -19,6 +21,7 @@ logger = sp_util.init_logger("preprocess") + def dummy_preproc(obs_id, group_list, logger, configs, overwrite, run_parallel): """ @@ -55,7 +58,8 @@ def preprocess_tod(obs_id, verbosity=0, group_list=None, overwrite=False, - run_parallel=False): + run_parallel=False, + run_tracemalloc=False): """Meant to be run as part of a batched script, this function calls the preprocessing pipeline a specific Observation ID and saves the results in the ManifestDb specified in the configs. @@ -139,10 +143,33 @@ def preprocess_tod(obs_id, logger.info(f"Beginning run for {obs_id}:{group}") dets = {gb:gg for gb, gg in zip(group_by, group)} try: + if run_tracemalloc: + tracemalloc.start() + aman = context.get_obs(obs_id, dets=dets) + if run_tracemalloc: + init_snapshot = ('aman', tracemalloc.take_snapshot()) + tracemalloc.stop() tags = np.array(context.obsdb.get(aman.obs_info.obs_id, tags=True)['tags']) aman.wrap('tags', tags) - proc_aman, success = pipe.run(aman) + if run_tracemalloc: + proc_aman, success, snapshots_process, snapshots_calc = pipe.run(aman, run_tracemalloc=run_tracemalloc) + snapshots_process = [init_snapshot] + snapshots_process + snapshots_calc = [init_snapshot] + snapshots_calc + + dest_dataset = obs_id + for gb, g in zip(group_by, group): + if gb == 'detset': + dest_dataset += "_" + g + else: + dest_dataset += "_" + gb + "_" + str(g) + trace_dir = os.path.join(os.path.dirname(configs['archive']['policy']['filename']), "trace") + for i, snap in enumerate(snapshots_process): + snap[1].dump(os.path.join(trace_dir, f"{dest_dataset}_snapshot_process_{snap[0]}_{i}.pkl")) + for i, snap in enumerate(snapshots_calc): + snap[1].dump(os.path.join(trace_dir, f"{dest_dataset}_snapshot_calc_{snap[0]}_{i}.pkl")) + else: + proc_aman, success = pipe.run(aman, run_tracemalloc=run_tracemalloc) if make_lmsi: new_plots = os.path.join(configs["plot_dir"], @@ -304,6 +331,18 @@ def get_parser(parser=None): type=int, default=4 ) + parser.add_argument( + '--profile', + help="Run profiling.", + type=bool, + default=False + ) + parser.add_argument( + '--tracemalloc', + help="Run tracemalloc.", + type=bool, + default=False + ) parser.add_argument( '--raise-error', help="Raise an error upon completion if any obsids or groups fail.", @@ -326,6 +365,8 @@ def _main(executor: Union["MPICommExecutor", "ProcessPoolExecutor"], planet_obs: bool = False, verbosity: Optional[int] = None, nproc: Optional[int] = 4, + run_profiling: Optional[bool] = False, + run_tracemalloc: Optional[bool] = False, raise_error: Optional[bool] = False): configs, context = pp_util.get_preprocess_context(configs) @@ -370,11 +411,28 @@ def _main(executor: Union["MPICommExecutor", "ProcessPoolExecutor"], n_fail = 0 + if run_profiling: + profile_dir = os.path.join(os.path.dirname(configs['archive']['policy']['filename']), 'prof') + if not(os.path.exists(profile_dir)): + os.makedirs(profile_dir) + else: + profile_dir = None + + if run_tracemalloc: + trace_dir = os.path.join(os.path.dirname(configs['archive']['policy']['filename']), 'trace') + if not(os.path.exists(trace_dir)): + os.makedirs(trace_dir) + # Run write_block obs-ids in parallel at once then write all to the sqlite db. - futures = [executor.submit(preprocess_tod, obs_id=r[0]['obs_id'], + futures = [executor.submit(pp_util.profile_function, + func=preprocess_tod, + profile_path=os.path.join(profile_dir, f'{r[0]["obs_id"]}.prof') if profile_dir is not None else None, + obs_id=r[0]['obs_id'], group_list=r[1], verbosity=verbosity, configs=configs, - overwrite=overwrite, run_parallel=True) for r in run_list] + overwrite=overwrite, + run_parallel=True, + run_tracemalloc=run_tracemalloc) for r in run_list] for future in as_completed_callable(futures): logger.info('New future as_completed result') try: @@ -400,6 +458,26 @@ def _main(executor: Union["MPICommExecutor", "ProcessPoolExecutor"], else: pp_util.cleanup_mandb(err, db_datasets, configs, logger) + if run_profiling: + combined_profile_dir = os.path.join(profile_dir, 'combined_profile.prof') + if os.path.exists(combined_profile_dir): + combined_stats = pstats.Stats(combined_profile_dir) + else: + combined_stats = None + for r in run_list: + profile_file = os.path.join(profile_dir, f'{r[0]["obs_id"]}.prof') + if os.path.exists(profile_file): + try: + stats = pstats.Stats(profile_file) + if combined_stats is None: + combined_stats = stats + else: + combined_stats.add(stats) + except Exception as e: + logger.error(f"cannot get stats for {r[0]['obs_id']}: {e}") + if combined_stats is not None: + combined_stats.dump_stats(combined_profile_dir) + if raise_error and n_fail > 0: raise RuntimeError(f"preprocess_tod: {n_fail}/{len(run_list)} obs_ids failed") @@ -414,6 +492,8 @@ def main(configs: str, planet_obs: bool = False, verbosity: Optional[int] = None, nproc: Optional[int] = 4, + run_profiling: Optional[bool] = False, + run_tracemalloc: Optional[bool] = False, raise_error: Optional[bool] = False): rank, executor, as_completed_callable = get_exec_env(nproc) @@ -431,6 +511,8 @@ def main(configs: str, planet_obs=planet_obs, verbosity=verbosity, nproc=nproc, + run_profiling=run_profiling, + run_tracemalloc=run_tracemalloc, raise_error=raise_error) if __name__ == '__main__': diff --git a/sotodlib/utils/cprofile.py b/sotodlib/utils/cprofile.py new file mode 100644 index 000000000..b8957a92d --- /dev/null +++ b/sotodlib/utils/cprofile.py @@ -0,0 +1,19 @@ +import cProfile, pstats, io + +def cprofile(name): + """Decorator to call CProfile on a function. + """ + # Reference: https://stackoverflow.com/questions/5375624/a-decorator-that-profiles-a-method-call-and-logs-the-profiling-result + def cprofile_func(func): + def wrapper(*args, **kwargs): + prof = cProfile.Profile() + retval = prof.runcall(func, *args, **kwargs) + s = io.StringIO() + sortby = 'cumulative' # time spent by function and called subfunctions + ps = pstats.Stats(prof, stream=s).sort_stats(sortby) + ps.print_stats(20) # print 10 longest calls + print(f"{name} {func.__name__}: {s.getvalue()}") + return retval + + return wrapper + return cprofile_func \ No newline at end of file