Skip to content

Commit 75e8039

Browse files
Merge pull request #21 from jeromekelleher/more-refactoring
Abstract threaded zarr encoding to class
2 parents 39c53f2 + 9e3eb74 commit 75e8039

File tree

3 files changed

+67
-50
lines changed

3 files changed

+67
-50
lines changed

bio2zarr/core.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
import dataclasses
2+
import contextlib
3+
import concurrent.futures as cf
24
import logging
35

46
import zarr
@@ -19,6 +21,10 @@ def __init__(self, array):
1921
dims[0] = min(array.chunks[0], array.shape[0])
2022
self.buff = np.zeros(dims, dtype=array.dtype)
2123

24+
@property
25+
def chunk_length(self):
26+
return self.buff.shape[0]
27+
2228
def swap_buffers(self):
2329
self.buff = np.zeros_like(self.buff)
2430

@@ -63,3 +69,52 @@ def flush_chunk(start, stop):
6369
start = stop
6470

6571
return futures
72+
73+
74+
class ThreadedZarrEncoder(contextlib.AbstractContextManager):
75+
def __init__(self, buffered_arrays, encoder_threads):
76+
self.buffered_arrays = buffered_arrays
77+
self.executor = cf.ThreadPoolExecutor(max_workers=encoder_threads)
78+
self.chunk_length = buffered_arrays[0].chunk_length
79+
assert all(ba.chunk_length == self.chunk_length for ba in self.buffered_arrays)
80+
self.futures = []
81+
self.array_offset = 0
82+
self.next_row = -1
83+
84+
def next_buffer_row(self):
85+
self.next_row += 1
86+
if self.next_row == self.chunk_length:
87+
self.swap_buffers()
88+
self.array_offset += self.chunk_length
89+
self.next_row = 0
90+
return self.next_row
91+
92+
def wait_on_futures(self):
93+
for future in cf.as_completed(self.futures):
94+
exception = future.exception()
95+
if exception is not None:
96+
raise exception
97+
98+
def swap_buffers(self):
99+
self.wait_on_futures()
100+
self.futures = []
101+
for ba in self.buffered_arrays:
102+
# TODO add debug log
103+
# print("Scheduling", ba.array, offset, buff_stop)
104+
self.futures.extend(
105+
ba.async_flush(self.executor, self.array_offset, self.next_row)
106+
)
107+
ba.swap_buffers()
108+
109+
def __exit__(self, exc_type, exc_val, exc_tb):
110+
if exc_type is None:
111+
# Normal exit condition
112+
self.next_row += 1
113+
self.swap_buffers()
114+
self.wait_on_futures()
115+
# TODO add arguments to wait and cancel_futures appropriate
116+
# for the an error condition occuring here. Generally need
117+
# to think about the error exit condition here (like running
118+
# out of disk space) to see what the right behaviour is.
119+
self.executor.shutdown()
120+
return False

bio2zarr/vcf.py

Lines changed: 10 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -1127,77 +1127,44 @@ def create_array(self, variable):
11271127
)
11281128
a.attrs["_ARRAY_DIMENSIONS"] = variable.dimensions
11291129

1130-
def encode_column(self, pcvcf, column):
1130+
def encode_column(self, pcvcf, column, encoder_threads=4):
11311131
source_col = pcvcf.columns[column.vcf_field]
11321132
array = self.root[column.name]
11331133
ba = core.BufferedArray(array)
11341134
sanitiser = source_col.sanitiser_factory(ba.buff.shape)
1135-
chunk_length = array.chunks[0]
11361135

1137-
with cf.ThreadPoolExecutor(max_workers=4) as executor:
1138-
futures = []
1139-
chunk_start = 0
1140-
j = 0
1136+
with core.ThreadedZarrEncoder([ba], encoder_threads) as te:
11411137
last_bytes_read = 0
11421138
for value, bytes_read in source_col.iter_values_bytes():
1139+
j = te.next_buffer_row()
11431140
sanitiser(ba.buff, j, value)
1144-
j += 1
1145-
if j == chunk_length:
1146-
flush_futures(futures)
1147-
futures.extend(ba.async_flush(executor, chunk_start))
1148-
ba.swap_buffers()
1149-
j = 0
1150-
chunk_start += chunk_length
11511141
if last_bytes_read != bytes_read:
11521142
with progress_counter.get_lock():
11531143
progress_counter.value += bytes_read - last_bytes_read
11541144
last_bytes_read = bytes_read
11551145

1156-
if j != 0:
1157-
flush_futures(futures)
1158-
futures.extend(ba.async_flush(executor, chunk_start, j))
1159-
flush_futures(futures)
1160-
1161-
def encode_genotypes(self, pcvcf):
1146+
def encode_genotypes(self, pcvcf, encoder_threads=4):
11621147
source_col = pcvcf.columns["FORMAT/GT"]
11631148
gt = core.BufferedArray(self.root["call_genotype"])
11641149
gt_mask = core.BufferedArray(self.root["call_genotype_mask"])
11651150
gt_phased = core.BufferedArray(self.root["call_genotype_phased"])
1166-
chunk_length = gt.array.chunks[0]
1167-
11681151
buffered_arrays = [gt, gt_phased, gt_mask]
11691152

