Skip to content

Commit 7a125b3

Browse files
committed
TOAST mlmapmaker: Add option to write components of div
Add option to write individual components of `div`. Reduce peak memory by writing the per-pixel covariance matrix div one component at a time. So instead of MPI send/receive 3 x 3 x npix matrix, only send/receive npix matrix.
1 parent e2890a1 commit 7a125b3

File tree

2 files changed

+39
-14
lines changed

2 files changed

+39
-14
lines changed

sotodlib/toast/ops/mlmapmaker.py

Lines changed: 38 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -217,7 +217,12 @@ class MLMapmaker(Operator):
217217
help="Truncate TOD to an easily factorizable length to ensure efficient FFT.",
218218
)
219219

220-
write_div = Bool(True, help="Write out the noise weight map")
220+
write_div = Unicode(
221+
"all",
222+
allow_none=True,
223+
help="Components (must be 'T', 'QU', 'TQU', 'all', None, or '')"
224+
)
225+
221226
write_hits= Bool(True, help="Write out the hitcount map")
222227

223228
write_rhs = Bool(
@@ -231,7 +236,7 @@ class MLMapmaker(Operator):
231236
)
232237

233238
@traitlets.validate("comps")
234-
def _check_mode(self, proposal):
239+
def _check_comps(self, proposal):
235240
check = proposal["value"]
236241
if check not in ["T", "QU", "TQU"]:
237242
raise traitlets.TraitError("Invalid comps (must be 'T', 'QU' or 'TQU')")
@@ -259,7 +264,7 @@ def _check_det_flag_mask(self, proposal):
259264
return check
260265

261266
@traitlets.validate("dtype_map")
262-
def _check_det_flag_mask(self, proposal):
267+
def _check_dtype_map(self, proposal):
263268
check = proposal["value"]
264269
if check not in ["float", "float64"]:
265270
raise traitlets.TraitError(
@@ -306,6 +311,13 @@ def _check_nmat_type(self, proposal):
306311
raise traitlets.TraitError(msg)
307312
return check
308313

314+
@traitlets.validate("write_div")
315+
def _check_write_div(self, proposal):
316+
check = proposal["value"]
317+
if check not in ["T", "QU", "TQU", None, 'all', '']:
318+
raise traitlets.TraitError("Invalid write_div (must be 'T', 'QU', 'TQU', 'all', None or '')")
319+
return check
320+
309321
def __init__(self, **kwargs):
310322
self.shape = None
311323
self.wcs = None
@@ -513,20 +525,22 @@ def _init_mapmaker(
513525
fname = signal_map.write(prefix, "rhs", signal_map.rhs)
514526
log.info_rank(f"Wrote rhs to {fname}", comm=comm)
515527

516-
if self.write_div:
517-
fname = f"{prefix}sky_div.fits"
518-
if self.skip_existing and os.path.isfile(fname):
519-
log.info_rank(f"Skipping existing div in {fname}", comm=comm)
520-
else:
521-
# FIXME : only writing the TT variance to avoid integer overflow in communication
522-
fname = signal_map.write(prefix, "div", signal_map.div)
523-
# fname = signal_map.write(prefix, "div", signal_map.div[0, 0])
524-
log.info_rank(f"Wrote div to {fname}", comm=comm)
528+
if self.write_div is not None:
529+
# Write each covariance element seperately, to reduce peak memory.
530+
for i,ci in enumerate(self.comps):
531+
for j,cj in enumerate(self.comps):
532+
if ci in self.write_div and cj in self.write_div:
533+
fname = f"{prefix}sky_div{ci}{cj}.fits"
534+
if self.skip_existing and os.path.isfile(fname):
535+
log.info_rank(f"Skipping existing div{ci}{cj} in {fname}", comm=comm)
536+
else:
537+
fname = signal_map.write(prefix, f"div{ci}{cj}", signal_map.div[i, j])
538+
log.info_rank(f"Wrote div{ci}{cj} to {fname}", comm=comm)
525539

526540
if self.write_hits:
527541
fname = f"{prefix}sky_hits.fits"
528542
if self.skip_existing and os.path.isfile(fname):
529-
log.info_rank(f"Skipping existing div in {fname}", comm=comm)
543+
log.info_rank(f"Skipping existing hits in {fname}", comm=comm)
530544
else:
531545
fname = signal_map.write(prefix, "hits", signal_map.hits)
532546
log.info_rank(f"Wrote hits to {fname}", comm=comm)
@@ -610,6 +624,17 @@ def _exec(self, data, detectors=None, **kwargs):
610624
gcomm = data.comm.comm_group
611625
timer.start()
612626

627+
if self.write_div == 'all':
628+
self.write_div = self.comps
629+
elif self.write_div == '':
630+
self.write_div = None
631+
elif self.write_div is not None:
632+
# Make sure all components in self.write_div is in self.comps
633+
for i in self.write_div:
634+
if i not in self.comps:
635+
msg = f"Component '{i}' in write_div={self.write_div} not present in comps={self.comps}"
636+
raise RuntimeError(msg)
637+
613638
if data.comm.group_size != 1:
614639
raise RuntimeError(
615640
"The ML mapmaker requires the TOAST process group size to be exactly one."

tests/test_mapmaker_pointing.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ def test_mapmaker_pointing(self):
139139
truncate_tod=False,
140140
write_hits=True,
141141
write_rhs=False,
142-
write_div=False,
142+
write_div="all",
143143
write_bin=True,
144144
deslope=False,
145145
)

0 commit comments

Comments
 (0)