1+ # standard library
12from math import sqrt
23from typing import Optional , Callable , Union , Dict , List
34
5+ # 3rd-party
6+ import dask .array as da
7+ import dask .dataframe as dd
8+ from dask import delayed
9+
410import numpy as np
511import pandas as pd
612import xarray as xr
713from xarray import DataArray
814
9- import dask .array as da
10- import dask .dataframe as dd
11- from dask import delayed
15+ try :
16+ import cupy
17+ except ImportError :
18+ class cupy (object ):
19+ ndarray = False
1220
13- from xrspatial .utils import ngjit , validate_arrays
21+ # local modules
22+ from xrspatial .utils import ngjit
23+ from xrspatial .utils import validate_arrays
1424from xrspatial .utils import ArrayTypeFunctionMapping
1525from xrspatial .utils import not_implemented_func
1626
17-
1827TOTAL_COUNT = '_total_count'
1928
2029
2130def _stats_count (data ):
2231 if isinstance (data , np .ndarray ):
2332 # numpy case
2433 stats_count = np .ma .count (data )
34+ elif isinstance (data , cupy .ndarray ):
35+ # cupy case
36+ stats_count = np .prod (data .shape )
2537 else :
2638 # dask case
2739 stats_count = data .size - da .ma .getmaskarray (data ).sum ()
@@ -56,9 +68,9 @@ def _stats_count(data):
5668 sum_squares = lambda block_sum_squares : np .nansum (block_sum_squares , axis = 0 ),
5769 squared_sum = lambda block_sums : np .nansum (block_sums , axis = 0 )** 2 ,
5870)
59- _dask_mean = lambda sums , counts : sums / counts # noqa
60- _dask_std = lambda sum_squares , squared_sum , n : np .sqrt ((sum_squares - squared_sum / n ) / n ) # noqa
61- _dask_var = lambda sum_squares , squared_sum , n : (sum_squares - squared_sum / n ) / n # noqa
71+ def _dask_mean ( sums , counts ): return sums / counts # noqa
72+ def _dask_std ( sum_squares , squared_sum , n ): return np .sqrt ((sum_squares - squared_sum / n ) / n ) # noqa
73+ def _dask_var ( sum_squares , squared_sum , n ): return (sum_squares - squared_sum / n ) / n # noqa
6274
6375
6476@ngjit
@@ -282,6 +294,81 @@ def _stats_numpy(
282294 return stats_df
283295
284296
297+ def _stats_cupy (
298+ orig_zones : xr .DataArray ,
299+ orig_values : xr .DataArray ,
300+ zone_ids : List [Union [int , float ]],
301+ stats_funcs : Dict ,
302+ nodata_values : Union [int , float ],
303+ ) -> pd .DataFrame :
304+
305+ # TODO add support for 3D input
306+ if len (orig_values .shape ) > 2 :
307+ raise TypeError ('3D inputs not supported for cupy backend' )
308+
309+ zones = cupy .ravel (orig_zones )
310+ values = cupy .ravel (orig_values )
311+
312+ sorted_indices = cupy .argsort (zones )
313+
314+ sorted_zones = zones [sorted_indices ]
315+ values_by_zone = values [sorted_indices ]
316+
317+ # filter out values that are non-finite or values equal to nodata_values
318+ if nodata_values :
319+ filter_values = cupy .isfinite (values_by_zone ) & (
320+ values_by_zone != nodata_values )
321+ else :
322+ filter_values = cupy .isfinite (values_by_zone )
323+ values_by_zone = values_by_zone [filter_values ]
324+ sorted_zones = sorted_zones [filter_values ]
325+
326+ # Now I need to find the unique zones, and zone breaks
327+ unique_zones , unique_index = cupy .unique (sorted_zones , return_index = True )
328+
329+ # Transfer to the host
330+ unique_index = unique_index .get ()
331+ if zone_ids is None :
332+ unique_zones = unique_zones .get ()
333+ else :
334+ unique_zones = zone_ids
335+ # unique_zones = list(map(_to_int, unique_zones))
336+ unique_zones = np .asarray (unique_zones )
337+
338+ # stats columns
339+ stats_dict = {'zone' : []}
340+ for stats in stats_funcs :
341+ stats_dict [stats ] = []
342+
343+ for i in range (len (unique_zones )):
344+ zone_id = unique_zones [i ]
345+ # skip zone_id == nodata_zones, and non-finite zone ids
346+ if not np .isfinite (zone_id ):
347+ continue
348+
349+ stats_dict ['zone' ].append (zone_id )
350+ # extract zone_values
351+ if i < len (unique_zones ) - 1 :
352+ zone_values = values_by_zone [unique_index [i ]:unique_index [i + 1 ]]
353+ else :
354+ zone_values = values_by_zone [unique_index [i ]:]
355+
356+ # apply stats on the zone data
357+ for j , stats in enumerate (stats_funcs ):
358+ stats_func = stats_funcs .get (stats )
359+ if not callable (stats_func ):
360+ raise ValueError (stats )
361+ result = stats_func (zone_values )
362+
363+ assert (len (result .shape ) == 0 )
364+
365+ stats_dict [stats ].append (cupy .float (result ))
366+
367+ stats_df = pd .DataFrame (stats_dict )
368+ stats_df .set_index ("zone" )
369+ return stats_df
370+
371+
285372def stats (
286373 zones : xr .DataArray ,
287374 values : xr .DataArray ,
@@ -461,13 +548,11 @@ def stats(
461548 if isinstance (stats_funcs , list ):
462549 # create a dict of stats
463550 stats_funcs_dict = {}
464-
465551 for stats in stats_funcs :
466552 func = _DEFAULT_STATS .get (stats , None )
467553 if func is None :
468554 err_str = f"Invalid stat name. { stats } option not supported."
469555 raise ValueError (err_str )
470-
471556 stats_funcs_dict [stats ] = func
472557
473558 elif isinstance (stats_funcs , dict ):
@@ -476,9 +561,7 @@ def stats(
476561 mapper = ArrayTypeFunctionMapping (
477562 numpy_func = _stats_numpy ,
478563 dask_func = _stats_dask_numpy ,
479- cupy_func = lambda * args : not_implemented_func (
480- * args , messages = 'stats() does not support cupy backed DataArray'
481- ),
564+ cupy_func = _stats_cupy ,
482565 dask_cupy_func = lambda * args : not_implemented_func (
483566 * args , messages = 'stats() does not support dask with cupy backed DataArray' # noqa
484567 ),
@@ -841,13 +924,13 @@ def crosstab(
841924 >>> df = crosstab(zones=zones_dask, values=values_dask)
842925 >>> print(df)
843926 Dask DataFrame Structure:
844- zone 0.0 10.0 20.0 30.0 40.0 50.0
927+ zone 0.0 10.0 20.0 30.0 40.0 50.0
845928 npartitions=5
846- 0 float64 int64 int64 int64 int64 int64 int64
847- 1 ... ... ... ... ... ... ...
848- ... ... ... ... ... ... ... ...
849- 4 ... ... ... ... ... ... ...
850- 5 ... ... ... ... ... ... ...
929+ 0 float64 int64 int64 int64 int64 int64 int64
930+ 1 ... ... ... ... ... ... ...
931+ ... ... ... ... ... ... ... ...
932+ 4 ... ... ... ... ... ... ...
933+ 5 ... ... ... ... ... ... ...
851934 Dask Name: astype, 1186 tasks
852935 >>> print(dask_df.compute)
853936 zone 0.0 10.0 20.0 30.0 40.0 50.0
@@ -1214,7 +1297,7 @@ def _area_connectivity(data, n=4):
12141297 src_window [4 ] = data [min (y + 1 , rows - 1 ), x ]
12151298 src_window [5 ] = data [max (y - 1 , 0 ), min (x + 1 , cols - 1 )]
12161299 src_window [6 ] = data [y , min (x + 1 , cols - 1 )]
1217- src_window [7 ] = data [min (y + 1 , rows - 1 ), min (x + 1 , cols - 1 )] # noqa
1300+ src_window [7 ] = data [min (y + 1 , rows - 1 ), min (x + 1 , cols - 1 )] # noqa
12181301
12191302 area_window [0 ] = out [max (y - 1 , 0 ), max (x - 1 , 0 )]
12201303 area_window [1 ] = out [y , max (x - 1 , 0 )]
@@ -1223,7 +1306,7 @@ def _area_connectivity(data, n=4):
12231306 area_window [4 ] = out [min (y + 1 , rows - 1 ), x ]
12241307 area_window [5 ] = out [max (y - 1 , 0 ), min (x + 1 , cols - 1 )]
12251308 area_window [6 ] = out [y , min (x + 1 , cols - 1 )]
1226- area_window [7 ] = out [min (y + 1 , rows - 1 ), min (x + 1 , cols - 1 )] # noqa
1309+ area_window [7 ] = out [min (y + 1 , rows - 1 ), min (x + 1 , cols - 1 )] # noqa
12271310
12281311 else :
12291312 src_window [0 ] = data [y , max (x - 1 , 0 )]
@@ -1272,7 +1355,7 @@ def _area_connectivity(data, n=4):
12721355 src_window [4 ] = data [min (y + 1 , rows - 1 ), x ]
12731356 src_window [5 ] = data [max (y - 1 , 0 ), min (x + 1 , cols - 1 )]
12741357 src_window [6 ] = data [y , min (x + 1 , cols - 1 )]
1275- src_window [7 ] = data [min (y + 1 , rows - 1 ), min (x + 1 , cols - 1 )] # noqa
1358+ src_window [7 ] = data [min (y + 1 , rows - 1 ), min (x + 1 , cols - 1 )] # noqa
12761359
12771360 area_window [0 ] = out [max (y - 1 , 0 ), max (x - 1 , 0 )]
12781361 area_window [1 ] = out [y , max (x - 1 , 0 )]
@@ -1281,7 +1364,7 @@ def _area_connectivity(data, n=4):
12811364 area_window [4 ] = out [min (y + 1 , rows - 1 ), x ]
12821365 area_window [5 ] = out [max (y - 1 , 0 ), min (x + 1 , cols - 1 )]
12831366 area_window [6 ] = out [y , min (x + 1 , cols - 1 )]
1284- area_window [7 ] = out [min (y + 1 , rows - 1 ), min (x + 1 , cols - 1 )] # noqa
1367+ area_window [7 ] = out [min (y + 1 , rows - 1 ), min (x + 1 , cols - 1 )] # noqa
12851368
12861369 else :
12871370 src_window [0 ] = data [y , max (x - 1 , 0 )]
0 commit comments