Skip to content

Commit 01559ce

Browse files
Refactored encode write path
1 parent f184967 commit 01559ce

File tree

3 files changed

+139
-223
lines changed

3 files changed

+139
-223
lines changed

bio2zarr/core.py

Lines changed: 38 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,23 @@
2424
)
2525

2626

27+
def chunk_aligned_slices(z, n):
28+
"""
29+
Returns at n slices in the specified zarr array, aligned
30+
with its chunks
31+
"""
32+
chunk_size = z.chunks[0]
33+
num_chunks = int(np.ceil(z.shape[0] / chunk_size))
34+
slices = []
35+
splits = np.array_split(np.arange(num_chunks), min(n, num_chunks))
36+
for split in splits:
37+
start = split[0] * chunk_size
38+
stop = (split[-1] + 1) * chunk_size
39+
stop = min(stop, z.shape[0])
40+
slices.append((start, stop))
41+
return slices
42+
43+
2744
class SynchronousExecutor(cf.Executor):
2845
def submit(self, fn, /, *args, **kwargs):
2946
future = cf.Future()
@@ -46,23 +63,38 @@ def cancel_futures(futures):
4663
@dataclasses.dataclass
4764
class BufferedArray:
4865
array: zarr.Array
66+
array_offset: int
4967
buff: np.ndarray
68+
buffer_row: int
5069

51-
def __init__(self, array):
70+
def __init__(self, array, offset):
5271
self.array = array
72+
self.array_offset = offset
73+
assert offset % array.chunks[0] == 0
5374
dims = list(array.shape)
5475
dims[0] = min(array.chunks[0], array.shape[0])
5576
self.buff = np.zeros(dims, dtype=array.dtype)
77+
self.buffer_row = 0
5678

5779
@property
5880
def chunk_length(self):
5981
return self.buff.shape[0]
6082

61-
def swap_buffers(self):
62-
self.buff = np.zeros_like(self.buff)
63-
64-
def async_flush(self, executor, offset, buff_stop=None):
65-
return async_flush_array(executor, self.buff[:buff_stop], self.array, offset)
83+
def next_buffer_row(self):
84+
if self.buffer_row == self.chunk_length:
85+
self.flush()
86+
row = self.buffer_row
87+
self.buffer_row += 1
88+
return row
89+
90+
def flush(self):
91+
# TODO just move sync_flush_array in here
92+
if self.buffer_row != 0:
93+
sync_flush_array(
94+
self.buff[: self.buffer_row], self.array, self.array_offset
95+
)
96+
self.array_offset += self.chunk_length
97+
self.buffer_row = 0
6698

6799

68100
# TODO: factor these functions into the BufferedArray class
@@ -72,82 +104,6 @@ def sync_flush_array(np_buffer, zarr_array, offset):
72104
zarr_array[offset : offset + np_buffer.shape[0]] = np_buffer
73105

74106

