Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
139 changes: 112 additions & 27 deletions ndcube/extra_coords/table_coord.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,12 +140,14 @@ def _generate_generic_frame(naxes, unit, names=None, physical_types=None):
axes_names=names, name=name, axis_physical_types=physical_types)


def _generate_tabular(lookup_table, interpolation='linear', points_unit=u.pix, **kwargs):
def _generate_tabular(table_points, lookup_table, interpolation='linear', points_unit=u.pix, **kwargs):
"""
Generate a Tabular model class and instance.
"""
if not isinstance(lookup_table, u.Quantity):
raise TypeError("lookup_table must be a Quantity.") # pragma: no cover
if not isinstance(table_points, u.Quantity):
raise TypeError("table_points must be a Quantity.") # pragma: no cover

ndim = lookup_table.ndim

Expand All @@ -160,20 +162,15 @@ def _generate_tabular(lookup_table, interpolation='linear', points_unit=u.pix, *
else:
TabularND = tabular_model(ndim, name=f"Tabular{ndim}D")

# The integer location is at the centre of the pixel.
points = [(np.arange(size) - 0) * points_unit for size in lookup_table.shape]
if len(points) == 1:
points = points[0]

kwargs = {'bounds_error': False,
'fill_value': np.nan,
'method': interpolation,
**kwargs}

if len(lookup_table) == 1:
t = Length1Tabular(points, lookup_table, **kwargs)
t = Length1Tabular(table_points, lookup_table, **kwargs)
else:
t = TabularND(points, lookup_table, **kwargs)
t = TabularND(table_points, lookup_table, **kwargs)

# TODO: Remove this when there is a new gWCS release
# Work around https://github.com/spacetelescope/gwcs/pull/331
Expand All @@ -182,13 +179,13 @@ def _generate_tabular(lookup_table, interpolation='linear', points_unit=u.pix, *
return t


def _generate_compound_model(*lookup_tables, mesh=True):
def _generate_compound_model(points, *lookup_tables, mesh=True):
"""
Takes a set of quantities and returns a ND compound model.
"""
model = _generate_tabular(lookup_tables[0])
for lt in lookup_tables[1:]:
model = model & _generate_tabular(lt)
model = _generate_tabular(points[0], lookup_tables[0])
for pts, lt in zip(points[1:], lookup_tables[1:]):
model = model & _generate_tabular(pts, lt)

if mesh:
return model
Expand All @@ -198,11 +195,11 @@ def _generate_compound_model(*lookup_tables, mesh=True):
return models.Mapping(mapping) | model


def _model_from_quantity(lookup_tables, mesh=False):
def _model_from_quantity(points, lookup_tables, mesh=False):
if len(lookup_tables) > 1:
return _generate_compound_model(*lookup_tables, mesh=mesh)
return _generate_compound_model(points, *lookup_tables, mesh=mesh)

return _generate_tabular(lookup_tables[0])
return _generate_tabular(points[0], lookup_tables[0])


class BaseTableCoordinate(abc.ABC):
Expand Down Expand Up @@ -322,7 +319,7 @@ class QuantityTableCoordinate(BaseTableCoordinate):
a physical type must be given for each component.
"""

def __init__(self, *tables, names=None, physical_types=None):
def __init__(self, *tables, names=None, physical_types=None, grid_points="centers"):
if not all(isinstance(t, u.Quantity) for t in tables):
raise TypeError("All tables must be astropy Quantity objects")
if not all(t.unit.is_equivalent(tables[0].unit) for t in tables):
Expand All @@ -334,6 +331,17 @@ def __init__(self, *tables, names=None, physical_types=None):
"Currently all tables must be 1-D. If you need >1D support, please "
"raise an issue at https://github.con/sunpy/ndcube/issues")

# lookup table must be stored as values at corners and centers.
# If input tables only represent centers or corners, linearly interpolate to get other values.
# If centers or corners provided, the following assumes the tables are 1-D.
if grid_points == "centers":
tables = _get_grid_from_centers(tables)
elif grid_points == "corners":
tables = _get_grid_from_corners(tables)
elif grid_points != "centers and corners":
raise ValueError(f"Unrecognized value for grid_points: {grid_points}. "
"Must be 'centers', 'corners', or 'centers and corners'.")

if isinstance(names, str):
names = [names]
if names is not None and len(names) != ndim:
Expand Down Expand Up @@ -385,16 +393,20 @@ def __getitem__(self, item):
if not (len(item) == len(self.table) or len(item) == self.table[0].ndim):
raise ValueError("Can not slice with incorrect length")

# Since underlying tables store corner and center values, item must be changed to reflect this.
new_item = _convert_cube_item_to_table_item(item)

new_components = defaultdict(list)
new_components["dropped_world_dimensions"] = copy.deepcopy(self._dropped_world_dimensions)

for i, (ele, table) in enumerate(zip(item, self.table)):
self._slice_table(i, table, ele, new_components, whole_slice=item)
for i, (ele, table) in enumerate(zip(new_item, self.table)):
self._slice_table(i, table, ele, new_components, whole_slice=new_item)

names = new_components["names"] or None
physical_types = new_components["physical_types"] or None

ret_table = type(self)(*new_components["tables"], names=names, physical_types=physical_types)
ret_table = type(self)(*new_components["tables"], names=names, physical_types=physical_types,
grid_points="centers and corners")
ret_table._dropped_world_dimensions = new_components["dropped_world_dimensions"]
return ret_table

Expand All @@ -417,7 +429,11 @@ def model(self):
"""
Generate the Astropy Model for this LookupTable.
"""
return _model_from_quantity(self.table, True)
points_unit = u.pix
points = [(np.arange(-1, table.shape[0] - 1) / 2) * points_unit if len(table.shape) == 1
else [(np.arange(-1, size - 1) / 2) * points_unit for size in table.shape]
for table in self.table]
return _model_from_quantity(points, self.table, True)

@property
def ndim(self):
Expand Down Expand Up @@ -597,6 +613,10 @@ def model(self):
"""
Generate the Astropy Model for this LookupTable.
"""
points_unit = u.pix
points = [np.arange(table.shape[0]) * points_unit if len(table.shape) == 1
else [np.arange(size) for size in table.shape] * points_unit
for table in self._sliced_components]
return _model_from_quantity(self._sliced_components, mesh=self.mesh)

@property
Expand Down Expand Up @@ -719,10 +739,24 @@ class TimeTableCoordinate(BaseTableCoordinate):
Default is first time coordinate in table input.
"""

def __init__(self, *tables, names=None, physical_types=None, reference_time=None):
def __init__(self, *tables, names=None, physical_types=None, reference_time=None, grid_points="centers"):
if not len(tables) == 1 and isinstance(tables[0], Time):
raise ValueError("TimeLookupTable can only be constructed from a single Time object.")

# lookup table must be stored as values at corners and centers.
# If input tables only represent centers or corners, linearly interpolate to get other values.
# If centers or corners provided, the following assumes the tables are 1-D.
table = tables[0]
mjd_table = table.mjd
if grid_points in {"centers", "corners"}:
mjd = _get_grid_from_centers([mjd_table]) if grid_points == "centers" else _get_grid_from_corners([mjd_table])
mjd = mjd[0]
t = Time(mjd, format="mjd", scale=table.scale)
tables = [Time(getattr(t, table.format), scale=table.scale)]
elif grid_points != "centers and corners":
raise ValueError(f"Unrecognized value for grid_points: {grid_points}. "
"Must be 'centers', 'corners', or 'centers and corners'.")

if isinstance(names, str):
names = [names]
if isinstance(physical_types, str):
Expand All @@ -738,13 +772,17 @@ def __init__(self, *tables, names=None, physical_types=None, reference_time=None
self.reference_time = reference_time or self.table[0]

def __getitem__(self, item):
if not (isinstance(item, (slice, Integral)) or len(item) == 1):
if isinstance(item, (slice, Integral)):
item = (item,)
elif len(item) != 1:
raise ValueError("Can not slice with incorrect length")

return type(self)(self.table[item],
# Since table grid includes centers and corners, the input item must be changes accordingly.
new_item = _convert_cube_item_to_table_item(item)[0]
return type(self)(self.table[new_item],
names=self.names,
physical_types=self.physical_types,
reference_time=self.reference_time)
reference_time=self.reference_time,
grid_points="centers and corners")

@property
def n_inputs(self):
Expand All @@ -770,8 +808,9 @@ def model(self):
"""
time = self.table
deltas = (time - self.reference_time).to(u.s)

return _model_from_quantity((deltas,), mesh=False)
points_unit = u.pix
points = ((np.arange(-1, self.table.shape[0] - 1) / 2) * points_unit,)
return _model_from_quantity(points, (deltas,), mesh=False)

def interpolate(self, new_array_grids, **kwargs):
"""
Expand Down Expand Up @@ -1007,3 +1046,49 @@ def interpolate(self, new_array_grids, **kwargs):
new_obj = type(self)(*new_table_coordinates)
new_obj._dropped_coords = self._dropped_coords
return new_obj


def _get_grid_from_centers(tables):
if not hasattr(tables, "__len__"):
return tables
new_tables = []
for table in tables:
new_table = np.zeros(len(table) * 2 + 1)
tv = table.value if isinstance(table, u.Quantity) else table
new_table[0] = tv[0] - (tv[1] - tv[0]) / 2
new_table[2:-1:2] = tv[:-1] + (tv[1:] - tv[:-1]) / 2
new_table[1::2] = tv
new_table[-1] = tv[-1] + (tv[-1] - tv[-2]) / 2
if isinstance(table, u.Quantity):
new_table *= table.unit
new_tables.append(new_table)
return tuple(new_tables)


def _get_grid_from_corners(tables):
if not hasattr(tables, "__len__"):
return tables
new_tables = []
for table in tables:
new_table = np.zeros(len(table) * 2 - 1)
tv = table.value if isinstance(table, u.Quantity) else table
new_table[::2] = tv
new_table[1::2] = tv[:-1] + (tv[1:] - tv[:-1]) / 2
if isinstance(table, u.Quantity):
new_table *= table.unit
new_tables.append(new_table)
return tuple(new_tables)


def _convert_cube_item_to_table_item(item):
new_item = []
for idx in item:
# Assume entries in item must be integers or slices. Fancy indexing not supported.
if isinstance(idx, Integral):
new_idx = idx * 2 + 1
else:
new_start = None if idx.start is None else idx.start * 2
new_stop = None if idx.stop is None else idx.stop * 2 + 1
new_idx = slice(new_start, new_stop)
new_item.append(new_idx)
return tuple(new_item)
Loading