@@ -90,18 +90,38 @@ def next_buffer_row(self):
90
90
def flush (self ):
91
91
# TODO just move sync_flush_array in here
92
92
if self .buffer_row != 0 :
93
- sync_flush_array (
94
- self .buff [: self .buffer_row ], self .array , self .array_offset
95
- )
93
+ if len (self .array .chunks ) <= 1 :
94
+ sync_flush_1d_array (
95
+ self .buff [: self .buffer_row ], self .array , self .array_offset
96
+ )
97
+ else :
98
+ sync_flush_2d_array (
99
+ self .buff [: self .buffer_row ], self .array , self .array_offset
100
+ )
101
+ logger .debug (
102
+ f"Flushed chunk { self .array } { self .array_offset } + { self .buffer_row } " )
96
103
self .array_offset += self .chunk_length
97
104
self .buffer_row = 0
98
105
99
106
100
- # TODO: factor these functions into the BufferedArray class
107
+ def sync_flush_1d_array (np_buffer , zarr_array , offset ):
108
+ zarr_array [offset : offset + np_buffer .shape [0 ]] = np_buffer
109
+ update_progress (1 )
101
110
102
111
103
- def sync_flush_array (np_buffer , zarr_array , offset ):
104
- zarr_array [offset : offset + np_buffer .shape [0 ]] = np_buffer
112
+ def sync_flush_2d_array (np_buffer , zarr_array , offset ):
113
+ # Write chunks in the second dimension 1-by-1 to make progress more
114
+ # incremental, and to avoid large memcopies in the underlying
115
+ # encoder implementations.
116
+ s = slice (offset , offset + np_buffer .shape [0 ])
117
+ chunk_width = zarr_array .chunks [1 ]
118
+ zarr_array_width = zarr_array .shape [1 ]
119
+ start = 0
120
+ while start < zarr_array_width :
121
+ stop = min (start + chunk_width , zarr_array_width )
122
+ zarr_array [s , start :stop ] = np_buffer [:, start :stop ]
123
+ update_progress (1 )
124
+ start = stop
105
125
106
126
107
127
@dataclasses .dataclass
@@ -113,6 +133,10 @@ class ProgressConfig:
113
133
poll_interval : float = 0.001
114
134
115
135
136
+ # NOTE: this approach means that we cannot have more than one
137
+ # progressable thing happening per source process. This is
138
+ # probably fine in practise, but there could be corner cases
139
+ # where it's not. Something to watch out for.
116
140
_progress_counter = multiprocessing .Value ("Q" , 0 )
117
141
118
142
@@ -146,8 +170,14 @@ def progress_thread_worker(config):
146
170
inc = current - pbar .n
147
171
pbar .update (inc )
148
172
time .sleep (config .poll_interval )
149
- # inc = config.total - pbar.n
150
- # pbar.update(inc)
173
+ # TODO figure out why we're sometimes going over total
174
+ # if get_progress() != config.total:
175
+ # print("HOW DID THIS HAPPEN!!")
176
+ # print(get_progress())
177
+ # print(config)
178
+ # assert get_progress() == config.total
179
+ inc = config .total - pbar .n
180
+ pbar .update (inc )
151
181
pbar .close ()
152
182
# print("EXITING PROGRESS THREAD")
153
183
@@ -187,7 +217,7 @@ def __exit__(self, exc_type, exc_val, exc_tb):
187
217
# Note: this doesn't seem to be working correctly. If
188
218
# we set a timeout of None we get deadlocks
189
219
set_progress (self .progress_config .total )
190
- timeout = 1
220
+ timeout = None
191
221
else :
192
222
cancel_futures (self .futures )
193
223
timeout = 0
0 commit comments