75-
def async_flush_array(executor, np_buffer, zarr_array, offset):
76-
"""
77-
Flush the specified chunk aligned buffer to the specified zarr array.
78-
"""
79-
logger.debug(f"Schedule flush {zarr_array} @ {offset}")
80-
assert zarr_array.shape[1:] == np_buffer.shape[1:]
81-
# print("sync", zarr_array, np_buffer)
82-
83-
if len(np_buffer.shape) == 1:
84-
futures = [executor.submit(sync_flush_array, np_buffer, zarr_array, offset)]
85-
else:
86-
futures = async_flush_2d_array(executor, np_buffer, zarr_array, offset)
87-
return futures
88-
89-
90-
def async_flush_2d_array(executor, np_buffer, zarr_array, offset):
91-
# Flush each of the chunks in the second dimension separately
92-
s = slice(offset, offset + np_buffer.shape[0])
93-
94-
def flush_chunk(start, stop):
95-
zarr_array[s, start:stop] = np_buffer[:, start:stop]
96-
97-
chunk_width = zarr_array.chunks[1]
98-
zarr_array_width = zarr_array.shape[1]
99-
start = 0
100-
futures = []
101-
while start < zarr_array_width:
102-
stop = min(start + chunk_width, zarr_array_width)
103-
future = executor.submit(flush_chunk, start, stop)
104-
futures.append(future)
105-
start = stop
106-
107-
return futures
108-
109-
110-
class ThreadedZarrEncoder(contextlib.AbstractContextManager):
111-
# TODO (maybe) add option with encoder_threads=None to run synchronously for
112-
# debugging using a mock Executor
113-
def __init__(self, buffered_arrays, encoder_threads=1):
114-
self.buffered_arrays = buffered_arrays
115-
self.executor = cf.ThreadPoolExecutor(max_workers=encoder_threads)
116-
self.chunk_length = buffered_arrays[0].chunk_length
117-
assert all(ba.chunk_length == self.chunk_length for ba in self.buffered_arrays)
118-
self.futures = []
119-
self.array_offset = 0
120-
self.next_row = -1
121-
122-
def next_buffer_row(self):
123-
self.next_row += 1
124-
if self.next_row == self.chunk_length:
125-
self.swap_buffers()
126-
self.array_offset += self.chunk_length
127-
self.next_row = 0
128-
return self.next_row
129-
130-
def swap_buffers(self):
131-
wait_on_futures(self.futures)
132-
self.futures = []
133-
for ba in self.buffered_arrays:
134-
self.futures.extend(
135-
ba.async_flush(self.executor, self.array_offset, self.next_row)
136-
)
137-
ba.swap_buffers()
138-
139-
def __exit__(self, exc_type, exc_val, exc_tb):
140-
if exc_type is None:
141-
# Normal exit condition
142-
self.next_row += 1
143-
self.swap_buffers()
144-
wait_on_futures(self.futures)
145-
else:
146-
cancel_futures(self.futures)
147-
self.executor.shutdown()
148-
return False
149-
150-
151107
@dataclasses.dataclass
152108
class ProgressConfig:
153109
total: int = 0

bio2zarr/vcf.py

Lines changed: 39 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -1262,54 +1262,39 @@ def create_array(self, variable):
12621262
)
12631263
a.attrs["_ARRAY_DIMENSIONS"] = variable.dimensions
12641264

1265-
def encode_column(self, pcvcf, column, encoder_threads=4):
1266-
# TODO we're doing this the wrong way at the moment, overcomplicating
1267-
# things by having the ThreadedZarrEncoder. It would be simpler if
1268-
# we split the columns into vertical chunks, and just pushed a bunch
1269-
# of futures for encoding start:end slices of each column. The
1270-
# complicating factor here is that we need to get these slices
1271-
# out of the pcvcf, which takes a little bit of doing (but fine,
1272-
# because we know the number of records in each partition).
1273-
# An annoying factor then is how to update the progess meter
1274-
# because the "bytes read" approach becomes problematic
1275-
# when we might access the same chunk several times.
1276-
# Would perhaps be better to call sys.getsizeof() on the stored
1277-
# value each time.
1278-
1265+
def encode_column_slice(self, pcvcf, column, start, stop):
12791266
source_col = pcvcf.columns[column.vcf_field]
12801267
array = self.root[column.name]
1281-
ba = core.BufferedArray(array)
1268+
ba = core.BufferedArray(array, start)
12821269
sanitiser = source_col.sanitiser_factory(ba.buff.shape)
12831270

