Skip to content

Commit 1420b07

Browse files
Change progress units to chunks
1 parent b10479e commit 1420b07

File tree

4 files changed

+140
-54
lines changed

4 files changed

+140
-54
lines changed

bio2zarr/cli.py

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,11 @@
1414
"-p", "--worker-processes", type=int, default=1, help="Number of worker processes"
1515
)
1616

17+
# TODO help text
18+
chunk_length = click.option("-l", "--chunk-length", type=int, default=None)
19+
20+
chunk_width = click.option("-w", "--chunk-width", type=int, default=None)
21+
1722
version = click.version_option(version=provenance.__version__)
1823

1924

@@ -79,8 +84,13 @@ def mkschema(if_path):
7984
@click.argument("zarr_path", type=click.Path())
8085
@verbose
8186
@click.option("-s", "--schema", default=None)
87+
# TODO: these are mutually exclusive with schema, tell click this
88+
@chunk_length
89+
@chunk_width
8290
@worker_processes
83-
def encode(if_path, zarr_path, verbose, schema, worker_processes):
91+
def encode(
92+
if_path, zarr_path, verbose, schema, chunk_length, chunk_width, worker_processes
93+
):
8494
"""
8595
Encode intermediate format (see explode) to vcfzarr
8696
"""
@@ -89,6 +99,8 @@ def encode(if_path, zarr_path, verbose, schema, worker_processes):
8999
if_path,
90100
zarr_path,
91101
schema,
102+
chunk_length=chunk_length,
103+
chunk_width=chunk_width,
92104
worker_processes=worker_processes,
93105
show_progress=True,
94106
)
@@ -97,14 +109,23 @@ def encode(if_path, zarr_path, verbose, schema, worker_processes):
97109
@click.command(name="convert")
98110
@click.argument("vcfs", nargs=-1, required=True)
99111
@click.argument("out_path", type=click.Path())
112+
@chunk_length
113+
@chunk_width
100114
@verbose
101115
@worker_processes
102-
def convert_vcf(vcfs, out_path, verbose, worker_processes):
116+
def convert_vcf(vcfs, out_path, chunk_length, chunk_width, verbose, worker_processes):
103117
"""
104118
Convert input VCF(s) directly to vcfzarr (not recommended for large files)
105119
"""
106120
setup_logging(verbose)
107-
vcf.convert(vcfs, out_path, show_progress=True, worker_processes=worker_processes)
121+
vcf.convert(
122+
vcfs,
123+
out_path,
124+
chunk_length=chunk_length,
125+
chunk_width=chunk_width,
126+
show_progress=True,
127+
worker_processes=worker_processes,
128+
)
108129

109130

110131
@click.command

bio2zarr/core.py

Lines changed: 39 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -90,18 +90,38 @@ def next_buffer_row(self):
9090
def flush(self):
9191
# TODO just move sync_flush_array in here
9292
if self.buffer_row != 0:
93-
sync_flush_array(
94-
self.buff[: self.buffer_row], self.array, self.array_offset
95-
)
93+
if len(self.array.chunks) <= 1:
94+
sync_flush_1d_array(
95+
self.buff[: self.buffer_row], self.array, self.array_offset
96+
)
97+
else:
98+
sync_flush_2d_array(
99+
self.buff[: self.buffer_row], self.array, self.array_offset
100+
)
101+
logger.debug(
102+
f"Flushed chunk {self.array} {self.array_offset} + {self.buffer_row}")
96103
self.array_offset += self.chunk_length
97104
self.buffer_row = 0
98105

99106

100-
# TODO: factor these functions into the BufferedArray class
107+
def sync_flush_1d_array(np_buffer, zarr_array, offset):
108+
zarr_array[offset : offset + np_buffer.shape[0]] = np_buffer
109+
update_progress(1)
101110

102111

103-
def sync_flush_array(np_buffer, zarr_array, offset):
104-
zarr_array[offset : offset + np_buffer.shape[0]] = np_buffer
112+
def sync_flush_2d_array(np_buffer, zarr_array, offset):
113+
# Write chunks in the second dimension 1-by-1 to make progress more
114+
# incremental, and to avoid large memcopies in the underlying
115+
# encoder implementations.
116+
s = slice(offset, offset + np_buffer.shape[0])
117+
chunk_width = zarr_array.chunks[1]
118+
zarr_array_width = zarr_array.shape[1]
119+
start = 0
120+
while start < zarr_array_width:
121+
stop = min(start + chunk_width, zarr_array_width)
122+
zarr_array[s, start:stop] = np_buffer[:, start:stop]
123+
update_progress(1)
124+
start = stop
105125

106126

107127
@dataclasses.dataclass
@@ -113,6 +133,10 @@ class ProgressConfig:
113133
poll_interval: float = 0.001
114134

115135

136+
# NOTE: this approach means that we cannot have more than one
137+
# progressable thing happening per source process. This is
138+
# probably fine in practise, but there could be corner cases
139+
# where it's not. Something to watch out for.
116140
_progress_counter = multiprocessing.Value("Q", 0)
117141

118142

@@ -146,8 +170,14 @@ def progress_thread_worker(config):
146170
inc = current - pbar.n
147171
pbar.update(inc)
148172
time.sleep(config.poll_interval)
149-
# inc = config.total - pbar.n
150-
# pbar.update(inc)
173+
# TODO figure out why we're sometimes going over total
174+
# if get_progress() != config.total:
175+
# print("HOW DID THIS HAPPEN!!")
176+
# print(get_progress())
177+
# print(config)
178+
# assert get_progress() == config.total
179+
inc = config.total - pbar.n
180+
pbar.update(inc)
151181
pbar.close()
152182
# print("EXITING PROGRESS THREAD")
153183

@@ -187,7 +217,7 @@ def __exit__(self, exc_type, exc_val, exc_tb):
187217
# Note: this doesn't seem to be working correctly. If
188218
# we set a timeout of None we get deadlocks
189219
set_progress(self.progress_config.total)
190-
timeout = 1
220+
timeout = None
191221
else:
192222
cancel_futures(self.futures)
193223
timeout = 0

0 commit comments

Comments
 (0)