diff --git a/ndcube/extra_coords/table_coord.py b/ndcube/extra_coords/table_coord.py index e975393d4..bae240183 100644 --- a/ndcube/extra_coords/table_coord.py +++ b/ndcube/extra_coords/table_coord.py @@ -140,12 +140,14 @@ def _generate_generic_frame(naxes, unit, names=None, physical_types=None): axes_names=names, name=name, axis_physical_types=physical_types) -def _generate_tabular(lookup_table, interpolation='linear', points_unit=u.pix, **kwargs): +def _generate_tabular(table_points, lookup_table, interpolation='linear', points_unit=u.pix, **kwargs): """ Generate a Tabular model class and instance. """ if not isinstance(lookup_table, u.Quantity): raise TypeError("lookup_table must be a Quantity.") # pragma: no cover + if not isinstance(table_points, u.Quantity): + raise TypeError("table_points must be a Quantity.") # pragma: no cover ndim = lookup_table.ndim @@ -160,20 +162,15 @@ def _generate_tabular(lookup_table, interpolation='linear', points_unit=u.pix, * else: TabularND = tabular_model(ndim, name=f"Tabular{ndim}D") - # The integer location is at the centre of the pixel. - points = [(np.arange(size) - 0) * points_unit for size in lookup_table.shape] - if len(points) == 1: - points = points[0] - kwargs = {'bounds_error': False, 'fill_value': np.nan, 'method': interpolation, **kwargs} if len(lookup_table) == 1: - t = Length1Tabular(points, lookup_table, **kwargs) + t = Length1Tabular(table_points, lookup_table, **kwargs) else: - t = TabularND(points, lookup_table, **kwargs) + t = TabularND(table_points, lookup_table, **kwargs) # TODO: Remove this when there is a new gWCS release # Work around https://github.com/spacetelescope/gwcs/pull/331 @@ -182,13 +179,13 @@ def _generate_tabular(lookup_table, interpolation='linear', points_unit=u.pix, * return t -def _generate_compound_model(*lookup_tables, mesh=True): +def _generate_compound_model(points, *lookup_tables, mesh=True): """ Takes a set of quantities and returns a ND compound model. """ - model = _generate_tabular(lookup_tables[0]) - for lt in lookup_tables[1:]: - model = model & _generate_tabular(lt) + model = _generate_tabular(points[0], lookup_tables[0]) + for pts, lt in zip(points[1:], lookup_tables[1:]): + model = model & _generate_tabular(pts, lt) if mesh: return model @@ -198,11 +195,11 @@ def _generate_compound_model(*lookup_tables, mesh=True): return models.Mapping(mapping) | model -def _model_from_quantity(lookup_tables, mesh=False): +def _model_from_quantity(points, lookup_tables, mesh=False): if len(lookup_tables) > 1: - return _generate_compound_model(*lookup_tables, mesh=mesh) + return _generate_compound_model(points, *lookup_tables, mesh=mesh) - return _generate_tabular(lookup_tables[0]) + return _generate_tabular(points[0], lookup_tables[0]) class BaseTableCoordinate(abc.ABC): @@ -322,7 +319,7 @@ class QuantityTableCoordinate(BaseTableCoordinate): a physical type must be given for each component. """ - def __init__(self, *tables, names=None, physical_types=None): + def __init__(self, *tables, names=None, physical_types=None, grid_points="centers"): if not all(isinstance(t, u.Quantity) for t in tables): raise TypeError("All tables must be astropy Quantity objects") if not all(t.unit.is_equivalent(tables[0].unit) for t in tables): @@ -334,6 +331,17 @@ def __init__(self, *tables, names=None, physical_types=None): "Currently all tables must be 1-D. If you need >1D support, please " "raise an issue at https://github.con/sunpy/ndcube/issues") + # lookup table must be stored as values at corners and centers. + # If input tables only represent centers or corners, linearly interpolate to get other values. + # If centers or corners provided, the following assumes the tables are 1-D. + if grid_points == "centers": + tables = _get_grid_from_centers(tables) + elif grid_points == "corners": + tables = _get_grid_from_corners(tables) + elif grid_points != "centers and corners": + raise ValueError(f"Unrecognized value for grid_points: {grid_points}. " + "Must be 'centers', 'corners', or 'centers and corners'.") + if isinstance(names, str): names = [names] if names is not None and len(names) != ndim: @@ -385,16 +393,20 @@ def __getitem__(self, item): if not (len(item) == len(self.table) or len(item) == self.table[0].ndim): raise ValueError("Can not slice with incorrect length") + # Since underlying tables store corner and center values, item must be changed to reflect this. + new_item = _convert_cube_item_to_table_item(item) + new_components = defaultdict(list) new_components["dropped_world_dimensions"] = copy.deepcopy(self._dropped_world_dimensions) - for i, (ele, table) in enumerate(zip(item, self.table)): - self._slice_table(i, table, ele, new_components, whole_slice=item) + for i, (ele, table) in enumerate(zip(new_item, self.table)): + self._slice_table(i, table, ele, new_components, whole_slice=new_item) names = new_components["names"] or None physical_types = new_components["physical_types"] or None - ret_table = type(self)(*new_components["tables"], names=names, physical_types=physical_types) + ret_table = type(self)(*new_components["tables"], names=names, physical_types=physical_types, + grid_points="centers and corners") ret_table._dropped_world_dimensions = new_components["dropped_world_dimensions"] return ret_table @@ -417,7 +429,11 @@ def model(self): """ Generate the Astropy Model for this LookupTable. """ - return _model_from_quantity(self.table, True) + points_unit = u.pix + points = [(np.arange(-1, table.shape[0] - 1) / 2) * points_unit if len(table.shape) == 1 + else [(np.arange(-1, size - 1) / 2) * points_unit for size in table.shape] + for table in self.table] + return _model_from_quantity(points, self.table, True) @property def ndim(self): @@ -597,6 +613,10 @@ def model(self): """ Generate the Astropy Model for this LookupTable. """ + points_unit = u.pix + points = [np.arange(table.shape[0]) * points_unit if len(table.shape) == 1 + else [np.arange(size) for size in table.shape] * points_unit + for table in self._sliced_components] return _model_from_quantity(self._sliced_components, mesh=self.mesh) @property @@ -719,10 +739,24 @@ class TimeTableCoordinate(BaseTableCoordinate): Default is first time coordinate in table input. """ - def __init__(self, *tables, names=None, physical_types=None, reference_time=None): + def __init__(self, *tables, names=None, physical_types=None, reference_time=None, grid_points="centers"): if not len(tables) == 1 and isinstance(tables[0], Time): raise ValueError("TimeLookupTable can only be constructed from a single Time object.") + # lookup table must be stored as values at corners and centers. + # If input tables only represent centers or corners, linearly interpolate to get other values. + # If centers or corners provided, the following assumes the tables are 1-D. + table = tables[0] + mjd_table = table.mjd + if grid_points in {"centers", "corners"}: + mjd = _get_grid_from_centers([mjd_table]) if grid_points == "centers" else _get_grid_from_corners([mjd_table]) + mjd = mjd[0] + t = Time(mjd, format="mjd", scale=table.scale) + tables = [Time(getattr(t, table.format), scale=table.scale)] + elif grid_points != "centers and corners": + raise ValueError(f"Unrecognized value for grid_points: {grid_points}. " + "Must be 'centers', 'corners', or 'centers and corners'.") + if isinstance(names, str): names = [names] if isinstance(physical_types, str): @@ -738,13 +772,17 @@ def __init__(self, *tables, names=None, physical_types=None, reference_time=None self.reference_time = reference_time or self.table[0] def __getitem__(self, item): - if not (isinstance(item, (slice, Integral)) or len(item) == 1): + if isinstance(item, (slice, Integral)): + item = (item,) + elif len(item) != 1: raise ValueError("Can not slice with incorrect length") - - return type(self)(self.table[item], + # Since table grid includes centers and corners, the input item must be changes accordingly. + new_item = _convert_cube_item_to_table_item(item)[0] + return type(self)(self.table[new_item], names=self.names, physical_types=self.physical_types, - reference_time=self.reference_time) + reference_time=self.reference_time, + grid_points="centers and corners") @property def n_inputs(self): @@ -770,8 +808,9 @@ def model(self): """ time = self.table deltas = (time - self.reference_time).to(u.s) - - return _model_from_quantity((deltas,), mesh=False) + points_unit = u.pix + points = ((np.arange(-1, self.table.shape[0] - 1) / 2) * points_unit,) + return _model_from_quantity(points, (deltas,), mesh=False) def interpolate(self, new_array_grids, **kwargs): """ @@ -1007,3 +1046,49 @@ def interpolate(self, new_array_grids, **kwargs): new_obj = type(self)(*new_table_coordinates) new_obj._dropped_coords = self._dropped_coords return new_obj + + +def _get_grid_from_centers(tables): + if not hasattr(tables, "__len__"): + return tables + new_tables = [] + for table in tables: + new_table = np.zeros(len(table) * 2 + 1) + tv = table.value if isinstance(table, u.Quantity) else table + new_table[0] = tv[0] - (tv[1] - tv[0]) / 2 + new_table[2:-1:2] = tv[:-1] + (tv[1:] - tv[:-1]) / 2 + new_table[1::2] = tv + new_table[-1] = tv[-1] + (tv[-1] - tv[-2]) / 2 + if isinstance(table, u.Quantity): + new_table *= table.unit + new_tables.append(new_table) + return tuple(new_tables) + + +def _get_grid_from_corners(tables): + if not hasattr(tables, "__len__"): + return tables + new_tables = [] + for table in tables: + new_table = np.zeros(len(table) * 2 - 1) + tv = table.value if isinstance(table, u.Quantity) else table + new_table[::2] = tv + new_table[1::2] = tv[:-1] + (tv[1:] - tv[:-1]) / 2 + if isinstance(table, u.Quantity): + new_table *= table.unit + new_tables.append(new_table) + return tuple(new_tables) + + +def _convert_cube_item_to_table_item(item): + new_item = [] + for idx in item: + # Assume entries in item must be integers or slices. Fancy indexing not supported. + if isinstance(idx, Integral): + new_idx = idx * 2 + 1 + else: + new_start = None if idx.start is None else idx.start * 2 + new_stop = None if idx.stop is None else idx.stop * 2 + 1 + new_idx = slice(new_start, new_stop) + new_item.append(new_idx) + return tuple(new_item)