1284-
with core.ThreadedZarrEncoder([ba], encoder_threads) as te:
1285-
last_bytes_read = 0
1286-
for value, bytes_read in source_col.iter_values_bytes():
1287-
j = te.next_buffer_row()
1288-
sanitiser(ba.buff, j, value)
1289-
# print(bytes_read, last_bytes_read, value)
1290-
if last_bytes_read != bytes_read:
1291-
core.update_progress(bytes_read - last_bytes_read)
1292-
last_bytes_read = bytes_read
1293-
1294-
def encode_genotypes(self, pcvcf, encoder_threads=4):
1271+
for value in source_col.iter_values(start, stop):
1272+
# We write directly into the buffer in the sanitiser function
1273+
# to make it easier to reason about dimension padding
1274+
j = ba.next_buffer_row()
1275+
sanitiser(ba.buff, j, value)
1276+
core.update_progress(sys.getsizeof(value))
1277+
ba.flush()
1278+
1279+
def encode_genotypes_slice(self, pcvcf, start, stop):
12951280
source_col = pcvcf.columns["FORMAT/GT"]
1296-
gt = core.BufferedArray(self.root["call_genotype"])
1297-
gt_mask = core.BufferedArray(self.root["call_genotype_mask"])
1298-
gt_phased = core.BufferedArray(self.root["call_genotype_phased"])
1299-
buffered_arrays = [gt, gt_phased, gt_mask]
1300-
1301-
with core.ThreadedZarrEncoder(buffered_arrays, encoder_threads) as te:
1302-
last_bytes_read = 0
1303-
for value, bytes_read in source_col.iter_values_bytes():
1304-
j = te.next_buffer_row()
1305-
sanitise_value_int_2d(gt.buff, j, value[:, :-1])
1306-
sanitise_value_int_1d(gt_phased.buff, j, value[:, -1])
1307-
# TODO check is this the correct semantics when we are padding
1308-
# with mixed ploidies?
1309-
gt_mask.buff[j] = gt.buff[j] < 0
1310-
if last_bytes_read != bytes_read:
1311-
core.update_progress(bytes_read - last_bytes_read)
1312-
last_bytes_read = bytes_read
1281+
gt = core.BufferedArray(self.root["call_genotype"], start)
1282+
gt_mask = core.BufferedArray(self.root["call_genotype_mask"], start)
1283+
gt_phased = core.BufferedArray(self.root["call_genotype_phased"], start)
1284+
1285+
for value in source_col.iter_values(start, stop):
1286+
j = gt.next_buffer_row()
1287+
sanitise_value_int_2d(gt.buff, j, value[:, :-1])
1288+
j = gt_phased.next_buffer_row()
1289+
sanitise_value_int_1d(gt_phased.buff, j, value[:, -1])
1290+
# TODO check is this the correct semantics when we are padding
1291+
# with mixed ploidies?
1292+
j = gt_mask.next_buffer_row()
1293+
gt_mask.buff[j] = gt.buff[j] < 0
1294+
core.update_progress(sys.getsizeof(value))
1295+
gt.flush()
1296+
gt_phased.flush()
1297+
gt_mask.flush()
13131298

13141299
def encode_alleles(self, pcvcf):
13151300
ref_col = pcvcf.columns["REF"]
@@ -1449,6 +1434,7 @@ def convert(
14491434
units="b",
14501435
show=show_progress,
14511436
)
1437+
num_slices = max(1, worker_processes * 4)
14521438
with core.ParallelWorkManager(worker_processes, progress_config) as pwm:
14531439
pwm.submit(
14541440
sgvcf.encode_samples,
@@ -1465,22 +1451,23 @@ def convert(
14651451
conversion_spec.contig_length,
14661452
)
14671453
pwm.submit(sgvcf.encode_filters, pcvcf, conversion_spec.filter_id)
1454+
# Using POS arbitrarily to get the array slices
1455+
slices = core.chunk_aligned_slices(
1456+
sgvcf.root["variant_position"], num_slices
1457+
)
14681458
has_gt = False
14691459
for variable in conversion_spec.columns.values():
14701460
if variable.vcf_field is not None:
1471-
# print("Encode", variable.name)
1472-
# TODO for large columns it's probably worth splitting up
1473-
# these into vertical chunks. Otherwise we tend to get a
1474-
# long wait for the largest GT columns to finish.
1475-
# Straightforward to do because we can chunk-align the work
1476-
# packages.
1477-
pwm.submit(sgvcf.encode_column, pcvcf, variable)
1461+
for start, stop in slices:
1462+
pwm.submit(
1463+
sgvcf.encode_column_slice, pcvcf, variable, start, stop
1464+
)
14781465
else:
14791466
if variable.name == "call_genotype":
14801467
has_gt = True
14811468
if has_gt:
1482-
# TODO add mixed ploidy
1483-
pwm.executor.submit(sgvcf.encode_genotypes, pcvcf)
1469+
for start, stop in slices:
1470+
pwm.submit(sgvcf.encode_genotypes_slice, pcvcf, start, stop)
14841471

14851472
zarr.consolidate_metadata(write_path)
14861473
# Atomic swap, now we've completely finished.

0 commit comments

Comments
 (0)