diff --git a/pyresample/geometry.py b/pyresample/geometry.py index 07dddc79e..c026bfe80 100644 --- a/pyresample/geometry.py +++ b/pyresample/geometry.py @@ -27,6 +27,7 @@ from typing import Optional import numpy as np +import pyproj import yaml from pyproj import Geod, transform @@ -41,6 +42,7 @@ proj4_dict_to_str, proj4_radius_parameters, ) +from pyresample.utils.array import _convert_2D_array try: from xarray import DataArray @@ -713,43 +715,301 @@ def copy(self): @staticmethod def _do_transform(src, dst, lons, lats, alt): - """Run pyproj.transform and stack the results.""" - x, y, z = transform(src, dst, lons, lats, alt) + """Perform pyproj transformation and stack the results. + + If using pyproj >= 3.1, it employs thread-safe pyproj.transformer.Transformer. + If using pyproj < 3.1, it employs pyproj.transform. + + Docs: https://pyproj4.github.io/pyproj/stable/advanced_examples.html#multithreading + """ + if float(pyproj.__version__[0:3]) >= 3.1: + from pyproj import Transformer + transformer = Transformer.from_crs(src.crs, dst.crs) + x, y, z = transformer.transform(lons, lats, alt, radians=False) + else: + x, y, z = transform(src, dst, lons, lats, alt) return np.dstack((x, y, z)) def aggregate(self, **dims): - """Aggregate the current swath definition by averaging. - - For example, averaging over 2x2 windows: - `sd.aggregate(x=2, y=2)` + """Return an aggregated version of the area.""" + warnings.warn("'aggregate' is deprecated, use 'downsample' instead.", PendingDeprecationWarning) + return self.downsample(x=dims.get('x', 1), y=dims.get('y', 1)) + + def downsample(self, x=1, y=1, **kwargs): + """Downsample the SwathDefinition along x (columns) and y (lines) dimensions. + + Builds upon xarray.DataArray.coarsen averaging function. + To downsample of a factor of 2, call swath_def.downsample(x=2, y=2) + swath_def.downsample(x=1, y=1) simply returns the current swath_def. + By default, it raise a ValueError if the dimension size is not a multiple of the window size. + This can be changed by passing boundary="trim" or boundary="pad", but behaviour within pyresample is undefined. + See https://xarray.pydata.org/en/stable/generated/xarray.DataArray.coarsen.html for further details. """ import dask.array as da - import pyproj + import xarray as xr + + # Check input validity + x = int(x) + y = int(y) + if x < 1 or y < 1: + raise ValueError('SwathDefinition.downsample expects (integer) aggregation factors >=1 .') + + # Return SwathDefinition if nothing to downsample + if x == 1 and y == 1: + return self + # Define geodetic and geocentric projection geocent = pyproj.Proj(proj='geocent') latlong = pyproj.Proj(proj='latlong') + + # Get xr.DataArray with dask array + # - If input lats/lons are xr.DataArray, the specified dims ['x','y'] are ignored + src_lons, src_lons_format = _convert_2D_array(self.lons, to='DataArray_Dask', dims=['y', 'x']) + src_lats, src_lats_format = _convert_2D_array(self.lats, to='DataArray_Dask', dims=['y', 'x']) + + # Conversion to Geocentric Cartesian (x,y,z) CRS res = da.map_blocks(self._do_transform, latlong, geocent, - self.lons.data, self.lats.data, - da.zeros_like(self.lons.data), new_axis=[2], - chunks=(self.lons.chunks[0], self.lons.chunks[1], 3)) - res = DataArray(res, dims=['y', 'x', 'coord'], coords=self.lons.coords) - res = res.coarsen(**dims).mean() + src_lons.data, + src_lats.data, + da.zeros_like(src_lons), # altitude + new_axis=[2], + chunks=(src_lons.chunks[0], src_lons.chunks[1], 3)) + res = xr.DataArray(res, dims=['y', 'x', 'xyz'], coords=src_lons.coords) + + # Aggregating + res = res.coarsen(x=x, y=y, **kwargs).mean() + + # Back-conversion to geographic CRS lonlatalt = da.map_blocks(self._do_transform, geocent, latlong, - res[:, :, 0].data, res[:, :, 1].data, - res[:, :, 2].data, new_axis=[2], + res[:, :, 0].data, # x + res[:, :, 1].data, # y + res[:, :, 2].data, # z + new_axis=[2], chunks=res.data.chunks) - lons = DataArray(lonlatalt[:, :, 0], dims=self.lons.dims, - coords=res.coords, attrs=self.lons.attrs.copy()) - lats = DataArray(lonlatalt[:, :, 1], dims=self.lons.dims, - coords=res.coords, attrs=self.lons.attrs.copy()) - try: - resolution = lons.attrs['resolution'] * ((dims.get('x', 1) + dims.get('y', 1)) / 2) - lons.attrs['resolution'] = resolution - lats.attrs['resolution'] = resolution - except KeyError: - pass + + # Back-conversion array as input format + lons, _ = _convert_2D_array(lonlatalt[:, :, 0], to=src_lons_format, dims=src_lons.dims) + lats, _ = _convert_2D_array(lonlatalt[:, :, 1], to=src_lats_format, dims=src_lats.dims) + + # Add additional info if the source array is a DataArray + if isinstance(self.lats, xr.DataArray) and isinstance(self.lons, xr.DataArray): + lats = lats.assign_coords(res.coords) + lons = lons.assign_coords(res.coords) + lats.attrs = self.lats.attrs.copy() + lons.attrs = self.lons.attrs.copy() + try: + resolution = lons.attrs['resolution'] * ((x + y) / 2) + lons.attrs['resolution'] = resolution + lats.attrs['resolution'] = resolution + except KeyError: + pass + + # Return the downsampled swath definition + return SwathDefinition(lons, lats) + + def upsample(self, x=1, y=1): + """Upsample the SwathDefinition along x (columns) and y (lines) dimensions. + + To upsample of a factor of 2 (each pixel splitted in 2x2 pixels), + call swath_def.upsample(x=2, y=2). + swath_def.upsample(x=1, y=1) simply returns the current swath_def. + """ + # TODO: An alternative would be to use geotiepoints.geointerpolator.GeoInterpolator + # But I have some problem using it, see code snippet in a comment of the PR. + # TODO: Should we upsample also possible coords of lons/lats input xr.DataArray? + import dask.array as da + import numpy as np + import xarray as xr + from xarray.plot.utils import _infer_interval_breaks + + # https://github.com/pydata/xarray/blob/main/xarray/plot/utils.py#L784 + # Check input validity + x = int(x) + y = int(y) + if x < 1 or y < 1: + raise ValueError("SwathDefinition.upsample expects (integer) upscaling factors >=1 .") + # Return SwathDefinition if nothing to upsample + if x == 1 and y == 1: + return self + # --------------------------------------------------------------------. + # TODO: + # - Refactor for dask-compatibility + # - Should we make _infer_interval_breaks dask-compatible? + + def _get_corners_from_centroids(centroids): + breaks_xx = _infer_interval_breaks(centroids, axis=1) + corners = _infer_interval_breaks(breaks_xx, axis=0) + return corners + + # TODO: choose one of the two function below + # - What is the best way to apply _upsample_centroid along each x-y plane with dask + def _upsample_centroid(centroid, x=1, y=1): + corners = _get_corners_from_centroids(centroid) + # Retrieve corners of the the upsampled grid + new_corners = _linspace2D_between_values(corners, num_x=x - 1, num_y=y - 1) + # Get centroids from corners + new_centroids = (new_corners[:-1, :-1] + new_corners[1:, 1:]) / 2 + return new_centroids + + def upsample_centroids(centroid_x, centroid_y, centroid_z, x=1, y=1): + x_new_centroids = _upsample_centroid(centroid_x, x=x, y=y) + y_new_centroids = _upsample_centroid(centroid_y, x=x, y=y) + z_new_centroids = _upsample_centroid(centroid_z, x=x, y=y) + return x_new_centroids, y_new_centroids, z_new_centroids + + # --------------------------------------------------------------------. + # Define geodetic and geocentric projection + geocent = pyproj.Proj(proj='geocent') + latlong = pyproj.Proj(proj='latlong') + + # Get xr.DataArray with dask array + # - If input lats/lons are xr.DataArray, the specified dims ['x','y'] are ignored + src_lons, src_lons_format = _convert_2D_array(self.lons, to='DataArray_Dask', dims=['y', 'x']) + src_lats, src_lats_format = _convert_2D_array(self.lats, to='DataArray_Dask', dims=['y', 'x']) + + # Conversion to Geocentric Cartesian (x,y,z) CRS + res = da.map_blocks(self._do_transform, latlong, geocent, + src_lons.data, + src_lats.data, + da.zeros_like(src_lons), # altitude + new_axis=[2], + chunks=(src_lons.chunks[0], src_lons.chunks[1], 3)) + res = xr.DataArray(res, dims=['y', 'x', 'xyz']) + + # Retrieve new centroids + # TODO: make it dask compatible using _upsample_centroid_dask [HELP WANTED] + # res1 = da.apply_along_axis(_upsample_centroid_dask, + # 2, + # res.data, + # x, + # y) + # res1 = xr.DataArray(res1, dims=['y', 'x', 'coord'], coords=src_lons.coords) + res = np.stack(upsample_centroids(res[:, :, 0].data, + res[:, :, 1].data, + res[:, :, 2].data, x=x, y=y), axis=2) + new_centroids = xr.DataArray(da.from_array(res), dims=['y', 'x', 'xyz']) + + # Back-conversion to geographic CRS + lonlatalt = da.map_blocks(self._do_transform, geocent, latlong, + new_centroids[:, :, 0].data, # x + new_centroids[:, :, 1].data, # y + new_centroids[:, :, 2].data, # z + new_axis=[2], + chunks=new_centroids.data.chunks) + + # Back-conversion array as input format + lons, _ = _convert_2D_array(lonlatalt[:, :, 0], to=src_lons_format, dims=src_lons.dims) + lats, _ = _convert_2D_array(lonlatalt[:, :, 1], to=src_lats_format, dims=src_lats.dims) + + # Add additional info if the source array is a DataArray + if isinstance(self.lats, xr.DataArray) and isinstance(self.lons, xr.DataArray): + lats.attrs = self.lats.attrs.copy() + lons.attrs = self.lons.attrs.copy() + try: + resolution = lons.attrs['resolution'] / ((x + y) / 2) + lons.attrs['resolution'] = resolution + lats.attrs['resolution'] = resolution + except KeyError: + pass + + # Return the downsampled swath definition return SwathDefinition(lons, lats) + def extend(self, left=0, right=0, bottom=0, top=0): + """Extend the SwathDefinition of n pixels on specific boundary sides. + + By default, it does not extend on any side. + """ + import xarray as xr + + # Check input validity + left = int(left) + right = int(right) + bottom = int(bottom) + top = int(top) + if left < 0 or right < 0 or bottom < 0 or top < 0: + raise ValueError('SwathDefinition.extend expects positive numbers of pixels.') + + # Return SwathDefinition if nothing to extend + if left == 0 and right == 0 and bottom == 0 and top == 0: + return self + + # Get lats/lons numpy arrays + src_lats, src_lats_format = _convert_2D_array(self.lats, to='numpy', dims=['y', 'x']) + src_lons, src_lons_format = _convert_2D_array(self.lons, to='numpy', dims=['y', 'x']) + + dst_lats = src_lats + dst_lons = src_lons + + # Extend swath sides + if top > 0: + list_side0 = (dst_lons[1, :], dst_lats[1, :], dst_lons[0, :], dst_lats[0, :]) + extended_side0_lonlats = _get_extended_lonlats(*list_side0, npts=top) + dst_lats = np.concatenate((extended_side0_lonlats[1][::-1, :], dst_lats), axis=0) + dst_lons = np.concatenate((extended_side0_lonlats[0][::-1, :], dst_lons), axis=0) + + if bottom > 0: + list_side2 = (dst_lons[-2, :], dst_lats[-2, :], dst_lons[-1, :], dst_lats[-1, :]) + extended_side2_lonlats = _get_extended_lonlats(*list_side2, npts=bottom) + dst_lats = np.concatenate((dst_lats, extended_side2_lonlats[1]), axis=0) + dst_lons = np.concatenate((dst_lons, extended_side2_lonlats[0]), axis=0) + + if right > 0: + list_side1 = (dst_lons[:, -2], dst_lats[:, -2], dst_lons[:, -1], dst_lats[:, -1]) + extended_side1_lonlats = _get_extended_lonlats(*list_side1, npts=right, transpose=False) + dst_lats = np.concatenate((dst_lats, extended_side1_lonlats[1]), axis=1) + dst_lons = np.concatenate((dst_lons, extended_side1_lonlats[0]), axis=1) + + if left > 0: + list_side3 = (dst_lons[:, 1], dst_lats[:, 1], dst_lons[:, 0], dst_lats[:, 0]) + extended_side3_lonlats = _get_extended_lonlats(*list_side3, npts=left, transpose=False) + dst_lats = np.concatenate((extended_side3_lonlats[1][:, ::-1], dst_lats), axis=1) + dst_lons = np.concatenate((extended_side3_lonlats[0][:, ::-1], dst_lons), axis=1) + + # Back-conversion array as input format + lons, _ = _convert_2D_array(dst_lons, to=src_lons_format, dims=['y', 'x']) + lats, _ = _convert_2D_array(dst_lats, to=src_lats_format, dims=['y', 'x']) + + # Add additional info if the source array is a DataArray + if isinstance(self.lats, xr.DataArray) and isinstance(self.lons, xr.DataArray): + lats.attrs = self.lats.attrs.copy() + lons.attrs = self.lons.attrs.copy() + + # Return the extended SwathDefinition + return SwathDefinition(lons, lats) + + def shrink(self, left=0, right=0, bottom=0, top=0): + """Shrink the SwathDefinition of n pixels on specific boundary sides. + + By default, it does not shrink on any side. + """ + # Check input validity + left = int(left) + right = int(right) + bottom = int(bottom) + top = int(top) + if left < 0 or right < 0 or bottom < 0 or top < 0: + raise ValueError('SwathDefinition.shrink expects positive numbers of pixels.') + + # Return SwathDefinition if nothing to shrink + if left == 0 and right == 0 and bottom == 0 and top == 0: + return self + + # Ensure shrinked area is at least 2x2 + height = self.lats.shape[0] + width = self.lats.shape[1] + x_max_shrink = width - 2 + y_max_shrink = height - 2 + if (left + right) > x_max_shrink: + raise ValueError("SwathDefinition.shrink can drop maximum {} pixels " + "along the x direction.".format(x_max_shrink)) + if (top + bottom) > y_max_shrink: + raise ValueError("SwathDefinition.shrink can drop maximum {} pixels " + "along the y direction.".format(y_max_shrink)) + + # Return the shrinked SwathDefinition + return self[slice(top, height - bottom), slice(left, width - right)] + def __hash__(self): """Compute the hash of this object.""" if self.hash is None: @@ -892,6 +1152,97 @@ def compute_optimal_bb_area(self, proj_dict=None): return area.freeze((lons, lats), shape=(height, width)) +def _linspace2D_between_values(arr, num_x=0, num_y=0): + """Dask-friendly function linearly interpolating values between each 2D array values. + + This function does not perform extrapolation. + It expects a 2D array as input! + + Parameters + ---------- + arr : (np.ndarray, dask.array.Array) + Numpy or Dask Array to be linearly interpolated between values. + num_x : int, optional + The number of linearly spaced values to infer between array values (along x). + . The default is 0. + num_y : int, optional + The number of linearly spaced values to infer between array values (along y). + The default is 0. + + Returns + ------- + arr : (np.ndarray, dask.array.Array) + Numpy or Dask Array with in-between linearly interpolated values. + + Example + ------- + + Function call: _linspace2D_between_values(arr, num_x=1, num_y=1) + Input: + np.array([[5.0, 7.0], + [7.0, 9.0]]) + Output: + np.array([[5.0, 6.0, 7.0], + [6.0, 7.0, 8.0], + [7.0, 8.0, 9.0]]) + """ + import xarray as xr + + # Check input validity + if arr.ndim != 2: + raise ValueError("'_linspace2D_between_values' expects a 2D array.") + num_x = int(num_x) + num_y = int(num_y) + if num_x < 0 or num_y < 0: + raise ValueError("'x' and 'y' must be an integer equal or larger than 0.") + if num_x == 0 and num_y == 0: + return arr + # Define src and dst ties + shape = arr.shape + Nx_dst = (shape[1] - 1) * (num_x + 1) + 1 + Ny_dst = (shape[0] - 1) * (num_y + 1) + 1 + + src_ties_x = np.arange(Nx_dst, step=num_x + 1) + src_ties_y = np.arange(Ny_dst, step=num_y + 1) + dst_ties_x = np.arange(Nx_dst) + dst_ties_y = np.arange(Ny_dst) + # Interpolate + da = xr.DataArray( + data=arr, + dims=("y", "x"), + coords={"y": src_ties_y, "x": src_ties_x} + ) + da_interp = da.interp({"y": dst_ties_y, "x": dst_ties_x}, method="linear") + return da_interp.data + + +def _get_extended_lonlats(lon_start, lat_start, lon_end, lat_end, npts, + ellps="sphere", + transpose=True): + """Utils employed by SwathDefinition.extend. + + It extrapolate npts following the forward azimuth with an interdistance + equal to the distance between the starting point and the end point. + """ + geod = pyproj.Geod(ellps=ellps) + # geod = pyproj.Geod(ellps='WGS84') # sphere + az12_arr, _, dist_arr = geod.inv(lon_start, lat_start, lon_end, lat_end) + list_lat = [] + list_lon = [] + for lon, lat, az12, dist in zip(lon_end, lat_end, az12_arr, dist_arr): + points = geod.fwd_intermediate(lon, lat, az12, del_s=dist, npts=npts, + out_lons=None, out_lats=None, radians=False) + list_lat.append(points.lats) + list_lon.append(points.lons) + + new_lats = np.stack(list_lat) + new_lons = np.stack(list_lon) + if transpose: + new_lats = new_lats.T + new_lons = new_lons.T + return new_lons, new_lats + + class DynamicAreaDefinition(object): """An AreaDefintion containing just a subset of the needed parameters. @@ -1422,11 +1773,120 @@ def copy(self, **override_kwargs): return AreaDefinition(**kwargs) def aggregate(self, **dims): - """Return an aggregated version of the area.""" - width = int(self.width / dims.get('x', 1)) - height = int(self.height / dims.get('y', 1)) + """Return an aggregated version of the area. + + Aggregate allows to mix between downsample and upsample in different directions. + Example: area_def.aggregate(x=2, y=0.5) <-> area_def.downsample(x=2).upsample(y=2). + """ + x = dims.get('x', 1) + y = dims.get('y', 1) + if x <= 0 or y <= 0: + raise ValueError('AreaDefinition.aggregate x and y arguments must be > 0.') + if x == 1 and y == 1: + return self + width = int(self.width / x) + height = int(self.height / y) return self.copy(height=height, width=width) + def downsample(self, x=1, y=1): + """Return a downsampled version of the area.""" + # Check input validity + if x == 1 and y == 1: + return self + if x < 1 or y < 1: + raise ValueError('AreaDefinition.downsample x and y arguments must be >= 1.') + # Downsample + return self.aggregate(x=x, y=y) + + def upsample(self, x=1, y=1): + """Return an upsampled version of the area.""" + # Check input validity + if x == 1 and y == 1: + return self + if x < 1 or y < 1: + raise ValueError('AreaDefinition.upsample x and y arguments must be >= 1.') + # Upsample + return self.aggregate(x=1 / x, y=1 / y) + + def extend(self, left=0, right=0, bottom=0, top=0): + """Extend AreaDefinition by n pixels on specific boundary sides. + + By default, it does not extend on any side. + """ + if self.is_geostationary: + raise NotImplementedError("AreaDefinition.extend method is not implemented for GEO AreaDefinition.") + # Check input validity + left = int(left) + right = int(right) + bottom = int(bottom) + top = int(top) + if left < 0 or right < 0 or bottom < 0 or top < 0: + raise ValueError('AreaDefinition.extend expects positive numbers of pixels.') + + # Return AreaDefinition if nothing to extend + if left == 0 and right == 0 and bottom == 0 and top == 0: + return self + + # Retrieve pixel and area info + new_width = self.width + left + right + new_height = self.height + bottom + top + pixel_size_x = self.pixel_size_x + pixel_size_y = self.pixel_size_y + + # Extend area_extent (lower_left_x, lower_left_y, upper_right_x, upper_right_y) + area_extent = self._area_extent + new_area_extent = list(area_extent) + new_area_extent[0] = new_area_extent[0] - pixel_size_x * left + new_area_extent[1] = new_area_extent[1] - pixel_size_y * bottom + new_area_extent[2] = new_area_extent[2] + pixel_size_x * right + new_area_extent[3] = new_area_extent[3] + pixel_size_y * top + + # Define new AreaDefinition + projection = self.crs_wkt + area_def = AreaDefinition(self.area_id, self.description, self.proj_id, + projection=projection, + width=new_width, + height=new_height, + area_extent=new_area_extent, + rotation=self.rotation, + nprocs=self.nprocs, + dtype=self.dtype) + + return area_def + + def shrink(self, left=0, right=0, bottom=0, top=0): + """Shrink AreaDefinition by n pixels on specific boundary sides. + + By default, it does not shrink on any side. + """ + if self.is_geostationary: + raise NotImplementedError("AreaDefinition.shrink method is not implemented for GEO AreaDefinition.") + # Check input validity + left = int(left) + right = int(right) + bottom = int(bottom) + top = int(top) + if left < 0 or right < 0 or bottom < 0 or top < 0: + raise ValueError('AreaDefinition.shrink expects positive numbers of pixels.') + + # Return AreaDefinition if nothing to extend + if left == 0 and right == 0 and bottom == 0 and top == 0: + return self + + # Ensure shrinked area is at least 2x2 + width = self.width + height = self.height + x_max_shrink = width - 2 + y_max_shrink = height - 2 + if (left + right) > x_max_shrink: + raise ValueError("AreaDefinition.shrink can drop maximum {} pixels " + "along the x direction.".format(x_max_shrink)) + if (top + bottom) > y_max_shrink: + raise ValueError("AreaDefinition.shrink can drop maximum {} pixels " + "along the y direction.".format(y_max_shrink)) + + return self[slice(top, height - bottom), slice(left, width - right)] + @property def resolution(self): """Return area resolution in X and Y direction.""" diff --git a/pyresample/test/test_geometry.py b/pyresample/test/test_geometry.py index b3d0b9202..14b012abd 100644 --- a/pyresample/test/test_geometry.py +++ b/pyresample/test/test_geometry.py @@ -1872,26 +1872,301 @@ def test_compute_optimal_bb(self): assert_np_dict_allclose(res.proj_dict, proj_dict) self.assertEqual(res.shape, (6, 3)) - def test_aggregation(self): + def test_downsampling(self): """Test aggregation on SwathDefinitions.""" import dask.array as da import numpy as np import xarray as xr + + from pyresample.geometry import SwathDefinition window_size = 2 resolution = 3 lats = np.array([[0, 0, 0, 0], [1, 1, 1, 1.0]]) lons = np.array([[178.5, 179.5, -179.5, -178.5], [178.5, 179.5, -179.5, -178.5]]) - xlats = xr.DataArray(da.from_array(lats, chunks=2), dims=['y', 'x'], - attrs={'resolution': resolution}) - xlons = xr.DataArray(da.from_array(lons, chunks=2), dims=['y', 'x'], - attrs={'resolution': resolution}) + lats_dask = da.from_array(lats, chunks=2) + lons_dask = da.from_array(lons, chunks=2) + lats_xr = xr.DataArray(lats, dims=['y', 'x'], + attrs={'resolution': resolution}) + lons_xr = xr.DataArray(lons, dims=['y', 'x'], + attrs={'resolution': resolution}) + lats_xr_dask = xr.DataArray(lats_dask, dims=['y', 'x'], + attrs={'resolution': resolution}) + lons_xr_dask = xr.DataArray(lons_dask, dims=['y', 'x'], + attrs={'resolution': resolution}) + sd_np = SwathDefinition(lons, lats) + sd_xr = SwathDefinition(lons_xr, lats_xr) + sd_xr_dask = SwathDefinition(lons_xr_dask, lats_xr_dask) + + res_np = sd_np.downsample(y=window_size, x=window_size) + res_xr = sd_xr.downsample(y=window_size, x=window_size) + res_xr_dask = sd_xr_dask.downsample(y=window_size, x=window_size) + + assert isinstance(res_np.lats, np.ndarray) + assert isinstance(res_xr.lats, xr.DataArray) + assert isinstance(res_xr_dask.lats, xr.DataArray) + assert isinstance(res_xr.lats.data, np.ndarray) + assert isinstance(res_xr_dask.lats.data, da.Array) + + np.testing.assert_allclose(res_np.lons, [[179, -179]]) + np.testing.assert_allclose(res_np.lats, [[0.5, 0.5]], atol=2e-5) + np.testing.assert_allclose(res_xr.lons.data, res_np.lons) + np.testing.assert_allclose(res_xr.lats.data, res_np.lats) + np.testing.assert_allclose(res_xr_dask.lons.values, res_np.lons) + np.testing.assert_allclose(res_xr_dask.lats.values, res_np.lats) + + self.assertAlmostEqual(res_xr.lons.resolution, resolution * window_size) + self.assertAlmostEqual(res_xr.lats.resolution, resolution * window_size) + # Test skip aggregation + np.testing.assert_allclose(sd_np.downsample(y=1, x=1).lats, sd_np.lats) + np.testing.assert_allclose(sd_np.downsample(y=1, x=1).lons, sd_np.lons) + # Test invalid arguments + self.assertRaises(ValueError, sd_np.downsample, 0, 0) + self.assertRaises(ValueError, sd_np.downsample, -1, -1) + # Test works with DataArray also without attrs + lats_xr = xr.DataArray(lats, dims=['y', 'x']) + lons_xr = xr.DataArray(lons, dims=['y', 'x']) + sd_xr1 = SwathDefinition(lons_xr, lats_xr) + res_xr1 = sd_xr1.downsample(y=window_size, x=window_size) + np.testing.assert_allclose(res_xr1.lats.data, res_xr.lats.data) + + def test_upsampling(self): + """Test upsampling on SwathDefinitions.""" + import dask.array as da + import numpy as np + import xarray as xr + from pyresample.geometry import SwathDefinition - sd = SwathDefinition(xlons, xlats) - res = sd.aggregate(y=window_size, x=window_size) - np.testing.assert_allclose(res.lons, [[179, -179]]) - np.testing.assert_allclose(res.lats, [[0.5, 0.5]], atol=2e-5) - self.assertAlmostEqual(res.lons.resolution, window_size * resolution) - self.assertAlmostEqual(res.lats.resolution, window_size * resolution) + window_size = 2 + resolution = 4 + + lons = np.array([5.0, 9.0]) + lats = np.array([6.0, 4.0]) + lons, lats = np.meshgrid(lons, lats) + + lats_dask = da.from_array(lats, chunks=2) + lons_dask = da.from_array(lons, chunks=2) + lats_xr = xr.DataArray(lats, dims=['y', 'x'], + attrs={'resolution': resolution}) + lons_xr = xr.DataArray(lons, dims=['y', 'x'], + attrs={'resolution': resolution}) + lats_xr_dask = xr.DataArray(lats_dask, dims=['y', 'x'], + attrs={'resolution': resolution}) + lons_xr_dask = xr.DataArray(lons_dask, dims=['y', 'x'], + attrs={'resolution': resolution}) + sd_np = SwathDefinition(lons, lats) + sd_xr = SwathDefinition(lons_xr, lats_xr) + sd_xr_dask = SwathDefinition(lons_xr_dask, lats_xr_dask) + + res_np = sd_np.upsample(y=window_size, x=window_size) + res_xr = sd_xr.upsample(y=window_size, x=window_size) + res_xr_dask = sd_xr_dask.upsample(y=window_size, x=window_size) + + assert isinstance(res_np.lons, np.ndarray) + assert isinstance(res_xr.lats, xr.DataArray) + assert isinstance(res_xr_dask.lats, xr.DataArray) + assert isinstance(res_xr.lats.data, np.ndarray) + assert isinstance(res_xr_dask.lats.data, da.Array) + + np.testing.assert_allclose(res_np.lons[0, :], [4, 6, 8, 10], atol=1e-2) + np.testing.assert_allclose(res_np.lats[:, 1], [6.5, 5.5, 4.5, 3.5], atol=1e-2) + np.testing.assert_allclose(res_xr.lons.data, res_np.lons) + np.testing.assert_allclose(res_xr.lats.data, res_np.lats) + np.testing.assert_allclose(res_xr_dask.lons.values, res_np.lons) + np.testing.assert_allclose(res_xr_dask.lats.values, res_np.lats) + + self.assertAlmostEqual(res_xr.lons.resolution, resolution / window_size) + self.assertAlmostEqual(res_xr.lats.resolution, resolution / window_size) + + # Test skip upsampling + np.testing.assert_allclose(sd_np.upsample(y=1, x=1).lats, sd_np.lats) + np.testing.assert_allclose(sd_np.upsample(y=1, x=1).lons, sd_np.lons) + # Test invalid arguments + self.assertRaises(ValueError, sd_np.upsample, 0, 0) + self.assertRaises(ValueError, sd_np.upsample, -1, -1) + # Test works with DataArray also without attrs + lats_xr = xr.DataArray(lats, dims=['y', 'x']) + lons_xr = xr.DataArray(lons, dims=['y', 'x']) + sd_xr1 = SwathDefinition(lons_xr, lats_xr) + res_xr1 = sd_xr1.upsample(y=window_size, x=window_size) + np.testing.assert_allclose(res_xr1.lats.data, res_xr.lats.data) + + def test_extend(self): + """Test extend on SwathDefinitions.""" + import dask.array as da + import numpy as np + import xarray as xr + + from pyresample.geometry import SwathDefinition + top = 2 + bottom = 2 + left = 2 + right = 2 + resolution = 4 + + lons = np.arange(-179.5, -178.5, 0.5) + lats = np.arange(-89.5, -88.5, 0.5) + lons, lats = np.meshgrid(lons, lats) + + lats_dask = da.from_array(lats, chunks=2) + lons_dask = da.from_array(lons, chunks=2) + lats_xr = xr.DataArray(lats, dims=['y', 'x'], + attrs={'resolution': resolution}) + lons_xr = xr.DataArray(lons, dims=['y', 'x'], + attrs={'resolution': resolution}) + lats_xr_dask = xr.DataArray(lats_dask, dims=['y', 'x'], + attrs={'resolution': resolution}) + lons_xr_dask = xr.DataArray(lons_dask, dims=['y', 'x'], + attrs={'resolution': resolution}) + sd_np = SwathDefinition(lons, lats) + sd_xr = SwathDefinition(lons_xr, lats_xr) + sd_xr_dask = SwathDefinition(lons_xr_dask, lats_xr_dask) + + res_np = sd_np.extend(left, right, bottom, top) + res_xr = sd_xr.extend(left, right, bottom, top) + res_xr_dask = sd_xr_dask.extend(left, right, bottom, top) + + assert isinstance(res_np.lons, np.ndarray) + assert isinstance(res_xr.lats, xr.DataArray) + assert isinstance(res_xr_dask.lats, xr.DataArray) + assert isinstance(res_xr.lats.data, np.ndarray) + assert isinstance(res_xr_dask.lats.data, da.Array) + + np.testing.assert_allclose(res_np.lons[2, 0:3], [179.5, 180, -179.5], atol=1e-4) + np.testing.assert_allclose(res_np.lons[0, :], [-0.5, 0, 0.5, 1, 1.5, 2], atol=1e-4) + np.testing.assert_allclose(res_np.lats[:, 0], [-89.5, -90.0, -89.5, -89, -88.5, -88.0], atol=1e-3) + np.testing.assert_allclose(res_xr.lons.data, res_np.lons) + np.testing.assert_allclose(res_xr.lats.data, res_np.lats) + np.testing.assert_allclose(res_xr_dask.lons.values, res_np.lons) + np.testing.assert_allclose(res_xr_dask.lats.values, res_np.lats) + + self.assertAlmostEqual(res_xr.lons.resolution, resolution) + self.assertAlmostEqual(res_xr.lats.resolution, resolution) + + # Test skip extension + np.testing.assert_allclose(sd_np.extend().lats, sd_np.lats) + np.testing.assert_allclose(sd_np.extend().lons, sd_np.lons) + # Test invalid arguments + self.assertRaises(ValueError, sd_np.extend, -1, -1, 0, 0) + # Test works with DataArray also without attrs + lats_xr = xr.DataArray(lats, dims=['y', 'x']) + lons_xr = xr.DataArray(lons, dims=['y', 'x']) + sd_xr1 = SwathDefinition(lons_xr, lats_xr) + res_xr1 = sd_xr1.extend(left, right, bottom, top) + np.testing.assert_allclose(res_xr1.lats.data, res_xr.lats.data) + + def test_shrink(self): + """Test shrink on SwathDefinitions.""" + import dask.array as da + import numpy as np + import xarray as xr + + from pyresample.geometry import SwathDefinition + right = 1 + left = 1 + bottom = 0 + top = 0 + resolution = 4 + + lons = np.arange(-179.5, -177.5, 0.5) + lats = np.arange(-89.5, -88.0, 0.5) + lons, lats = np.meshgrid(lons, lats) + + lats_dask = da.from_array(lats, chunks=2) + lons_dask = da.from_array(lons, chunks=2) + lats_xr = xr.DataArray(lats, dims=['y', 'x'], + attrs={'resolution': resolution}) + lons_xr = xr.DataArray(lons, dims=['y', 'x'], + attrs={'resolution': resolution}) + lats_xr_dask = xr.DataArray(lats_dask, dims=['y', 'x'], + attrs={'resolution': resolution}) + lons_xr_dask = xr.DataArray(lons_dask, dims=['y', 'x'], + attrs={'resolution': resolution}) + sd_np = SwathDefinition(lons, lats) + sd_xr = SwathDefinition(lons_xr, lats_xr) + sd_xr_dask = SwathDefinition(lons_xr_dask, lats_xr_dask) + + res_np = sd_np.shrink(left, right, bottom, top) + res_xr = sd_xr.shrink(left, right, bottom, top) + res_xr_dask = sd_xr_dask.shrink(left, right, bottom, top) + + assert isinstance(res_np.lons, np.ndarray) + assert isinstance(res_xr.lats, xr.DataArray) + assert isinstance(res_xr_dask.lats, xr.DataArray) + assert isinstance(res_xr.lats.data, np.ndarray) + assert isinstance(res_xr_dask.lats.data, da.Array) + + np.testing.assert_allclose(res_np.lons[:, 0], lons[:, left]) + np.testing.assert_allclose(res_np.lons[:, -1], lons[:, -right - 1]) + np.testing.assert_allclose(res_xr.lons.data, res_np.lons) + np.testing.assert_allclose(res_xr.lats.data, res_np.lats) + np.testing.assert_allclose(res_xr_dask.lons.values, res_np.lons) + np.testing.assert_allclose(res_xr_dask.lats.values, res_np.lats) + + self.assertAlmostEqual(res_xr.lons.resolution, resolution) + self.assertAlmostEqual(res_xr.lats.resolution, resolution) + + # Test skip reduction + np.testing.assert_allclose(sd_np.shrink().lats, sd_np.lats) + np.testing.assert_allclose(sd_np.shrink().lons, sd_np.lons) + # Test invalid arguments + self.assertRaises(ValueError, sd_np.shrink, -1, -1, 0, 0) + # Test works with DataArray also without attrs + lats_xr = xr.DataArray(lats, dims=['y', 'x']) + lons_xr = xr.DataArray(lons, dims=['y', 'x']) + sd_xr1 = SwathDefinition(lons_xr, lats_xr) + res_xr1 = sd_xr1.shrink(left, right, bottom, top) + np.testing.assert_allclose(res_xr1.lats.data, res_xr.lats.data) + # Test it raise Error if x or y are too large and not ensure output to be 2x2 at least + self.assertRaises(ValueError, sd_np.shrink, lons.shape[1] / 2, lons.shape[1] / 2, 0, 0) + self.assertRaises(ValueError, sd_np.shrink, 0, 0, lons.shape[0] / 2, lons.shape[0] / 2) + + def test_linspace2D_between_values(self): + """Test linspace2D_between_values.""" + import dask.array as da + import numpy as np + + from pyresample.geometry import _linspace2D_between_values + arr_np = np.array([[5.0, 7.0], + [7.0, 9.0]]) + arr_dask = da.from_array(arr_np) + + res_np = _linspace2D_between_values(arr_np, num_x=1, num_y=3) + res_dask = _linspace2D_between_values(arr_dask, num_x=1, num_y=3) + + output_expected = np.array([[5., 6., 7.], + [5.5, 6.5, 7.5], + [6., 7., 8.], + [6.5, 7.5, 8.5], + [7., 8., 9.]]) + np.testing.assert_allclose(res_np, output_expected) + np.testing.assert_allclose(res_dask, output_expected) + assert isinstance(res_np, np.ndarray) + assert isinstance(res_dask, da.Array) + + # Test for no interpolation inbetween values + res = _linspace2D_between_values(arr_np, num_x=0, num_y=0) + np.testing.assert_allclose(res, arr_np) + + # Test for valid inputs + self.assertRaises(ValueError, _linspace2D_between_values, arr_np, -1, 0) + self.assertRaises(ValueError, _linspace2D_between_values, arr_np, 0, -1) + self.assertRaises(ValueError, _linspace2D_between_values, arr_np[0, :], 0, 0) + + def test_get_extended_lonlats(self): + import numpy as np + + from pyresample.geometry import _get_extended_lonlats + lon_start = np.array([10, 20]) + lon_end = np.array([20, 30]) + lat_start = np.array([0, 1]) + lat_end = np.array([0, 1]) + + npts = 2 + ext_lons, ext_lats = _get_extended_lonlats(lon_start, lat_start, + lon_end, lat_end, + npts, transpose=True) + np.testing.assert_allclose(ext_lons[0, :], [30.0, 40.0]) + np.testing.assert_allclose(ext_lons[1, :], [40.0, 50.0], atol=1e4) def test_striding(self): """Test striding.""" @@ -2562,6 +2837,265 @@ def test_aggregate(self): self.assertEqual(res.shape[0], area.shape[0] / 2) self.assertEqual(res.shape[1], area.shape[1] / 4) + # Test skip aggregate + res = area.aggregate(x=1, y=1) + np.testing.assert_allclose(res.shape, area.shape) + assert res == area + + # Test raise error + self.assertRaises(ValueError, area.aggregate, x=0, y=0) + self.assertRaises(ValueError, area.aggregate, x=-1, y=-1) + + def test_downsample(self): + """Test downsampling of AreaDefinitions.""" + area = geometry.AreaDefinition('areaD', 'Europe (3km, HRV, VTC)', 'areaD', + {'a': '6378144.0', + 'b': '6356759.0', + 'lat_0': '50.00', + 'lat_ts': '50.00', + 'lon_0': '8.00', + 'proj': 'stere'}, + 800, + 800, + [-1370912.72, + -909968.64000000001, + 1029087.28, + 1490031.3600000001]) + res = area.downsample(x=4, y=2) + self.assertDictEqual(res.proj_dict, area.proj_dict) + np.testing.assert_allclose(res.area_extent, area.area_extent) + self.assertEqual(res.shape[0], area.shape[0] / 2) + self.assertEqual(res.shape[1], area.shape[1] / 4) + + # Test skip downsampling + res = area.downsample(x=1, y=1) + np.testing.assert_allclose(res.shape, area.shape) + assert res == area + + # Test invalid arguments + self.assertRaises(ValueError, area.downsample, 0, 0) + self.assertRaises(ValueError, area.downsample, -1, -1) + self.assertRaises(ValueError, area.downsample, 0.5, 1) + self.assertRaises(ValueError, area.downsample, 1, 0.5) + + def test_upsample(self): + """Test upsampling of AreaDefinitions.""" + area = geometry.AreaDefinition('areaD', 'Europe (3km, HRV, VTC)', 'areaD', + {'a': '6378144.0', + 'b': '6356759.0', + 'lat_0': '50.00', + 'lat_ts': '50.00', + 'lon_0': '8.00', + 'proj': 'stere'}, + 800, + 800, + [-1370912.72, + -909968.64000000001, + 1029087.28, + 1490031.3600000001]) + res = area.upsample(x=4, y=2) + self.assertDictEqual(res.proj_dict, area.proj_dict) + np.testing.assert_allclose(res.area_extent, area.area_extent) + self.assertEqual(res.shape[0], area.shape[0] * 2) + self.assertEqual(res.shape[1], area.shape[1] * 4) + + # Test skip upsampling + res = area.upsample(x=1, y=1) + np.testing.assert_allclose(res.shape, area.shape) + assert res == area + + # Test invalid arguments + self.assertRaises(ValueError, area.upsample, 0, 0) + self.assertRaises(ValueError, area.upsample, 0, 1) + self.assertRaises(ValueError, area.upsample, 0.5, 1) + self.assertRaises(ValueError, area.upsample, 1, 0.5) + + def test_shrink(self): + """Test shrinkage of AreaDefinitions.""" + from pyresample import geometry + area = geometry.AreaDefinition('areaD', 'Europe (3km, HRV, VTC)', 'areaD', + {'a': '6378144.0', + 'b': '6356759.0', + 'lat_0': '50.00', + 'lat_ts': '50.00', + 'lon_0': '8.00', + 'proj': 'stere'}, + 800, + 800, + [-1370912.72, + -909968.64000000001, + 1029087.28, + 1490031.3600000001]) + left = 8 + right = 8 + bottom = 2 + top = 2 + res = area.shrink(left, right, bottom, top) + self.assertDictEqual(res.proj_dict, area.proj_dict) + self.assertEqual(res.shape[0], area.shape[0] - bottom - top) + self.assertEqual(res.shape[1], area.shape[1] - left - right) + + # Test skip reduction + res = area.shrink() + np.testing.assert_allclose(res.area_extent, area.area_extent) + assert res == area + + # Test invalid arguments + self.assertRaises(ValueError, area.shrink, -1, -1, -1, -1) + self.assertRaises(ValueError, area.shrink, area.shape[0] / 2, area.shape[0] / 2, 0, 0) + self.assertRaises(ValueError, area.shrink, 0, 0, area.shape[1] / 2, area.shape[1] / 2) + + # Test raise NotImplementedError for GEO + area_geo = geometry.AreaDefinition(area_id='seviri', + description='SEVIRI HRIT like (flipped, south up)', + proj_id='seviri', + projection={'proj': 'geos', + 'lon_0': 0.0, + 'a': 6378169.00, + 'b': 6356583.80, + 'h': 35785831.00, + 'units': 'm'}, + width=123, height=123, + area_extent=[5500000, 5500000, -5500000, -5500000]) + self.assertRaises(NotImplementedError, area_geo.shrink, 1, 1, 0, 0) + + def test_extend(self): + """Test extension of AreaDefinitions.""" + from pyresample import geometry + area = geometry.AreaDefinition('areaD', 'Europe (3km, HRV, VTC)', 'areaD', + {'a': '6378144.0', + 'b': '6356759.0', + 'lat_0': '50.00', + 'lat_ts': '50.00', + 'lon_0': '8.00', + 'proj': 'stere'}, + 800, + 800, + [-1370912.72, + -909968.64000000001, + 1029087.28, + 1490031.3600000001]) + left = 8 + right = 8 + bottom = 2 + top = 2 + res = area.extend(left, right, bottom, top) + self.assertDictEqual(res.proj_dict, area.proj_dict) + self.assertEqual(res.shape[0], area.shape[0] + top + bottom) + self.assertEqual(res.shape[1], area.shape[1] + left + right) + + # Test skip extension + res = area.extend() + np.testing.assert_allclose(res.area_extent, area.area_extent) + assert res == area + + # Test invalid arguments + self.assertRaises(ValueError, area.extend, -1, -1, -1, -1) + + # Test raise NotImplementedError for GEO + area_geo = geometry.AreaDefinition(area_id='seviri', + description='SEVIRI HRIT like (flipped, south up)', + proj_id='seviri', + projection={'proj': 'geos', + 'lon_0': 0.0, + 'a': 6378169.00, + 'b': 6356583.80, + 'h': 35785831.00, + 'units': 'm'}, + width=123, height=123, + area_extent=[5500000, 5500000, -5500000, -5500000]) + self.assertRaises(NotImplementedError, area_geo.extend, 1, 1, 0, 0) + + +class TestSwathDefinitionDownsampling(unittest.TestCase): + """Test Downsampling SwathDefinition.""" + + # TODO: Martin + @classmethod + def setUpClass(cls): + """Do some setup for the test class.""" + import dask.array as da + import numpy as np + import xarray as xr + + from pyresample.geometry import SwathDefinition + resolution = 3 + lats = np.array([[0, 0, 0, 0], [1, 1, 1, 1.0]]) + lons = np.array([[178.5, 179.5, -179.5, -178.5], [178.5, 179.5, -179.5, -178.5]]) + lats_dask = da.from_array(lats, chunks=2) + lons_dask = da.from_array(lons, chunks=2) + lats_xr = xr.DataArray(lats, dims=['y', 'x'], + attrs={'resolution': resolution}) + lons_xr = xr.DataArray(lons, dims=['y', 'x'], + attrs={'resolution': resolution}) + lats_xr_dask = xr.DataArray(lats_dask, dims=['y', 'x'], + attrs={'resolution': resolution}) + lons_xr_dask = xr.DataArray(lons_dask, dims=['y', 'x'], + attrs={'resolution': resolution}) + sd_np = SwathDefinition(lons, lats) + sd_xr = SwathDefinition(lons_xr, lats_xr) + sd_xr_dask = SwathDefinition(lons_xr_dask, lats_xr_dask) + + cls.lons = lons + cls.lats = lats + cls.lons_dask = lons_dask + cls.lats_dask = lats_dask + cls.lons_xr = lons_xr + cls.lats_xr = lons_xr + cls.lons_xr_dask = lons_xr_dask + cls.lats_xr_dask = lats_xr_dask + cls.sd_np = sd_np + cls.sd_xr = sd_xr + cls.sd_xr_dask = sd_xr_dask + cls.resolution = resolution + + def test_downsampling_keeps_arrays(self): + """Test array format is kept.""" + import dask.array as da + import numpy as np + import xarray as xr + + window_size = 2 + res_np = self.sd_np.downsample(y=window_size, x=window_size) + res_xr = self.sd_xr.downsample(y=window_size, x=window_size) + res_xr_dask = self.sd_xr_dask.downsample(y=window_size, x=window_size) + + assert isinstance(res_np.lats, np.ndarray) + assert isinstance(res_xr.lats, xr.DataArray) and isinstance(res_xr.lats.data, np.ndarray) + assert isinstance(res_xr_dask.lats, xr.DataArray) and isinstance(res_xr_dask.lats.data, da.Array) + + def test_downsampling_results_consistency(self): + """Test array format is kept.""" + import numpy as np + window_size = 2 + res_np = self.sd_np.downsample(y=window_size, x=window_size) + res_xr = self.sd_xr.downsample(y=window_size, x=window_size) + res_xr_dask = self.sd_xr_dask.downsample(y=window_size, x=window_size) + + np.testing.assert_allclose(res_np.lons, [[179, -179]]) + np.testing.assert_allclose(res_np.lats, [[0.5, 0.5]], atol=2e-5) + np.testing.assert_allclose(res_xr.lons.data, res_np.lons) + np.testing.assert_allclose(res_xr.lats.data, res_np.lats) + np.testing.assert_allclose(res_xr_dask.lons.values, res_np.lons) + np.testing.assert_allclose(res_xr_dask.lats.values, res_np.lats) + + def test_downsampling_modify_resolution_attrs(self): + window_size = 2 + res_xr = self.sd_xr.downsample(y=window_size, x=window_size) + self.assertAlmostEqual(res_xr.lons.resolution, self.resolution * window_size) + self.assertAlmostEqual(res_xr.lats.resolution, self.resolution * window_size) + + def test_downsampling_default_skip(self): + import numpy as np + res_np = self.sd_np.downsample(y=1, x=1) + np.testing.assert_allclose(res_np.lats, self.sd_np.lats) + np.testing.assert_allclose(res_np.lons, self.sd_np.lons) + + def test_downsampling_valid_args(self): + # Test invalid arguments + self.assertRaises(ValueError, self.sd_np.downsample, 0, 0) + self.assertRaises(ValueError, self.sd_np.downsample, -1, -1) + def test_enclose_areas(): """Test enclosing areas.""" diff --git a/pyresample/test/test_utils_array.py b/pyresample/test/test_utils_array.py new file mode 100644 index 000000000..a43506f43 --- /dev/null +++ b/pyresample/test/test_utils_array.py @@ -0,0 +1,122 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# pyresample, Resampling of remote sensing image data in python +# +# Copyright (C) 2010-2022 Pyresample developers +# +# This program is free software: you can redistribute it and/or modify it under +# the terms of the GNU Lesser General Public License as published by the Free +# Software Foundation, either version 3 of the License, or (at your option) any +# later version. +# +# This program is distributed in the hope that it will be useful, but WITHOUT +# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS +# FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more +# details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with this program. If not, see . +"""Test utils array .""" +import unittest + +import dask.array as da +import numpy as np +import xarray as xr + +VALID_FORMAT = ['Numpy', 'Dask', 'DataArray_Numpy', 'DataArray_Dask'] + +lons_np = np.arange(-179.5, -177.5, 0.5) +lats_np = np.arange(-89.5, -88.0, 0.5) +lons_np, lats_np = np.meshgrid(lons_np, lats_np) +lats_dask = da.from_array(lats_np, chunks=2) +lats_xr = xr.DataArray(lats_np, dims=['y', 'x']) +lats_xr_dask = xr.DataArray(lats_dask, dims=['y', 'x']) +dict_format = {'Numpy': lats_np, + 'Dask': lats_dask, + 'DataArray_Numpy': lats_xr, + 'DataArray_Dask': lats_xr_dask + } + + +class TestArrayConversion(unittest.TestCase): + """Unit testing the array conversion.""" + + def test_numpy_conversion(self): + from pyresample.utils.array import _numpy_conversion + in_format = "Numpy" + dims = None + for out_format in VALID_FORMAT: + out_arr = _numpy_conversion(dict_format[in_format], to=out_format, dims=dims) + assert isinstance(out_arr, type(dict_format[out_format])) + if out_format.lower() == "dataarray_numpy": + assert isinstance(out_arr.data, type(dict_format['Numpy'])) + if out_format.lower() == "dataarray_dask": + assert isinstance(out_arr.data, type(dict_format['Dask'])) + # Test raise errors + self.assertRaises(TypeError, _numpy_conversion, dict_format[in_format], ['unvalid_type']) + self.assertRaises(TypeError, _numpy_conversion, ['unvalid_type'], in_format) + self.assertRaises(ValueError, _numpy_conversion, dict_format[in_format], 'unvalid_format') + + def test_dask_conversion(self): + from pyresample.utils.array import _dask_conversion + in_format = "Dask" + dims = None + for out_format in VALID_FORMAT: + out_arr = _dask_conversion(dict_format[in_format], to=out_format, dims=dims) + assert isinstance(out_arr, type(dict_format[out_format])) + if out_format.lower() == "dataarray_numpy": + assert isinstance(out_arr.data, type(dict_format['Numpy'])) + if out_format.lower() == "dataarray_dask": + assert isinstance(out_arr.data, type(dict_format['Dask'])) + # Test raise errors + self.assertRaises(TypeError, _dask_conversion, dict_format[in_format], ['unvalid_type']) + self.assertRaises(TypeError, _dask_conversion, ['unvalid_type'], in_format) + self.assertRaises(ValueError, _dask_conversion, dict_format[in_format], 'unvalid_format') + + def test_xr_numpy_conversion(self): + from pyresample.utils.array import _xr_numpy_conversion + in_format = "DataArray_Numpy" + for out_format in VALID_FORMAT: + out_arr = _xr_numpy_conversion(dict_format[in_format], to=out_format) + assert isinstance(out_arr, type(dict_format[out_format])) + if out_format.lower() == "dataarray_numpy": + assert isinstance(out_arr.data, type(dict_format['Numpy'])) + if out_format.lower() == "dataarray_dask": + assert isinstance(out_arr.data, type(dict_format['Dask'])) + # Test raise errors + self.assertRaises(TypeError, _xr_numpy_conversion, dict_format[in_format], ['unvalid_type']) + self.assertRaises(TypeError, _xr_numpy_conversion, ['unvalid_type'], in_format) + self.assertRaises(ValueError, _xr_numpy_conversion, dict_format[in_format], 'unvalid_format') + + def test_xr_dask_conversion(self): + from pyresample.utils.array import _xr_dask_conversion + in_format = "DataArray_Dask" + for out_format in VALID_FORMAT: + out_arr = _xr_dask_conversion(dict_format[in_format], to=out_format) + assert isinstance(out_arr, type(dict_format[out_format])) + if out_format.lower() == "dataarray_numpy": + assert isinstance(out_arr.data, type(dict_format['Numpy'])) + if out_format.lower() == "dataarray_dask": + assert isinstance(out_arr.data, type(dict_format['Dask'])) + # Test raise errors + self.assertRaises(TypeError, _xr_dask_conversion, dict_format[in_format], ['unvalid_type']) + self.assertRaises(TypeError, _xr_dask_conversion, ['unvalid_type'], in_format) + self.assertRaises(ValueError, _xr_dask_conversion, dict_format[in_format], 'unvalid_format') + + def test_convert_2D_array(self): + """Test conversion of 2D arrays between various formats.""" + from pyresample.utils.array import _convert_2D_array + dims = None + for in_format in VALID_FORMAT: + for out_format in VALID_FORMAT: + out_arr, src_format = _convert_2D_array(dict_format[in_format], to=out_format, dims=dims) + assert isinstance(out_arr, type(dict_format[out_format])) + assert src_format.lower() == in_format.lower() + if out_format.lower() == "dataarray_numpy": + assert isinstance(out_arr.data, type(dict_format['Numpy'])) + if out_format.lower() == "dataarray_dask": + assert isinstance(out_arr.data, type(dict_format['Dask'])) + # Test raise errors + self.assertRaises(TypeError, _convert_2D_array, dict_format['Numpy'], ['unvalid_type']) + self.assertRaises(TypeError, _convert_2D_array, [dict_format['Numpy']], 'numpy') + self.assertRaises(ValueError, _convert_2D_array, dict_format['Numpy'], 'unvalid_format') diff --git a/pyresample/utils/array.py b/pyresample/utils/array.py new file mode 100644 index 000000000..0d91986f5 --- /dev/null +++ b/pyresample/utils/array.py @@ -0,0 +1,150 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright (C) 2019-2021 Pyresample developers +# +# This program is free software: you can redistribute it and/or modify it under +# the terms of the GNU Lesser General Public License as published by the Free +# Software Foundation, either version 3 of the License, or (at your option) any +# later version. +# +# This program is distributed in the hope that it will be useful, but WITHOUT +# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS +# FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more +# details. +# +# You should have received a copy of the GNU Lesser General Public License along +# with this program. If not, see . +"""Utilities for converting array.""" + +import dask.array as da +import numpy as np +import xarray as xr + +VALID_FORMAT = ['Numpy', 'Dask', 'DataArray_Numpy', 'DataArray_Dask'] + + +def _numpy_conversion(arr, to, dims=None): + if not isinstance(to, str): + raise TypeError("'to' must be a string indicating the conversion array format.") + if not isinstance(arr, np.ndarray): + raise TypeError("_numpy_conversion expects a np.ndarray as input.") + if not np.isin(to.lower(), np.char.lower(VALID_FORMAT)): + raise ValueError("Valid _numpy_conversion array formats are {}".format(VALID_FORMAT)) + if to.lower() == 'numpy': + dst_arr = arr + elif to.lower() == 'dask': + dst_arr = da.from_array(arr) + elif to.lower() == 'dataarray_numpy': + dst_arr = xr.DataArray(arr, dims=dims) + else: # to.lower() == 'dataarray_dask': + dst_arr = xr.DataArray(da.from_array(arr), dims=dims) + return dst_arr + + +def _dask_conversion(arr, to, dims=None): + if not isinstance(to, str): + raise TypeError("'to' must be a string indicating the conversion array format.") + if not isinstance(arr, da.Array): + raise TypeError("_dask_conversion expects a dask.Array as input.") + if not np.isin(to.lower(), np.char.lower(VALID_FORMAT)): + raise ValueError("Valid _dask_conversion array formats are {}".format(VALID_FORMAT)) + if to.lower() == 'numpy': + dst_arr = arr.compute() + elif to.lower() == 'dask': + dst_arr = arr + elif to.lower() == 'dataarray_numpy': + dst_arr = xr.DataArray(arr.compute(), dims=dims) + else: # to.lower() == 'dataarray_dask': + dst_arr = xr.DataArray(arr, dims=dims) + return dst_arr + + +def _xr_numpy_conversion(arr, to): + if not isinstance(to, str): + raise TypeError("'to' must be a string indicating the conversion array format.") + if not isinstance(arr, xr.DataArray): + raise TypeError("_xr_numpy_conversion expects a xr.DataArray with numpy array as input.") + if not isinstance(arr.data, np.ndarray): + raise TypeError("_xr_numpy_conversion expects a xr.DataArray with numpy array as input.") + if not np.isin(to.lower(), np.char.lower(VALID_FORMAT)): + raise ValueError("Valid _xr_numpy_conversion array formats are {}".format(VALID_FORMAT)) + if to.lower() == 'numpy': + dst_arr = arr.data + elif to.lower() == 'dask': + dst_arr = da.from_array(arr.data) + elif to.lower() == 'dataarray_numpy': + dst_arr = arr + else: # to.lower() == 'dataarray_dask': + dst_arr = xr.DataArray(da.from_array(arr.data), dims=arr.dims) + return dst_arr + + +def _xr_dask_conversion(arr, to): + if not isinstance(to, str): + raise TypeError("'to' must be a string indicating the conversion array format.") + if not isinstance(arr, xr.DataArray): + raise TypeError("_xr_dask_conversion expects a xr.DataArray with dask.Array as input.") + if not isinstance(arr.data, da.Array): + raise TypeError("_xr_dask_conversion expects a xr.DataArray with dask.Array as input.") + if not np.isin(to.lower(), np.char.lower(VALID_FORMAT)): + raise ValueError("Valid _xr_dask_conversion array formats are {}".format(VALID_FORMAT)) + if to.lower() == 'numpy': + dst_arr = arr.data.compute() + elif to.lower() == 'dask': + dst_arr = arr.data + elif to.lower() == 'dataarray_numpy': + dst_arr = arr.compute() + else: # to.lower() == 'dataarray_dask': + dst_arr = arr + return dst_arr + + +def _convert_2D_array(arr, to, dims=None): + """ + Convert a 2D array to a specific format. + + Useful to return swath lons, lats in the same original format after processing. + + Parameters + ---------- + arr : (np.ndarray, da.Array, xr.DataArray) + The 2D array to be converted to another array format. + to : TYPE + The desired array output format. + Accepted formats are: ['Numpy','Dask', 'DataArray_Numpy','DataArray_Dask'] + dims : tuple, optional + Optional argument for the specification of xr.DataArray dimension names + if the input array is Numpy or Dask. + Does not have any impact if the input is already a xr.DataArray + Provide a tuple with (y_dimname, x_dimname). + The default is None --> (dim_0, dim_1) + + + Returns + ------- + dst_arr : (np.ndarray, da.Array, xr.DataArray) + The converted 2D array. + src_format: str + The source format of the 2D array. + + """ + # Checks + if not isinstance(to, str): + raise TypeError("'to' must be a string indicating the conversion array format.") + if not np.isin(to.lower(), np.char.lower(VALID_FORMAT)): + raise ValueError("Valid conversion array formats are {}".format(VALID_FORMAT)) + if not isinstance(arr, (np.ndarray, da.Array, xr.DataArray)): + raise TypeError("The provided array must be either a np.ndarray, a dask.Array or a xr.DataArray.") + # Numpy + if isinstance(arr, np.ndarray): + return _numpy_conversion(arr, to=to, dims=dims), "numpy" + # Dask + elif isinstance(arr, da.Array): + return _dask_conversion(arr, to=to, dims=dims), 'dask' + # DataArray_Numpy + elif isinstance(arr, xr.DataArray) and isinstance(arr.data, np.ndarray): + return _xr_numpy_conversion(arr, to=to), 'DataArray_Numpy' + # DataArray_Dask + else: # isinstance(arr, xr.DataArray) and isinstance(arr.data, da.Array): + return _xr_dask_conversion(arr, to=to), 'DataArray_Dask'