Skip to content

Commit c28ad88

Browse files
committed
First commit for making QuantityTableCoordinate support pixel corners.
1 parent ce4d0b1 commit c28ad88

File tree

2 files changed

+76
-22
lines changed

2 files changed

+76
-22
lines changed

ndcube/extra_coords/extra_coords.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ def add(self,
5050
name: str | Iterable[str],
5151
array_dimension: int | Iterable[int],
5252
lookup_table: Any,
53+
points: Iterable[float] = None,
5354
physical_types: str | Iterable[str] = None,
5455
**kwargs):
5556
"""
@@ -204,7 +205,7 @@ def from_lookup_tables(cls, names, pixel_dimensions, lookup_tables, physical_typ
204205

205206
return extra_coords
206207

207-
def add(self, name, array_dimension, lookup_table, physical_types=None, **kwargs):
208+
def add(self, name, array_dimension, lookup_table, points=None, physical_types=None, **kwargs):
208209
# docstring in ABC
209210

210211
if self._wcs is not None:
@@ -217,13 +218,13 @@ def add(self, name, array_dimension, lookup_table, physical_types=None, **kwargs
217218
if isinstance(lookup_table, BaseTableCoordinate):
218219
coord = lookup_table
219220
elif isinstance(lookup_table, Time):
220-
coord = TimeTableCoordinate(lookup_table, physical_types=physical_types, **kwargs)
221+
coord = TimeTableCoordinate(lookup_table, points=points, physical_types=physical_types, **kwargs)
221222
elif isinstance(lookup_table, SkyCoord):
222223
coord = SkyCoordTableCoordinate(lookup_table, physical_types=physical_types, **kwargs)
223224
elif isinstance(lookup_table, (list, tuple)):
224-
coord = QuantityTableCoordinate(*lookup_table, physical_types=physical_types, **kwargs)
225+
coord = QuantityTableCoordinate(*lookup_table, points=points, physical_types=physical_types, **kwargs)
225226
elif isinstance(lookup_table, u.Quantity):
226-
coord = QuantityTableCoordinate(lookup_table, physical_types=physical_types, **kwargs)
227+
coord = QuantityTableCoordinate(lookup_table, points=points, physical_types=physical_types, **kwargs)
227228
else:
228229
raise TypeError(f"The input type {type(lookup_table)} isn't supported")
229230

ndcube/extra_coords/table_coord.py

Lines changed: 71 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ def _generate_generic_frame(naxes, unit, names=None, physical_types=None):
128128
axes_names=names, name=name, axis_physical_types=physical_types)
129129

130130

131-
def _generate_tabular(lookup_table, interpolation='linear', points_unit=u.pix, **kwargs):
131+
def _generate_tabular(lookup_table, points=None, interpolation='linear', points_unit=u.pix, **kwargs):
132132
"""
133133
Generate a Tabular model class and instance.
134134
"""
@@ -139,7 +139,11 @@ def _generate_tabular(lookup_table, interpolation='linear', points_unit=u.pix, *
139139
TabularND = tabular_model(ndim, name=f"Tabular{ndim}D")
140140

141141
# The integer location is at the centre of the pixel.
142-
points = [(np.arange(size) - 0) * points_unit for size in lookup_table.shape]
142+
if points is None:
143+
points = [(np.arange(size) - 0) * points_unit for size in lookup_table.shape]
144+
else:
145+
points = points.to(points_unit)
146+
[points] * ndim
143147
if len(points) == 1:
144148
points = points[0]
145149

@@ -160,27 +164,28 @@ def _generate_tabular(lookup_table, interpolation='linear', points_unit=u.pix, *
160164
return t
161165

162166

163-
def _generate_compound_model(*lookup_tables, mesh=True):
167+
def _generate_compound_model(*lookup_tables, points=None, mesh=True):
164168
"""
165169
Takes a set of quantities and returns a ND compound model.
166170
"""
167-
model = _generate_tabular(lookup_tables[0])
171+
model = _generate_tabular(lookup_tables[0], points=points)
168172
for lt in lookup_tables[1:]:
169-
model = model & _generate_tabular(lt)
173+
model = model & _generate_tabular(lt, points=points)
170174

171175
if mesh:
172176
return model
173177

174178
# If we are not meshing the inputs duplicate the inputs across all models
175-
mapping = list(range(lookup_tables[0].ndim)) * len(lookup_tables)
179+
#mapping = list(range(lookup_tables[0].ndim)) * len(lookup_tables)
180+
mapping = list(points) * len(lookup_tables)
176181
return models.Mapping(mapping) | model
177182

178183

179-
def _model_from_quantity(lookup_tables, mesh=False):
184+
def _model_from_quantity(lookup_tables, points=None, mesh=False):
180185
if len(lookup_tables) > 1:
181-
return _generate_compound_model(*lookup_tables, mesh=mesh)
186+
return _generate_compound_model(*lookup_tables, points=points, mesh=mesh)
182187

183-
return _generate_tabular(lookup_tables[0])
188+
return _generate_tabular(lookup_tables[0], points=points)
184189

185190

186191
class BaseTableCoordinate(abc.ABC):
@@ -196,8 +201,9 @@ class BaseTableCoordinate(abc.ABC):
196201
coordinates, meaning it can have multiple gWCS frames.
197202
"""
198203

