Skip to content

Commit 96a9331

Browse files
committed
Make QuantityTableCoordinate and TimeTableCoordinate stored tables as corners and centers.
1 parent ce4d0b1 commit 96a9331

File tree

1 file changed

+79
-20
lines changed

1 file changed

+79
-20
lines changed

ndcube/extra_coords/table_coord.py

Lines changed: 79 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -128,30 +128,27 @@ def _generate_generic_frame(naxes, unit, names=None, physical_types=None):
128128
axes_names=names, name=name, axis_physical_types=physical_types)
129129

130130

131-
def _generate_tabular(lookup_table, interpolation='linear', points_unit=u.pix, **kwargs):
131+
def _generate_tabular(table_points, lookup_table, interpolation='linear', points_unit=u.pix, **kwargs):
132132
"""
133133
Generate a Tabular model class and instance.
134134
"""
135135
if not isinstance(lookup_table, u.Quantity):
136136
raise TypeError("lookup_table must be a Quantity.") # pragma: no cover
137+
if not isinstance(table_points, u.Quantity):
138+
raise TypeError("table_points must be a Quantity.") # pragma: no cover
137139

138140
ndim = lookup_table.ndim
139141
TabularND = tabular_model(ndim, name=f"Tabular{ndim}D")
140142

141-
# The integer location is at the centre of the pixel.
142-
points = [(np.arange(size) - 0) * points_unit for size in lookup_table.shape]
143-
if len(points) == 1:
144-
points = points[0]
145-
146143
kwargs = {'bounds_error': False,
147144
'fill_value': np.nan,
148145
'method': interpolation,
149146
**kwargs}
150147

151148
if len(lookup_table) == 1:
152-
t = Length1Tabular(points, lookup_table, **kwargs)
149+
t = Length1Tabular(table_points, lookup_table, **kwargs)
153150
else:
154-
t = TabularND(points, lookup_table, **kwargs)
151+
t = TabularND(table_points, lookup_table, **kwargs)
155152

