Skip to content

Commit df8f1c4

Browse files
authored
Added in reindex_axis and handling for tdets (#1506)
* Added in reindex_axis and handling for tdets Added in reindex_axis function in the AxisManager class which allows one to reindex all datasets that are assigned to a specified axis in an AxisManager. This is useful for dealing with tdets in get_obs. Also added in handling for tdets in get_obs. The handling merges the dets+tdets axis, inserts the tdets signal into the dets signal dataset, and reindexes every dataset in the obs_aman that is assigned to the dets axis (det_info... etc.). These axes will have nans inserted where the tdets exist in the new merged dets axis. * updated order of axis existence check Moved the axis existence check to happen before any aman copying takes place (if in_place=False). This will prevent an aman copy from being made right before a ValueError which would cause a memory leak in certain circumstances. * Fixed tdet_id cleanup Fixed the portion of get_obs that added tdet_ids into the dets axis. Previously it would only operate on the highest level aman's dets axis so the sub-amans dets axes wouldn't match. It's been rewritten to recursively (scary) delve through all lower-level amans to fix all dets axes. * Fixed default reindexing behavior and load_book Removed the default behavior of removing the tdets axis and data. This data will now never be removed. Added in `reindex_dets` arg to `get_obs`. `special_dets=True` will now load load the special dets while `reindex_dets=True` will reindex them into the `dets` axis. This will copy the special dets data into the dets signal, bands, channels, and readout_ids but will not remove the tdets data (in case anyone just wants to look at fixed tones). Used Kyohei's code for the special_dets `load_book` fix. * Added check for tones and det_info Added check for reindexing in `get_obs` to ensure tones and det_info data exists before attempting to reindex. Reindexing is logged and skipped if either are nonexistant * Changed logger behavior for missing tones Changed the logger behavior from `error` to `info` if get_obs is called with `special_channels=True` and `reindex_dets=True` when there are no tones to load. Minor changes for if statements that retain the same behavior. * Fixed np.full bug Previous commit introduced a small bug on a `np.full` call. Fixed this issue.
1 parent 550bc46 commit df8f1c4

File tree

3 files changed

+254
-13
lines changed

3 files changed

+254
-13
lines changed

sotodlib/core/axisman.py

Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -809,6 +809,144 @@ def restrict_axes(self, axes, in_place=True):
809809
dest._fields[k] = v[sslice]
810810
return dest
811811

812+
def reindex_axis(self, axis, indexes, in_place=True):
813+
"""
814+
Reindexes all data that is assigned to a specified axis
815+
with a new list/array of indexes.
816+
This is particularly useful if the number of detectors
817+
between the meta and obs data don't match.
818+
This function will recursively delve through all
819+
AxisManagers in aman and will reindex every
820+
data array that is found assigned to an axis
821+
matching the specified axis.
822+
823+
Args:
824+
axis (str): The name of the axis in the aman to reindex.
825+
indexes (int array): an array of ints with length
826+
equal to the length of the new array
827+
and values equal to the idxs of the
828+
values in the data to be reindexed.
829+
Indexes that should be left as nan in
830+
the new array should be set to -1 or nan.
831+
832+
For example:
833+
data = [1,3,5], indexes = [0, -1, 2, 1]
834+
would result in new_data = [1, nan, 5, 3]
835+
836+
in_place (bool): If in_place == True, the intersection is
837+
applied to self. Otherwise, a new object is returned,
838+
with data copied out.
839+
"""
840+
# Check if axis even exists first
841+
if axis not in self._axes.keys():
842+
raise ValueError(f"Axis doesn't exist in aman! \
843+
Can't re-index along {axis}")
844+
845+
if in_place:
846+
aman = self
847+
else:
848+
aman = self.copy(axes_only=True)
849+
aman._assignments.update(self._assignments)
850+
851+
# Loop through ever assignment and reindex along
852+
# each that is tied to the axis in question
853+
new_axes = {}
854+
reindexed_vs = {}
855+
assignments = list(aman._assignments.keys())
856+
for assignment in assignments:
857+
axes = aman._assignments[assignment]
858+
# If this assignment isn't connected to our axis
859+
# we can skip it.
860+
if axis not in axes:
861+
continue
862+
863+
v = aman[assignment]
864+
865+
if isinstance(v, AxisManager):
866+
# If we hit an axis manager,
867+
# recursively reindex it as well. Scary!
868+
new_v = v.reindex_axis(axis, indexes)
869+
870+
else:
871+
# By this point we have a non AxisManager
872+
# assignment assigned to only our axis.
873+
# Build new array with the correct indexes.
874+
shape = [len(indexes)]
875+
if isinstance(v, np.ndarray):
876+
for s in np.shape(v)[1:]:
877+
shape.append(s)
878+
879+
new_v = np.empty(shape, dtype=v.dtype)
880+
if isinstance(v.dtype, float):
881+
# Fill any float arrays with nans
882+
# Non float arrays may have weird
883+
# behavior for newly added indexes.
884+
# Oh well.
885+
new_v *= np.nan
886+
887+
for i, index in enumerate(indexes):
888+
if np.isnan(index) or not (0 <= index < len(v)):
889+
continue
890+
891+
new_v[i] = v[int(index)]
892+
893+
reindexed_vs[assignment] = new_v
894+
new_axes[assignment] = np.array(axes)
895+
896+
# Destroy the old assignment
897+
aman.move(name=assignment, new_name=None)
898+
899+
old_axis = aman._axes[axis]
900+
901+
# Recreate the axis
902+
if isinstance(old_axis, IndexAxis):
903+
# Build a new axis that has a length equal to the indexes arg.
904+
new_axis = IndexAxis(name=axis, count=len(indexes))
905+
906+
if isinstance(old_axis, LabelAxis):
907+
# A LabelAxis dtype may vary by length,
908+
# we'll insert empty values for the newly added idxs.
909+
# This will produce empty strings
910+
# ('') for det_ids, readout_ids, etc.
911+
# It may produce strange behavior
912+
# for non string like objects. Be careful!
913+
vals = np.empty(len(indexes), dtype=old_axis.vals.dtype)
914+
for i, index in enumerate(indexes):
915+
if np.isnan(index) or not (0 <= int(index) < len(old_axis.vals)):
916+
continue
917+
vals[i] = old_axis.vals[int(index)]
918+
919+
new_axis = LabelAxis(name=axis, vals=vals)
920+
921+
if isinstance(old_axis, OffsetAxis):
922+
new_axis = OffsetAxis(count=len(indexes),
923+
offset=old_axis.offset,
924+
origin_tag=old_axis.origin_tag)
925+
926+
# We're done with this old axis now, destroy it.
927+
del aman._axes[axis]
928+
# Add in the reindexed axis.
929+
aman.add_axis(new_axis)
930+
931+
# Now we'll go through all the reindexed data and wrap it back in.
932+
for assignment, axes in new_axes.items():
933+
# Build the axis map for wrapping the data.
934+
ax_map = []
935+
for i, ax in enumerate(axes):
936+
# Axis map looks like a list of numbered tuples.
937+
ax_map.append((i, ax))
938+
939+
vs = reindexed_vs[assignment]
940+
# Need to wrap aman's with no axismap
941+
if isinstance(vs, AxisManager):
942+
aman.wrap(name=assignment, data=vs)
943+
944+
else: # Everything else needs an axismap
945+
aman.wrap(name=assignment, data=vs, axis_map=ax_map)
946+
947+
# Everything is now reindexed and rewrapped. Done!
948+
return aman # Return for rewrapping if recursively called.
949+
812950
@staticmethod
813951
def _broadcast_selector(sslice):
814952
"""sslice is a list of selectors, which will typically be slice(), or

sotodlib/core/context.py

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,7 @@ def get_obs(self,
174174
no_signal=None,
175175
no_headers=None,
176176
special_channels=None,
177+
reindex_dets=None,
177178
loader_type=None,
178179
):
179180
"""Load TOD and supporting metadata for some observation.
@@ -217,10 +218,20 @@ def get_obs(self,
217218
information that tags along with the signal.
218219
special_channels (bool): If True, load "special" readout
219220
channels that are normally skipped (e.g. fixed tones).
221+
reindex_dets (bool) If True, reindexes the axismanager
222+
"dets" axis to include any special dets that may have
223+
been loaded. All special det signals, bands, channels,
224+
and readout_ids will be inserted along the dets axis
225+
and respective data arrays. Does not destroy special
226+
dets axes like the "tdets" axis. WARNING: ~Doubles
227+
run time and instantaenous memory usage as the signal
228+
dataset has to effectively be copied! Memory returned
229+
after reindexing is complete.
220230
loader_type (str): Name of the registered TOD loader
221231
function to use (this will override whatever is specified
222232
in context.yaml).
223233
234+
224235
Notes:
225236
It is acceptable to pass the ``obs_id`` argument by position
226237
(first), but all other arguments should be passed by
@@ -329,6 +340,97 @@ def get_obs(self,
329340
_det_info.move(k, None)
330341
meta.det_info.merge(_det_info)
331342
aman.merge(meta)
343+
344+
# Deal with special channels, if they exist.
345+
# We will merge the dets and tdets axes
346+
# And merge the tdets signal into the dets signal, band, and channels.
347+
# nans will be inserted into every other dataset assigned to the dets axis.
348+
if special_channels and reindex_dets:
349+
# Check for special channels (tones).
350+
if 'tones' not in aman:
351+
logger.info('"tones" not found in aman, no special channels to reindex!')
352+
353+
# Check there is det_info to reindex.
354+
elif 'det_info' not in aman:
355+
logger.error('"det_info" not found in aman, no dets to reindex!')
356+
357+
# Both tones and det_info exist.
358+
else:
359+
# Grab all band and channel info for dets + tdets
360+
det_bands = aman.det_info.smurf.band
361+
det_channels = aman.det_info.smurf.channel
362+
tdet_bands = aman.tones.band
363+
tdet_channels = aman.tones.channel
364+
365+
# Create a sorted array of dets + tdets
366+
special_band_ch = [(b, c) for b, c in zip(tdet_bands, tdet_channels)]
367+
normal_band_ch = [(b, c) for b, c in zip(det_bands, det_channels)]
368+
band_ch = np.array(sorted(normal_band_ch + special_band_ch))
369+
370+
# Grab the det idxs from the det band + channels
371+
det_indexes = np.full(len(band_ch), np.nan)
372+
for i, (b, c) in enumerate(band_ch):
373+
w = np.where((det_bands == b) & (det_channels == c))[0]
374+
if len(w) == 0:
375+
continue
376+
377+
det_indexes[i] = w[0]
378+
379+
# Grab the tdet idxs from the tdet band + channels
380+
tdet_indexes = np.full(len(band_ch), np.nan)
381+
for i, (b, c) in enumerate(band_ch):
382+
w = np.where((tdet_bands == b) & (tdet_channels == c))[0]
383+
if len(w) == 0:
384+
continue
385+
386+
tdet_indexes[i] = w[0]
387+
388+
# Use the det idxs to reindex all datasets assigned to the dets axis
389+
# This will set the len to dets + tdets
390+
# And will insert nans where the tdet channels exist
391+
# in the sorted band_ch array
392+
aman.reindex_axis(axis='dets', indexes=det_indexes)
393+
394+
# Finally use the tdet idxs to fill in the tdet data
395+
# For the signal, band, and channels
396+
for i, tidx in enumerate(tdet_indexes):
397+
if np.isnan(tidx):
398+
continue
399+
400+
aman.signal[i] = aman.tones.signal[int(tidx)]
401+
aman.det_info.smurf.channel[i] = aman.tones.channel[int(tidx)]
402+
aman.det_info.smurf.band[i] = aman.tones.band[int(tidx)]
403+
404+
def add_tdet_ids(aman, tdet_indexes, tdet_ids):
405+
"""
406+
Small Function for recursively delving through
407+
all amans in the arg aman to add tdet ids into
408+
the dets axes.
409+
"""
410+
# Check all assignments for any AxisManagers
411+
for assignment, axes in aman._assignments.items():
412+
if isinstance(aman[assignment], AxisManager) and "dets" in axes:
413+
# If they have a Dets axis, go fix it first.
414+
add_tdet_ids(aman[assignment], tdet_indexes, tdet_ids)
415+
416+
# Skip current aman if it doesn't have a dets axis.
417+
# This is the recursion base case
418+
# Assuming someone didn't make an infinitely deep aman
419+
# Or circular amans....
420+
if "dets" not in aman._axes.keys():
421+
return
422+
423+
# If this aman has a dets axis, add in tdet ids.
424+
for i, tidx in enumerate(tdet_indexes):
425+
if np.isnan(tidx):
426+
continue
427+
aman.dets.vals[i] = tdet_ids[int(tidx)]
428+
429+
# Add in tdet ids to all amans.
430+
add_tdet_ids(aman, tdet_indexes, aman.tdets.vals)
431+
432+
# obs_aman tdet info is merged!
433+
332434
return aman
333435

334436
def get_meta(self,

sotodlib/io/load_book.py

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -525,19 +525,20 @@ def _concat_filesets(results, ancil=None, timestamps=None,
525525
# Should look like this: sch_NONE_2_326
526526
b, c = map(int, k.split('_')[2:])
527527
tone_info.append((v['stream_id'], v['stream_id'] + f'_{b}_{c}', b, c))
528-
ts, tk, tb, tc = map(np.array, zip(*tone_info))
529-
tman = core.AxisManager(core.LabelAxis('tdets', tk),
530-
aman.samps)
531-
aman.wrap('tones', tman)
532-
aman.tones.wrap('stream_id', ts, axis_map=[(0, 'tdets')])
533-
aman.tones.wrap('band', tb, axis_map=[(0, 'tdets')])
534-
aman.tones.wrap('channel', tc, axis_map=[(0, 'tdets')])
535-
aman.tones.wrap_new('signal', shape=('tdets', 'samps'), dtype='float32')
536-
dets_ofs = 0
537-
for v in results.values():
538-
d = v['tones'].data
539-
aman.tones['signal'][dets_ofs:dets_ofs + len(d)] = d
540-
dets_ofs += len(d)
528+
if tone_info:
529+
ts, tk, tb, tc = map(np.array, zip(*tone_info))
530+
tman = core.AxisManager(core.LabelAxis('tdets', tk),
531+
aman.samps)
532+
aman.wrap('tones', tman)
533+
aman.tones.wrap('stream_id', ts, axis_map=[(0, 'tdets')])
534+
aman.tones.wrap('band', tb, axis_map=[(0, 'tdets')])
535+
aman.tones.wrap('channel', tc, axis_map=[(0, 'tdets')])
536+
aman.tones.wrap_new('signal', shape=('tdets', 'samps'), dtype='float32')
537+
dets_ofs = 0
538+
for v in results.values():
539+
d = v['tones'].data
540+
aman.tones['signal'][dets_ofs:dets_ofs + len(d)] = d
541+
dets_ofs += len(d)
541542

542543
# In sims, or if no_headers, the primary block may be unpopulated.
543544
if any([(v['primary'] is not None and v['primary'].data is not None)

0 commit comments

Comments
 (0)