1170-
with cf.ThreadPoolExecutor(max_workers=4) as executor:
1171-
futures = []
1172-
chunk_start = 0
1173-
j = 0
1153+
with core.ThreadedZarrEncoder(buffered_arrays, encoder_threads) as te:
11741154
last_bytes_read = 0
11751155
for value, bytes_read in source_col.iter_values_bytes():
1156+
j = te.next_buffer_row()
11761157
sanitise_value_int_2d(gt.buff, j, value[:, :-1])
11771158
sanitise_value_int_1d(gt_phased.buff, j, value[:, -1])
11781159
# TODO check is this the correct semantics when we are padding
11791160
# with mixed ploidies?
11801161
gt_mask.buff[j] = gt.buff[j] < 0
11811162

1182-
j += 1
1183-
if j == chunk_length:
1184-
flush_futures(futures)
1185-
for ba in buffered_arrays:
1186-
futures.extend(ba.async_flush(executor, chunk_start))
1187-
ba.swap_buffers()
1188-
j = 0
1189-
chunk_start += chunk_length
11901163
if last_bytes_read != bytes_read:
11911164
with progress_counter.get_lock():
11921165
progress_counter.value += bytes_read - last_bytes_read
11931166
last_bytes_read = bytes_read
11941167

1195-
if j != 0:
1196-
flush_futures(futures)
1197-
for ba in buffered_arrays:
1198-
futures.extend(ba.async_flush(executor, chunk_start, j))
1199-
flush_futures(futures)
1200-
12011168
def encode_alleles(self, pcvcf):
12021169
ref_col = pcvcf.columns["REF"]
12031170
alt_col = pcvcf.columns["ALT"]
@@ -1451,7 +1418,7 @@ def convert_vcf(
14511418
)
14521419

14531420

1454-
def encode_bed_partition_genotypes(bed_path, zarr_path, start_variant, end_variant):
1421+
def encode_bed_partition_genotypes(bed_path, zarr_path, start_variant, end_variant, encoder_threads=8):
14551422
bed = bed_reader.open_bed(bed_path, num_threads=1)
14561423

14571424
store = zarr.DirectoryStore(zarr_path)
@@ -1464,8 +1431,7 @@ def encode_bed_partition_genotypes(bed_path, zarr_path, start_variant, end_varia
14641431

14651432
buffered_arrays = [gt, gt_phased, gt_mask]
14661433

1467-
with cf.ThreadPoolExecutor(max_workers=8) as executor:
1468-
futures = []
1434+
with core.ThreadedZarrEncoder(buffered_arrays, encoder_threads) as te:
14691435

14701436
start = start_variant
14711437
while start < end_variant:
@@ -1474,7 +1440,8 @@ def encode_bed_partition_genotypes(bed_path, zarr_path, start_variant, end_varia
14741440
# Note could do this without iterating over rows, but it's a bit
14751441
# simpler and the bottleneck is in the encoding step anyway. It's
14761442
# also nice to have updates on the progress monitor.
1477-
for j, values in enumerate(bed_chunk):
1443+
for values in bed_chunk:
1444+
j = te.next_buffer_row()
14781445
dest = gt.buff[j]
14791446
dest[values == -127] = -1
14801447
dest[values == 2] = 1
@@ -1483,14 +1450,7 @@ def encode_bed_partition_genotypes(bed_path, zarr_path, start_variant, end_varia
14831450
gt_mask.buff[j] = dest == -1
14841451
with progress_counter.get_lock():
14851452
progress_counter.value += 1
1486-
1487-
assert j <= chunk_length
1488-
flush_futures(futures)
1489-
for ba in buffered_arrays:
1490-
ba.async_flush(extend, start, j)
1491-
ba.swap_buffers()
14921453
start = stop
1493-
flush_futures(futures)
14941454

14951455

14961456
def validate(vcf_path, zarr_path, show_progress=False):

tests/test_vcf.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -247,6 +247,8 @@ def test_chunk_size(
247247
out = tmp_path / "example.vcf.zarr"
248248
vcf.convert_vcf([path], out, chunk_length=chunk_length, chunk_width=chunk_width)
249249
ds2 = sg.load_dataset(out)
250+
# print(ds2.call_genotype.values)
251+
# print(ds.call_genotype.values)
250252
xt.assert_equal(ds, ds2)
251253
assert ds2.call_DP.chunks == (y_chunks, x_chunks)
252254
assert ds2.call_GQ.chunks == (y_chunks, x_chunks)

0 commit comments

Comments
 (0)