156153
# TODO: Remove this when there is a new gWCS release
157154
# Work around https://github.com/spacetelescope/gwcs/pull/331
@@ -160,13 +157,13 @@ def _generate_tabular(lookup_table, interpolation='linear', points_unit=u.pix, *
160157
return t
161158

162159

163-
def _generate_compound_model(*lookup_tables, mesh=True):
160+
def _generate_compound_model(points, *lookup_tables, mesh=True):
164161
"""
165162
Takes a set of quantities and returns a ND compound model.
166163
"""
167-
model = _generate_tabular(lookup_tables[0])
168-
for lt in lookup_tables[1:]:
169-
model = model & _generate_tabular(lt)
164+
model = _generate_tabular(points[0], lookup_tables[0])
165+
for pts, lt in zip(points[1:], lookup_tables[1:]):
166+
model = model & _generate_tabular(pts, lt)
170167

171168
if mesh:
172169
return model
@@ -176,11 +173,11 @@ def _generate_compound_model(*lookup_tables, mesh=True):
176173
return models.Mapping(mapping) | model
177174

178175

179-
def _model_from_quantity(lookup_tables, mesh=False):
176+
def _model_from_quantity(points, lookup_tables, mesh=False):
180177
if len(lookup_tables) > 1:
181-
return _generate_compound_model(*lookup_tables, mesh=mesh)
178+
return _generate_compound_model(points, *lookup_tables, mesh=mesh)
182179

183-
return _generate_tabular(lookup_tables[0])
180+
return _generate_tabular(points[0], lookup_tables[0])
184181

185182

186183
class BaseTableCoordinate(abc.ABC):
@@ -301,7 +298,7 @@ class QuantityTableCoordinate(BaseTableCoordinate):
301298
a physical type must be given for each component.
302299
"""
303300

304-
def __init__(self, *tables, names=None, physical_types=None):
301+
def __init__(self, *tables, names=None, physical_types=None, grid_points="centers"):
305302
if not all([isinstance(t, u.Quantity) for t in tables]):
306303
raise TypeError("All tables must be astropy Quantity objects")
307304
if not all([t.unit.is_equivalent(tables[0].unit) for t in tables]):
@@ -313,6 +310,17 @@ def __init__(self, *tables, names=None, physical_types=None):
313310
"Currently all tables must be 1-D. If you need >1D support, please "
314311
"raise an issue at https://github.con/sunpy/ndcube/issues")
315312

313+
# lookup table must be stored as values at corners and centers.
314+
# If input tables only represent centers or corners, linearly interpolate to get other values.
315+
# If centers or corners provided, the following assumes the tables are 1-D.
316+
if grid_points == "centers":
317+
tables = _get_grid_from_centers(tables)
318+
elif grid_points == "corners":
319+
tables = _get_grid_from_corners(tables)
320+
elif grid_points != "centers and corners":
321+
raise ValueError(f"Unrecognized value for grid_points: {grid_points}. "
322+
"Must be 'centers', 'corners', or 'centers and corners'.")
323+
316324
if isinstance(names, str):
317325
names = [names]
318326
if names is not None and len(names) != ndim:
@@ -396,7 +404,11 @@ def model(self):
396404
"""
397405
Generate the Astropy Model for this LookupTable.
398406
"""
399-
return _model_from_quantity(self.table, True)
407+
points_unit = u.pix
408+
points = [(np.arange(-1, table.shape[0] - 1) / 2) * points_unit if len(table.shape) == 1
409+
else [(np.arange(-1, size - 1) / 2) * points_unit for size in table.shape]
410+
for table in self.table]
411+
return _model_from_quantity(points, self.table, True)
400412

401413
@property
402414
def ndim(self):
@@ -577,6 +589,10 @@ def model(self):
577589
"""
578590
Generate the Astropy Model for this LookupTable.
579591
"""
592+
points_unit = u.pix
593+
points = [np.arange(table.shape[0]) * points_unit if len(table.shape) == 1
594+
else [np.arange(size) for size in table.shape] * points_unit
595+
for table in self._sliced_components]
580596
return _model_from_quantity(self._sliced_components, mesh=self.mesh)
581597

582598
@property
@@ -701,10 +717,24 @@ class TimeTableCoordinate(BaseTableCoordinate):
701717
Default is first time coordinate in table input.
702718
"""
703719

704-
def __init__(self, *tables, names=None, physical_types=None, reference_time=None):
720+
def __init__(self, *tables, names=None, physical_types=None, reference_time=None, grid_points="centers"):
705721
if not len(tables) == 1 and isinstance(tables[0], Time):
706722
raise ValueError("TimeLookupTable can only be constructed from a single Time object.")
707723

724+
# lookup table must be stored as values at corners and centers.
725+
# If input tables only represent centers or corners, linearly interpolate to get other values.
726+
# If centers or corners provided, the following assumes the tables are 1-D.
727+
table = tables[0]
728+
mjd_table = table.mjd
729+
if grid_points in {"centers", "corners"}:
730+
mjd = _get_grid_from_centers([mjd_table]) if grid_points == "centers" else _get_grid_from_corners([mjd_table])
731+
mjd = mjd[0]
732+
t = Time(mjd, format="mjd", scale=table.scale)
733+
tables = [Time(getattr(t, table.format), scale=table.scale)]
734+
elif grid_points != "centers and corners":
735+
raise ValueError(f"Unrecognized value for grid_points: {grid_points}. "
736+
"Must be 'centers', 'corners', or 'centers and corners'.")
737+
708738
if isinstance(names, str):
709739
names = [names]
710740
if isinstance(physical_types, str):
@@ -752,8 +782,9 @@ def model(self):
752782
"""
753783
time = self.table
754784
deltas = (time - self.reference_time).to(u.s)
755-
756-
return _model_from_quantity((deltas,), mesh=False)
785+
points_unit = u.pix
786+
points = ((np.arange(-1, self.table.shape[0] - 1) / 2) * points_unit,)
787+
return _model_from_quantity(points, (deltas,), mesh=False)
757788

758789
def interpolate(self, new_array_grids, **kwargs):
759790
"""
@@ -970,3 +1001,31 @@ def interpolate(self, new_array_grids, **kwargs):
9701001
new_obj = type(self)(*new_table_coordinates)
9711002
new_obj._dropped_coords = self._dropped_coords
9721003
return new_obj
1004+
1005+
1006+
def _get_grid_from_centers(tables):
1007+
new_tables = []
1008+
for table in tables:
1009+
new_table = np.zeros(len(table) * 2 + 1)
1010+
tv = table.value if isinstance(table, u.Quantity) else table
1011+
new_table[0] = tv[0] - (tv[1] - tv[0]) / 2
1012+
new_table[2:-1:2] = tv[:-1] + (tv[1:] - tv[:-1]) / 2
1013+
new_table[1::2] = tv
1014+
new_table[-1] = tv[-1] + (tv[-1] - tv[-2]) / 2
1015+
if isinstance(table, u.Quantity):
1016+
new_table *= table.unit
1017+
new_tables.append(new_table)
1018+
return tuple(new_tables)
1019+
1020+
1021+
def _get_grid_from_corners(tables):
1022+
new_tables = []
1023+
for table in tables:
1024+
new_table = np.zeros(len(table) * 2 - 1)
1025+
tv = table.value if isinstance(table, u.Quantity) else table
1026+
new_table[::2] = tv
1027+
new_table[1::2] = tv[:-1] + (tv[1:] - tv[:-1]) / 2
1028+
if isinstance(table, u.Quantity):
1029+
new_table *= table.unit
1030+
new_tables.append(new_table)
1031+
return tuple(new_tables)

0 commit comments

Comments
 (0)