5
5
pytest .importorskip ("dask" )
6
6
pytest .importorskip ("cftime" )
7
7
8
- import cftime
9
8
import dask
10
- import hypothesis .extra .numpy as npst
11
9
import hypothesis .strategies as st
12
10
import numpy as np
13
11
from hypothesis import assume , given , note
14
12
15
13
import flox
16
14
from flox .core import groupby_reduce , groupby_scan
17
15
18
- from . import ALL_FUNCS , SCIPY_STATS_FUNCS , assert_equal
16
+ from . import assert_equal
17
+ from .strategies import all_arrays , by_arrays , chunked_arrays , func_st , numeric_arrays
19
18
20
19
dask .config .set (scheduler = "sync" )
21
20
@@ -32,94 +31,13 @@ def bfill(array, axis, dtype=None):
32
31
)[::- 1 ]
33
32
34
33
35
- NON_NUMPY_FUNCS = ["first" , "last" , "nanfirst" , "nanlast" , "count" , "any" , "all" ] + list (
36
- SCIPY_STATS_FUNCS
37
- )
38
- SKIPPED_FUNCS = ["var" , "std" , "nanvar" , "nanstd" ]
39
34
NUMPY_SCAN_FUNCS = {
40
35
"nancumsum" : np .nancumsum ,
41
36
"ffill" : ffill ,
42
37
"bfill" : bfill ,
43
38
} # "cumsum": np.cumsum,
44
39
45
40
46
- def supported_dtypes () -> st .SearchStrategy [np .dtype ]:
47
- return (
48
- npst .integer_dtypes (endianness = "=" )
49
- | npst .unsigned_integer_dtypes (endianness = "=" )
50
- | npst .floating_dtypes (endianness = "=" , sizes = (32 , 64 ))
51
- | npst .complex_number_dtypes (endianness = "=" )
52
- | npst .datetime64_dtypes (endianness = "=" )
53
- | npst .timedelta64_dtypes (endianness = "=" )
54
- | npst .unicode_string_dtypes (endianness = "=" )
55
- )
56
-
57
-
58
- # TODO: stop excluding everything but U
59
- array_dtype_st = supported_dtypes ().filter (lambda x : x .kind not in "cmMU" )
60
- by_dtype_st = supported_dtypes ()
61
- func_st = st .sampled_from (
62
- [f for f in ALL_FUNCS if f not in NON_NUMPY_FUNCS and f not in SKIPPED_FUNCS ]
63
- )
64
- numeric_arrays = npst .arrays (
65
- elements = {"allow_subnormal" : False }, shape = npst .array_shapes (), dtype = array_dtype_st
66
- )
67
- all_arrays = npst .arrays (
68
- elements = {"allow_subnormal" : False }, shape = npst .array_shapes (), dtype = supported_dtypes ()
69
- )
70
-
71
- calendars = st .sampled_from (
72
- [
73
- "standard" ,
74
- "gregorian" ,
75
- "proleptic_gregorian" ,
76
- "noleap" ,
77
- "365_day" ,
78
- "360_day" ,
79
- "julian" ,
80
- "all_leap" ,
81
- "366_day" ,
82
- ]
83
- )
84
-
85
-
86
- @st .composite
87
- def units (draw , * , calendar : str ):
88
- choices = ["days" , "hours" , "minutes" , "seconds" , "milliseconds" , "microseconds" ]
89
- if calendar == "360_day" :
90
- choices += ["months" ]
91
- elif calendar == "noleap" :
92
- choices += ["common_years" ]
93
- time_units = draw (st .sampled_from (choices ))
94
-
95
- dt = draw (st .datetimes ())
96
- year , month , day = dt .year , dt .month , dt .day
97
- if calendar == "360_day" :
98
- month %= 30
99
- return f"{ time_units } since { year } -{ month } -{ day } "
100
-
101
-
102
- @st .composite
103
- def cftime_arrays (draw , * , shape , calendars = calendars , elements = None ):
104
- if elements is None :
105
- elements = {"min_value" : - 10_000 , "max_value" : 10_000 }
106
- cal = draw (calendars )
107
- values = draw (npst .arrays (dtype = np .int64 , shape = shape , elements = elements ))
108
- unit = draw (units (calendar = cal ))
109
- return cftime .num2date (values , units = unit , calendar = cal )
110
-
111
-
112
- def by_arrays (shape , * , elements = None ):
113
- return st .one_of (
114
- npst .arrays (
115
- dtype = npst .integer_dtypes (endianness = "=" ) | npst .unicode_string_dtypes (endianness = "=" ),
116
- shape = shape ,
117
- elements = elements ,
118
- ),
119
- cftime_arrays (shape = shape , elements = elements ),
120
- )
121
-
122
-
123
41
def not_overflowing_array (array ) -> bool :
124
42
if array .dtype .kind == "f" :
125
43
info = np .finfo (array .dtype )
@@ -133,40 +51,6 @@ def not_overflowing_array(array) -> bool:
133
51
return result
134
52
135
53
136
- @st .composite
137
- def chunks (draw , * , shape : tuple [int , ...]) -> tuple [tuple [int , ...], ...]:
138
- chunks = []
139
- for size in shape :
140
- if size > 1 :
141
- nchunks = draw (st .integers (min_value = 1 , max_value = size - 1 ))
142
- dividers = sorted (
143
- set (draw (st .integers (min_value = 1 , max_value = size - 1 )) for _ in range (nchunks - 1 ))
144
- )
145
- chunks .append (tuple (a - b for a , b in zip (dividers + [size ], [0 ] + dividers )))
146
- else :
147
- chunks .append ((1 ,))
148
- return tuple (chunks )
149
-
150
-
151
- @st .composite
152
- def chunked_arrays (draw , * , chunks = chunks , arrays = numeric_arrays , from_array = dask .array .from_array ):
153
- array = draw (arrays )
154
- chunks = draw (chunks (shape = array .shape ))
155
-
156
- if array .dtype .kind in "cf" :
157
- nan_idx = draw (
158
- st .lists (
159
- st .integers (min_value = 0 , max_value = array .shape [- 1 ] - 1 ),
160
- max_size = array .shape [- 1 ] - 1 ,
161
- unique = True ,
162
- )
163
- )
164
- if nan_idx :
165
- array [..., nan_idx ] = np .nan
166
-
167
- return from_array (array , chunks = chunks )
168
-
169
-
170
54
# TODO: migrate to by_arrays but with constant value
171
55
@given (data = st .data (), array = numeric_arrays , func = func_st )
172
56
def test_groupby_reduce (data , array , func ):
0 commit comments