199-
def __init__(self, *tables, mesh=False, names=None, physical_types=None):
204+
def __init__(self, *tables, points=None, mesh=False, names=None, physical_types=None):
200205
self.table = tables
206+
self.points = points
201207
self.mesh = mesh
202208
self.names = names if not isinstance(names, str) else [names]
203209
self.physical_types = physical_types if not isinstance(physical_types, str) else [physical_types]
@@ -290,6 +296,9 @@ class QuantityTableCoordinate(BaseTableCoordinate):
290296
multiple 1-D Quantities can be provided representing the different
291297
dimensions
292298
299+
points: `~astropy.units.Quantity` in pixel units.
300+
The points in the grid grid to which the values in the tables correspond.
301+
293302
names: `str` or `list` of `str`
294303
Custom names for the components of the QuantityTableCoord. If provided,
295304
a name must be given for each input Quantity.
@@ -301,7 +310,7 @@ class QuantityTableCoordinate(BaseTableCoordinate):
301310
a physical type must be given for each component.
302311
"""
303312

304-
def __init__(self, *tables, names=None, physical_types=None):
313+
def __init__(self, *tables, points=None, names=None, physical_types=None):
305314
if not all([isinstance(t, u.Quantity) for t in tables]):
306315
raise TypeError("All tables must be astropy Quantity objects")
307316
if not all([t.unit.is_equivalent(tables[0].unit) for t in tables]):
@@ -312,6 +321,11 @@ def __init__(self, *tables, names=None, physical_types=None):
312321
raise ValueError(
313322
"Currently all tables must be 1-D. If you need >1D support, please "
314323
"raise an issue at https://github.con/sunpy/ndcube/issues")
324+
if points is not None:
325+
if not points.unit.is_equivalent(u.pix):
326+
raise u.UnitsError("Points must have pixel units.")
327+
if points.shape != tables[0].shape:
328+
raise ValueError("Points must be same shape as table(s).")
315329

316330
if isinstance(names, str):
317331
names = [names]
@@ -324,7 +338,7 @@ def __init__(self, *tables, names=None, physical_types=None):
324338

325339
self.unit = tables[0].unit
326340

327-
super().__init__(*tables, mesh=True, names=names, physical_types=physical_types)
341+
super().__init__(*tables, points=points, mesh=True, names=names, physical_types=physical_types)
328342

329343
def _slice_table(self, i, table, item, new_components, whole_slice):
330344
"""
@@ -364,16 +378,43 @@ def __getitem__(self, item):
364378
if not (len(item) == len(self.table) or len(item) == self.table[0].ndim):
365379
raise ValueError("Can not slice with incorrect length")
366380

