24
24
)
25
25
26
26
27
+ def chunk_aligned_slices (z , n ):
28
+ """
29
+ Returns at n slices in the specified zarr array, aligned
30
+ with its chunks
31
+ """
32
+ chunk_size = z .chunks [0 ]
33
+ num_chunks = int (np .ceil (z .shape [0 ] / chunk_size ))
34
+ slices = []
35
+ splits = np .array_split (np .arange (num_chunks ), min (n , num_chunks ))
36
+ for split in splits :
37
+ start = split [0 ] * chunk_size
38
+ stop = (split [- 1 ] + 1 ) * chunk_size
39
+ stop = min (stop , z .shape [0 ])
40
+ slices .append ((start , stop ))
41
+ return slices
42
+
43
+
27
44
class SynchronousExecutor (cf .Executor ):
28
45
def submit (self , fn , / , * args , ** kwargs ):
29
46
future = cf .Future ()
@@ -46,107 +63,66 @@ def cancel_futures(futures):
46
63
@dataclasses .dataclass
47
64
class BufferedArray :
48
65
array : zarr .Array
66
+ array_offset : int
49
67
buff : np .ndarray
68
+ buffer_row : int
50
69
51
- def __init__ (self , array ):
70
+ def __init__ (self , array , offset ):
52
71
self .array = array
72
+ self .array_offset = offset
73
+ assert offset % array .chunks [0 ] == 0
53
74
dims = list (array .shape )
54
75
dims [0 ] = min (array .chunks [0 ], array .shape [0 ])
55
76
self .buff = np .zeros (dims , dtype = array .dtype )
77
+ self .buffer_row = 0
56
78
57
79
@property
58
80
def chunk_length (self ):
59
81
return self .buff .shape [0 ]
60
82
61
- def swap_buffers (self ):
62
- self .buff = np .zeros_like (self .buff )
63
-
64
- def async_flush (self , executor , offset , buff_stop = None ):
65
- return async_flush_array (executor , self .buff [:buff_stop ], self .array , offset )
66
-
67
-
68
- # TODO: factor these functions into the BufferedArray class
83
+ def next_buffer_row (self ):
84
+ if self .buffer_row == self .chunk_length :
85
+ self .flush ()
86
+ row = self .buffer_row
87
+ self .buffer_row += 1
88
+ return row
89
+
90
+ def flush (self ):
91
+ # TODO just move sync_flush_array in here
92
+ if self .buffer_row != 0 :
93
+ if len (self .array .chunks ) <= 1 :
94
+ sync_flush_1d_array (
95
+ self .buff [: self .buffer_row ], self .array , self .array_offset
96
+ )
97
+ else :
98
+ sync_flush_2d_array (
99
+ self .buff [: self .buffer_row ], self .array , self .array_offset
100
+ )
101
+ logger .debug (
102
+ f"Flushed chunk { self .array } { self .array_offset } + { self .buffer_row } " )
103
+ self .array_offset += self .chunk_length
104
+ self .buffer_row = 0
69
105
70
106
71
- def sync_flush_array (np_buffer , zarr_array , offset ):
107
+ def sync_flush_1d_array (np_buffer , zarr_array , offset ):
72
108
zarr_array [offset : offset + np_buffer .shape [0 ]] = np_buffer
109
+ update_progress (1 )
73
110
74
111
75
- def async_flush_array (executor , np_buffer , zarr_array , offset ):
76
- """
77
- Flush the specified chunk aligned buffer to the specified zarr array.
78
- """
79
- logger .debug (f"Schedule flush { zarr_array } @ { offset } " )
80
- assert zarr_array .shape [1 :] == np_buffer .shape [1 :]
81
- # print("sync", zarr_array, np_buffer)
82
-
83
- if len (np_buffer .shape ) == 1 :
84
- futures = [executor .submit (sync_flush_array , np_buffer , zarr_array , offset )]
85
- else :
86
- futures = async_flush_2d_array (executor , np_buffer , zarr_array , offset )
87
- return futures
88
-
89
-
90
- def async_flush_2d_array (executor , np_buffer , zarr_array , offset ):
91
- # Flush each of the chunks in the second dimension separately
112
+ def sync_flush_2d_array (np_buffer , zarr_array , offset ):
113
+ # Write chunks in the second dimension 1-by-1 to make progress more
114
+ # incremental, and to avoid large memcopies in the underlying
115
+ # encoder implementations.
92
116
s = slice (offset , offset + np_buffer .shape [0 ])
93
-
94
- def flush_chunk (start , stop ):
95
- zarr_array [s , start :stop ] = np_buffer [:, start :stop ]
96
-
97
117
chunk_width = zarr_array .chunks [1 ]
98
118
zarr_array_width = zarr_array .shape [1 ]
99
119
start = 0
100
- futures = []
101
120
while start < zarr_array_width :
102
121
stop = min (start + chunk_width , zarr_array_width )
103
- future = executor . submit ( flush_chunk , start , stop )
104
- futures . append ( future )
122
+ zarr_array [ s , start : stop ] = np_buffer [: , start : stop ]
123
+ update_progress ( 1 )
105
124
start = stop
106
125
107
- return futures
108
-
109
-
110
- class ThreadedZarrEncoder (contextlib .AbstractContextManager ):
111
- # TODO (maybe) add option with encoder_threads=None to run synchronously for
112
- # debugging using a mock Executor
113
- def __init__ (self , buffered_arrays , encoder_threads = 1 ):
114
- self .buffered_arrays = buffered_arrays
115
- self .executor = cf .ThreadPoolExecutor (max_workers = encoder_threads )
116
- self .chunk_length = buffered_arrays [0 ].chunk_length
117
- assert all (ba .chunk_length == self .chunk_length for ba in self .buffered_arrays )
118
- self .futures = []
119
- self .array_offset = 0
120
- self .next_row = - 1
121
-
122
- def next_buffer_row (self ):
123
- self .next_row += 1
124
- if self .next_row == self .chunk_length :
125
- self .swap_buffers ()
126
- self .array_offset += self .chunk_length
127
- self .next_row = 0
128
- return self .next_row
129
-
130
- def swap_buffers (self ):
131
- wait_on_futures (self .futures )
132
- self .futures = []
133
- for ba in self .buffered_arrays :
134
- self .futures .extend (
135
- ba .async_flush (self .executor , self .array_offset , self .next_row )
136
- )
137
- ba .swap_buffers ()
138
-
139
- def __exit__ (self , exc_type , exc_val , exc_tb ):
140
- if exc_type is None :
141
- # Normal exit condition
142
- self .next_row += 1
143
- self .swap_buffers ()
144
- wait_on_futures (self .futures )
145
- else :
146
- cancel_futures (self .futures )
147
- self .executor .shutdown ()
148
- return False
149
-
150
126
151
127
@dataclasses .dataclass
152
128
class ProgressConfig :
@@ -157,6 +133,10 @@ class ProgressConfig:
157
133
poll_interval : float = 0.001
158
134
159
135
136
+ # NOTE: this approach means that we cannot have more than one
137
+ # progressable thing happening per source process. This is
138
+ # probably fine in practise, but there could be corner cases
139
+ # where it's not. Something to watch out for.
160
140
_progress_counter = multiprocessing .Value ("Q" , 0 )
161
141
162
142
@@ -190,7 +170,16 @@ def progress_thread_worker(config):
190
170
inc = current - pbar .n
191
171
pbar .update (inc )
192
172
time .sleep (config .poll_interval )
173
+ # TODO figure out why we're sometimes going over total
174
+ # if get_progress() != config.total:
175
+ # print("HOW DID THIS HAPPEN!!")
176
+ # print(get_progress())
177
+ # print(config)
178
+ # assert get_progress() == config.total
179
+ inc = config .total - pbar .n
180
+ pbar .update (inc )
193
181
pbar .close ()
182
+ # print("EXITING PROGRESS THREAD")
194
183
195
184
196
185
class ParallelWorkManager (contextlib .AbstractContextManager ):
@@ -228,7 +217,7 @@ def __exit__(self, exc_type, exc_val, exc_tb):
228
217
# Note: this doesn't seem to be working correctly. If
229
218
# we set a timeout of None we get deadlocks
230
219
set_progress (self .progress_config .total )
231
- timeout = 1
220
+ timeout = None
232
221
else :
233
222
cancel_futures (self .futures )
234
223
timeout = 0
0 commit comments