Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 19 additions & 4 deletions sotodlib/preprocess/pcore.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from so3g.proj import Ranges, RangesMatrix
from scipy.sparse import csr_array
from matplotlib import pyplot as plt
import tracemalloc

class _Preprocess(object):
"""The base class for Preprocessing modules which defines the required
Expand Down Expand Up @@ -408,7 +409,7 @@ def extend(self, index, other):
def __setitem__(self, index, item):
super().__setitem__(index, self._check_item(item))

def run(self, aman, proc_aman=None, select=True, sim=False, update_plot=False):
def run(self, aman, proc_aman=None, select=True, sim=False, update_plot=False, run_tracemalloc=False):
"""
The main workhorse function for the pipeline class. This function takes
an AxisManager TOD and successively runs the pipeline of preprocessing
Expand Down Expand Up @@ -469,15 +470,26 @@ def run(self, aman, proc_aman=None, select=True, sim=False, update_plot=False):
proc_aman.restrict('dets', det_list)
full = proc_aman.copy()
run_calc = False


tracemalloc.start()

snapshots_process = []
snapshots_calc = []

success = 'end'
for step, process in enumerate(self):
if sim and process.skip_on_sim:
continue
self.logger.debug(f"Running {process.name}")
process.process(aman, proc_aman, sim)
if run_tracemalloc:
snapshot = tracemalloc.take_snapshot()
snapshots_process.append((process.name, snapshot))
if run_calc:
process.calc_and_save(aman, proc_aman)
if run_tracemalloc:
snapshot = tracemalloc.take_snapshot()
snapshots_calc.append((process.name, snapshot))
process.plot(aman, proc_aman, filename=os.path.join(self.plot_dir, '{ctime}/{obsid}', f'{step+1}_{{name}}.png'))
update_full_aman( proc_aman, full, self.wrap_valid)
if update_plot:
Expand All @@ -491,8 +503,11 @@ def run(self, aman, proc_aman=None, select=True, sim=False, update_plot=False):
if aman.dets.count == 0:
success = process.name
break

return full, success
tracemalloc.stop()
if run_tracemalloc:
return full, success, snapshots_process, snapshots_calc
else:
return full, success


class _FracFlaggedMixIn(object):
Expand Down
52 changes: 49 additions & 3 deletions sotodlib/preprocess/preprocess_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,49 @@

from . import _Preprocess, Pipeline, processes

import cProfile, pstats, io



def profile_function(func, profile_path, *args, **kwargs):
"""Runs CProfile on the input function and writes the profile out to a
file using pstats.

Arguments
----------
func : function
The function to be called
profile_path : str
The path to the output profile file.
*args : tuple
Additional positional arguments.
**kwargs : dict
Additional keyword arguments.

Returns
-------
local_vars : Any or None
Either the outputs of the function or None if the profile or function
call fails.
"""

local_vars = None

def wrapper_func():
nonlocal local_vars
local_vars = func(*args, **kwargs)

if profile_path is None:
wrapper_func()
return local_vars
else:
try:
cProfile.runctx('wrapper_func()', globals(), locals(), filename=profile_path)
return local_vars
except Exception as e:
return None