381+
# Convert item to table item based on points grid.
382+
new_item = []
383+
for idx in item:
384+
if isinstance(idx, Integral):
385+
new_idx = np.where(self.points == idx)
386+
if new_idx == ():
387+
raise NotImplementedError("Indexing QuantityTableCoordinate at inter-grid locations not supported.")
388+
else:
389+
new_item.append(new_idx[0][0])
390+
elif isinstance(idx, slice):
391+
new_start = np.where(self.points > slice.start - 1)
392+
new_start = 0 if new_start == () else new_start[0][0]
393+
new_stop = np.where(self.points >= slice.stop)
394+
new_stop = len(self.points) if new_stop == () else new_stop[0][0]
395+
new_item.append(slice(new_start, new_stop))
396+
else:
397+
new_idx = []
398+
for i in idx:
399+
new_i = np.where(self.points == i)
400+
if new_i == ():
401+
raise NotImplementedError("Indexing QuantityTableCoordinate at inter-grid locations not supported.")
402+
else:
403+
new_idx.append(new_i[0][0])
404+
new_item.append(np.asarray(new_idx))
405+
new_item = tuple(new_item)
406+
367407
new_components = defaultdict(list)
368408
new_components["dropped_world_dimensions"] = copy.deepcopy(self._dropped_world_dimensions)
369409

370410
for i, (ele, table) in enumerate(zip(item, self.table)):
371-
self._slice_table(i, table, ele, new_components, whole_slice=item)
411+
self._slice_table(i, table, ele, new_components, whole_slice=new_item)
372412

413+
points = None if self.points is None else self.points[new_item]
373414
names = new_components["names"] or None
374415
physical_types = new_components["physical_types"] or None
375416

376-
ret_table = type(self)(*new_components["tables"], names=names, physical_types=physical_types)
417+
ret_table = type(self)(*new_components["tables"], points=points, names=names, physical_types=physical_types)
377418
ret_table._dropped_world_dimensions = new_components["dropped_world_dimensions"]
378419
return ret_table
379420

@@ -396,7 +437,7 @@ def model(self):
396437
"""
397438
Generate the Astropy Model for this LookupTable.
398439
"""
399-
return _model_from_quantity(self.table, True)
440+
return _model_from_quantity(self.table, self.points, True)
400441

401442
@property
402443
def ndim(self):
@@ -418,7 +459,7 @@ def shape(self):
418459
"""
419460
return tuple(len(t) for t in self.table)
420461

421-
def interpolate(self, *new_array_grids, **kwargs):
462+
def interpolate(self, *new_array_grids, new_points=None, **kwargs):
422463
"""
423464
Interpolate QuantityTableCoordinate to new array index grids.
424465
@@ -431,6 +472,10 @@ def interpolate(self, *new_array_grids, **kwargs):
431472
represent a single location in the pixel grid. Therefore, array grids
432473
must all have the same shape.
433474
475+
new_points: `~astropy.units.Quantity` in pixel units of `str`
476+
The new pixel grid points to which the nely interpolating values will correspond.
477+
Default=None implies they will correspond to pixel centers.
478+
434479
kwargs
435480
All remaining kwargs are passed to underlying interpolation function.
436481
@@ -455,8 +500,16 @@ def interpolate(self, *new_array_grids, **kwargs):
455500
new_tables = [
456501
np.interp(new_grid, old_grid, t.value, **kwargs) * t.unit
457502
for new_grid, old_grid, t in zip(new_array_grids, old_array_grids, self.table)]
503+
if new_points is None:
504+
new_points = list(range(len(new_array_grids))) * u.pix
505+
elif not isinstance(new_points, u.Quantity):
506+
raise TypeError("new_points must be an astropy Quantity.")
507+
elif not new_points.unit.is_equivalent(u.pix):
508+
raise u.UnitError("new_points must have pixel units.")
509+
elif new_points.shape == new_tables[0].shape):
510+
raise ValueError("new_points must be same shape as tables.")
458511
# Rebuild return interpolated coord.
459-
new_coord = type(self)(*new_tables, names=self.names, physical_types=self.physical_types)
512+
new_coord = type(self)(*new_tables, points=new_points, names=self.names, physical_types=self.physical_types)
460513
new_coord._dropped_world_dimensions = self._dropped_world_dimensions
461514
return new_coord
462515

@@ -577,7 +630,7 @@ def model(self):
577630
"""
578631
Generate the Astropy Model for this LookupTable.
579632
"""
580-
return _model_from_quantity(self._sliced_components, mesh=self.mesh)
633+
return _model_from_quantity(self._sliced_components, self.points, mesh=self.mesh)
581634

582635
@property
583636
def ndim(self):

0 commit comments

Comments
 (0)