Skip to content

Commit 2439c70

Browse files
Change encode to a single progress monitor
1 parent 099101e commit 2439c70

File tree

2 files changed

+50
-57
lines changed

2 files changed

+50
-57
lines changed

bio2zarr/core.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ def flush(self):
118118

119119
def sync_flush_1d_array(np_buffer, zarr_array, offset):
120120
zarr_array[offset : offset + np_buffer.shape[0]] = np_buffer
121-
update_progress(1)
121+
update_progress(np_buffer.nbytes)
122122

123123

124124
def sync_flush_2d_array(np_buffer, zarr_array, offset):
@@ -127,12 +127,15 @@ def sync_flush_2d_array(np_buffer, zarr_array, offset):
127127
# encoder implementations.
128128
s = slice(offset, offset + np_buffer.shape[0])
129129
chunk_width = zarr_array.chunks[1]
130+
# TODO use zarr chunks here to support non-uniform chunking later
131+
# and for simplicity
130132
zarr_array_width = zarr_array.shape[1]
131133
start = 0
132134
while start < zarr_array_width:
133135
stop = min(start + chunk_width, zarr_array_width)
134-
zarr_array[s, start:stop] = np_buffer[:, start:stop]
135-
update_progress(1)
136+
chunk_buffer = np_buffer[:, start:stop]
137+
zarr_array[s, start:stop] = chunk_buffer
138+
update_progress(chunk_buffer.nbytes)
136139
start = stop
137140

138141

bio2zarr/vcf.py

Lines changed: 44 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -1284,6 +1284,13 @@ def summary_table(self):
12841284
return data
12851285

12861286

1287+
@dataclasses.dataclass
1288+
class EncodingWork:
1289+
func: callable
1290+
start: int
1291+
stop: int
1292+
1293+
12871294
class VcfZarrWriter:
12881295
def __init__(self, path, pcvcf, schema):
12891296
self.path = pathlib.Path(path)
@@ -1484,66 +1491,49 @@ def encode(
14841491
shape[0] = truncated
14851492
array.resize(shape)
14861493

1487-
chunked_1d = [
1488-
col for col in self.schema.columns.values() if len(col.chunks) <= 1
1489-
]
1490-
progress_config = core.ProgressConfig(
1491-
total=sum(self.get_array(col.name).nchunks for col in chunked_1d),
1492-
title="Encode 1D",
1493-
units="chunks",
1494-
show=show_progress,
1495-
)
1494+
total_bytes = 0
1495+
for col in self.schema.columns.values():
1496+
array = self.get_array(col.name)
1497+
total_bytes += array.nbytes
14961498

1497-
# Do these syncronously for simplicity so we have the mapping
14981499
filter_id_map = self.encode_filter_id()
14991500
contig_id_map = self.encode_contig_id()
15001501

1502+
work = []
1503+
for start, stop in slices:
1504+
for col in self.schema.columns.values():
1505+
if col.vcf_field is not None:
1506+
f = functools.partial(self.encode_array_slice, col)
1507+
work.append(EncodingWork(f, start, stop))
1508+
work.append(EncodingWork(self.encode_alleles_slice, start, stop))
1509+
work.append(EncodingWork(self.encode_id_slice, start, stop))
1510+
work.append(
1511+
EncodingWork(
1512+
functools.partial(self.encode_filters_slice, filter_id_map),
1513+
start,
1514+
stop,
1515+
)
1516+
)
1517+
work.append(
1518+
EncodingWork(
1519+
functools.partial(self.encode_contig_slice, contig_id_map),
1520+
start,
1521+
stop,
1522+
)
1523+
)
1524+
if "call_genotype" in self.schema.columns:
1525+
work.append(EncodingWork(self.encode_genotypes_slice, start, stop))
1526+
1527+
progress_config = core.ProgressConfig(
1528+
total=total_bytes,
1529+
title="Encode",
1530+
units="B",
1531+
show=show_progress,
1532+
)
15011533
with core.ParallelWorkManager(worker_processes, progress_config) as pwm:
15021534
pwm.submit(self.encode_samples)
1503-
for start, stop in slices:
1504-
pwm.submit(self.encode_alleles_slice, start, stop)
1505-
pwm.submit(self.encode_id_slice, start, stop)
1506-
pwm.submit(self.encode_filters_slice, filter_id_map, start, stop)
1507-
pwm.submit(self.encode_contig_slice, contig_id_map, start, stop)
1508-
for col in chunked_1d:
1509-
if col.vcf_field is not None:
1510-
pwm.submit(self.encode_array_slice, col, start, stop)
1511-
1512-
chunked_2d = [
1513-
col for col in self.schema.columns.values() if len(col.chunks) >= 2
1514-
]
1515-
if len(chunked_2d) > 0:
1516-
progress_config = core.ProgressConfig(
1517-
total=sum(self.get_array(col.name).nchunks for col in chunked_2d),
1518-
title="Encode 2D",
1519-
units="chunks",
1520-
show=show_progress,
1521-
)
1522-
with core.ParallelWorkManager(worker_processes, progress_config) as pwm:
1523-
if "call_genotype" in self.schema.columns:
1524-
arrays = [
1525-
self.get_array("call_genotype"),
1526-
self.get_array("call_genotype_phased"),
1527-
self.get_array("call_genotype_mask"),
1528-
]
1529-
min_mem = sum(array.blocks[0].nbytes for array in arrays)
1530-
logger.info(
1531-
f"Submit encode call_genotypes in {len(slices)} slices. "
1532-
f"Min per-worker mem={display_size(min_mem)}"
1533-
)
1534-
for start, stop in slices:
1535-
pwm.submit(self.encode_genotypes_slice, start, stop)
1536-
1537-
for col in chunked_2d:
1538-
if col.vcf_field is not None:
1539-
array = self.get_array(col.name)
1540-
min_mem = array.blocks[0].nbytes
1541-
logger.info(
1542-
f"Submit encode {col.name} in {len(slices)} slices. "
1543-
f"Min per-worker mem={display_size(min_mem)}"
1544-
)
1545-
for start, stop in slices:
1546-
pwm.submit(self.encode_array_slice, col, start, stop)
1535+
for wp in work:
1536+
pwm.submit(wp.func, wp.start, wp.stop)
15471537

15481538

15491539
def mkschema(if_path, out):

0 commit comments

Comments
 (0)