class ArchivePolicy:
"""Storage policy assistance. Helps to determine the HDF5
filename and dataset name for a result.
Expand Down Expand Up @@ -805,7 +848,7 @@ def cleanup_obs(obs_id, policy_dir, errlog, configs, context=None,


def preproc_or_load_group(obs_id, configs_init, dets, configs_proc=None,
logger=None, overwrite=False):
logger=None, overwrite=False, run_tracemalloc=False):
"""
This function is expected to receive a single obs_id, and dets dictionary.
The dets dictionary must match the grouping specified in the preprocess
Expand Down Expand Up @@ -985,7 +1028,10 @@ def preproc_or_load_group(obs_id, configs_init, dets, configs_proc=None,
aman = context_init.get_obs(obs_id, dets=dets)
tags = np.array(context_init.obsdb.get(aman.obs_info.obs_id, tags=True)['tags'])
aman.wrap('tags', tags)
proc_aman, success = pipe_init.run(aman)
if run_tracemalloc:
proc_aman, success, snapshots_process, snapshots_calc = pipe_init.run(aman, run_tracemalloc=run_tracemalloc)
else:
proc_aman, success = pipe_init.run(aman)
aman.wrap('preprocess', proc_aman)
except Exception as e:
error = f'Failed to run initial pipeline: {obs_id} {dets}'
Expand All @@ -1001,7 +1047,7 @@ def preproc_or_load_group(obs_id, configs_init, dets, configs_proc=None,
proc_aman.save(outputs_init['temp_file'], outputs_init['db_data']['dataset'], overwrite)

if configs_proc is None:
return error, outputs_init, [obs_id, dets], aman
return error, outputs_init, [obs_id, dets], aman, snapshots_process, snapshots_calc
else:
try:
outputs_proc = save_group(obs_id, configs_proc, dets, context_proc, subdir='temp_proc')
Expand Down
42 changes: 41 additions & 1 deletion sotodlib/site_pipeline/multilayer_preprocess_tod.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import argparse
import traceback
from typing import Optional
import cProfile, pstats, io
from sotodlib.utils.procs_pool import get_exec_env
import h5py
import copy
Expand Down Expand Up @@ -269,6 +270,12 @@ def get_parser(parser=None):
type=int,
default=4
)
parser.add_argument(
'--profile',
help="Run profiling.",
type=bool,
default=False
)
parser.add_argument(
'--raise-error',
help="Raise an error upon completion if any obsids or groups fail.",
Expand All @@ -291,6 +298,7 @@ def _main(executor: Union["MPICommExecutor", "ProcessPoolExecutor"],
planet_obs: bool = False,
verbosity: Optional[int] = None,
nproc: Optional[int] = 4,
run_profiling: Optional[bool] = False,
raise_error: Optional[bool] = False):

logger = pp_util.init_logger("preprocess", verbosity=verbosity)
Expand Down Expand Up @@ -341,8 +349,18 @@ def _main(executor: Union["MPICommExecutor", "ProcessPoolExecutor"],

n_fail = 0

if run_profiling:
profile_dir = os.path.join(os.path.dirname(configs['archive']['policy']['filename']), 'prof')
if not(os.path.exists(profile_dir)):
os.makedirs(profile_dir)
else:
profile_dir = None

# run write_block obs-ids in parallel at once then write all to the sqlite db.
futures = [executor.submit(multilayer_preprocess_tod, obs_id=r[0]['obs_id'],
futures = [executor.submit(pp_util.profile_func,
func=multilayer_preprocess_tod,
profile_path=os.path.join(profile_dir, f'{r[0]["obs_id"]}.prof') if profile_dir is not None else None,
obs_id=r[0]['obs_id'],
group_list=r[1], verbosity=verbosity,
configs_init=configs_init,
configs_proc=configs_proc,
Expand Down Expand Up @@ -380,6 +398,26 @@ def _main(executor: Union["MPICommExecutor", "ProcessPoolExecutor"],
else:
pp_util.cleanup_mandb(err, db_datasets_proc, configs_proc, logger, overwrite)

if run_profiling:
combined_profile_dir = os.path.join(profile_dir, 'combined_profile.prof')
if os.path.exists(combined_profile_dir):
combined_stats = pstats.Stats(combined_profile_dir)
else:
combined_stats = None
for r in run_list:
profile_file = os.path.join(profile_dir, f'{r[0]["obs_id"]}.prof')
if os.path.exists(profile_file):
try:
stats = pstats.Stats(profile_file)
if combined_stats is None:
combined_stats = stats
else:
combined_stats.add(stats)
except:
logger.error(f"cannot get stats for {r[0]['obs_id']}")
if combined_stats is not None:
combined_stats.dump_stats(combined_profile_dir)

if raise_error and n_fail > 0:
raise RuntimeError(f"multilayer_preprocess_tod: {n_fail}/{len(run_list)} obs_ids failed")

Expand All @@ -396,6 +434,7 @@ def main(configs_init: str,
planet_obs: bool = False,
verbosity: Optional[int] = None,
nproc: Optional[int] = 4,
run_profiling: Optional[bool] = False,
raise_error: Optional[bool] = False):

rank, executor, as_completed_callable = get_exec_env(nproc)
Expand All @@ -414,6 +453,7 @@ def main(configs_init: str,
planet_obs=planet_obs,
verbosity=verbosity,
nproc=nproc,
run_profiling=run_profiling,
raise_error=raise_error)


Expand Down
90 changes: 86 additions & 4 deletions sotodlib/site_pipeline/preprocess_tod.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
import argparse
import traceback
from typing import Optional
import cProfile, pstats, io
import tracemalloc
from sotodlib.utils.procs_pool import get_exec_env
import h5py
import copy
Expand All @@ -19,6 +21,7 @@

logger = sp_util.init_logger("preprocess")


def dummy_preproc(obs_id, group_list, logger,
configs, overwrite, run_parallel):
"""
Expand Down Expand Up @@ -55,7 +58,8 @@ def preprocess_tod(obs_id,
verbosity=0,
group_list=None,
overwrite=False,
run_parallel=False):
run_parallel=False,
run_tracemalloc=False):
"""Meant to be run as part of a batched script, this function calls the
preprocessing pipeline a specific Observation ID and saves the results in
the ManifestDb specified in the configs.
Expand Down Expand Up @@ -139,10 +143,33 @@ def preprocess_tod(obs_id,
logger.info(f"Beginning run for {obs_id}:{group}")
dets = {gb:gg for gb, gg in zip(group_by, group)}
try:
if run_tracemalloc:
tracemalloc.start()

aman = context.get_obs(obs_id, dets=dets)
if run_tracemalloc:
init_snapshot = ('aman', tracemalloc.take_snapshot())
tracemalloc.stop()
tags = np.array(context.obsdb.get(aman.obs_info.obs_id, tags=True)['tags'])
aman.wrap('tags', tags)
proc_aman, success = pipe.run(aman)
if run_tracemalloc:
proc_aman, success, snapshots_process, snapshots_calc = pipe.run(aman, run_tracemalloc=run_tracemalloc)
snapshots_process = [init_snapshot] + snapshots_process
snapshots_calc = [init_snapshot] + snapshots_calc

dest_dataset = obs_id
for gb, g in zip(group_by, group):
if gb == 'detset':
dest_dataset += "_" + g
else:
dest_dataset += "_" + gb + "_" + str(g)
trace_dir = os.path.join(os.path.dirname(configs['archive']['policy']['filename']), "trace")
for i, snap in enumerate(snapshots_process):
snap[1].dump(os.path.join(trace_dir, f"{dest_dataset}_snapshot_process_{snap[0]}_{i}.pkl"))
for i, snap in enumerate(snapshots_calc):
snap[1].dump(os.path.join(trace_dir, f"{dest_dataset}_snapshot_calc_{snap[0]}_{i}.pkl"))
else:
proc_aman, success = pipe.run(aman, run_tracemalloc=run_tracemalloc)

if make_lmsi:
new_plots = os.path.join(configs["plot_dir"],
Expand Down Expand Up @@ -304,6 +331,18 @@ def get_parser(parser=None):
type=int,
default=4
)
parser.add_argument(
'--profile',
help="Run profiling.",
type=bool,
default=False
)
parser.add_argument(
'--tracemalloc',
help="Run tracemalloc.",
type=bool,
default=False
)
parser.add_argument(
'--raise-error',
help="Raise an error upon completion if any obsids or groups fail.",
Expand All @@ -326,6 +365,8 @@ def _main(executor: Union["MPICommExecutor", "ProcessPoolExecutor"],
planet_obs: bool = False,
verbosity: Optional[int] = None,
nproc: Optional[int] = 4,
run_profiling: Optional[bool] = False,
run_tracemalloc: Optional[bool] = False,
raise_error: Optional[bool] = False):

configs, context = pp_util.get_preprocess_context(configs)
Expand Down Expand Up @@ -370,11 +411,28 @@ def _main(executor: Union["MPICommExecutor", "ProcessPoolExecutor"],

n_fail = 0

if run_profiling:
profile_dir = os.path.join(os.path.dirname(configs['archive']['policy']['filename']), 'prof')
if not(os.path.exists(profile_dir)):
os.makedirs(profile_dir)
else:
profile_dir = None

if run_tracemalloc:
trace_dir = os.path.join(os.path.dirname(configs['archive']['policy']['filename']), 'trace')
if not(os.path.exists(trace_dir)):
os.makedirs(trace_dir)

# Run write_block obs-ids in parallel at once then write all to the sqlite db.
futures = [executor.submit(preprocess_tod, obs_id=r[0]['obs_id'],
futures = [executor.submit(pp_util.profile_function,
func=preprocess_tod,
profile_path=os.path.join(profile_dir, f'{r[0]["obs_id"]}.prof') if profile_dir is not None else None,
obs_id=r[0]['obs_id'],
group_list=r[1], verbosity=verbosity,
configs=configs,
overwrite=overwrite, run_parallel=True) for r in run_list]
overwrite=overwrite,
run_parallel=True,
run_tracemalloc=run_tracemalloc) for r in run_list]
for future in as_completed_callable(futures):
logger.info('New future as_completed result')
try:
Expand All @@ -400,6 +458,26 @@ def _main(executor: Union["MPICommExecutor", "ProcessPoolExecutor"],
else:
pp_util.cleanup_mandb(err, db_datasets, configs, logger)

if run_profiling:
combined_profile_dir = os.path.join(profile_dir, 'combined_profile.prof')
if os.path.exists(combined_profile_dir):
combined_stats = pstats.Stats(combined_profile_dir)
else:
combined_stats = None
for r in run_list:
profile_file = os.path.join(profile_dir, f'{r[0]["obs_id"]}.prof')
if os.path.exists(profile_file):
try:
stats = pstats.Stats(profile_file)
if combined_stats is None:
combined_stats = stats
else:
combined_stats.add(stats)
except Exception as e:
logger.error(f"cannot get stats for {r[0]['obs_id']}: {e}")
if combined_stats is not None:
combined_stats.dump_stats(combined_profile_dir)

if raise_error and n_fail > 0:
raise RuntimeError(f"preprocess_tod: {n_fail}/{len(run_list)} obs_ids failed")

Expand All @@ -414,6 +492,8 @@ def main(configs: str,
planet_obs: bool = False,
verbosity: Optional[int] = None,
nproc: Optional[int] = 4,
run_profiling: Optional[bool] = False,
run_tracemalloc: Optional[bool] = False,
raise_error: Optional[bool] = False):

rank, executor, as_completed_callable = get_exec_env(nproc)
Expand All @@ -431,6 +511,8 @@ def main(configs: str,
planet_obs=planet_obs,
verbosity=verbosity,
nproc=nproc,
run_profiling=run_profiling,
run_tracemalloc=run_tracemalloc,
raise_error=raise_error)

if __name__ == '__main__':
Expand Down
Loading