Skip to content

Commit 95a265f

Browse files
Merge pull request #55 from jeromekelleher/slice-pcvcf
Slice pcvcf
2 parents caa3398 + 6cdbe8d commit 95a265f

File tree

6 files changed

+508
-337
lines changed

6 files changed

+508
-337
lines changed

bio2zarr/cli.py

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,11 @@
1414
"-p", "--worker-processes", type=int, default=1, help="Number of worker processes"
1515
)
1616

17+
# TODO help text
18+
chunk_length = click.option("-l", "--chunk-length", type=int, default=None)
19+
20+
chunk_width = click.option("-w", "--chunk-width", type=int, default=None)
21+
1722
version = click.version_option(version=provenance.__version__)
1823

1924

@@ -79,8 +84,13 @@ def mkschema(if_path):
7984
@click.argument("zarr_path", type=click.Path())
8085
@verbose
8186
@click.option("-s", "--schema", default=None)
87+
# TODO: these are mutually exclusive with schema, tell click this
88+
@chunk_length
89+
@chunk_width
8290
@worker_processes
83-
def encode(if_path, zarr_path, verbose, schema, worker_processes):
91+
def encode(
92+
if_path, zarr_path, verbose, schema, chunk_length, chunk_width, worker_processes
93+
):
8494
"""
8595
Encode intermediate format (see explode) to vcfzarr
8696
"""
@@ -89,6 +99,8 @@ def encode(if_path, zarr_path, verbose, schema, worker_processes):
8999
if_path,
90100
zarr_path,
91101
schema,
102+
chunk_length=chunk_length,
103+
chunk_width=chunk_width,
92104
worker_processes=worker_processes,
93105
show_progress=True,
94106
)
@@ -97,14 +109,23 @@ def encode(if_path, zarr_path, verbose, schema, worker_processes):
97109
@click.command(name="convert")
98110
@click.argument("vcfs", nargs=-1, required=True)
99111
@click.argument("out_path", type=click.Path())
112+
@chunk_length
113+
@chunk_width
100114
@verbose
101115
@worker_processes
102-
def convert_vcf(vcfs, out_path, verbose, worker_processes):
116+
def convert_vcf(vcfs, out_path, chunk_length, chunk_width, verbose, worker_processes):
103117
"""
104118
Convert input VCF(s) directly to vcfzarr (not recommended for large files)
105119
"""
106120
setup_logging(verbose)
107-
vcf.convert(vcfs, out_path, show_progress=True, worker_processes=worker_processes)
121+
vcf.convert(
122+
vcfs,
123+
out_path,
124+
chunk_length=chunk_length,
125+
chunk_width=chunk_width,
126+
show_progress=True,
127+
worker_processes=worker_processes,
128+
)
108129

109130

110131
@click.command

bio2zarr/core.py

Lines changed: 67 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,23 @@
2424
)
2525

2626

27+
def chunk_aligned_slices(z, n):
28+
"""
29+
Returns at n slices in the specified zarr array, aligned
30+
with its chunks
31+
"""
32+
chunk_size = z.chunks[0]
33+
num_chunks = int(np.ceil(z.shape[0] / chunk_size))
34+
slices = []
35+
splits = np.array_split(np.arange(num_chunks), min(n, num_chunks))
36+
for split in splits:
37+
start = split[0] * chunk_size
38+
stop = (split[-1] + 1) * chunk_size
39+
stop = min(stop, z.shape[0])
40+
slices.append((start, stop))
41+
return slices
42+
43+
2744
class SynchronousExecutor(cf.Executor):
2845
def submit(self, fn, /, *args, **kwargs):
2946
future = cf.Future()
@@ -46,107 +63,66 @@ def cancel_futures(futures):
4663
@dataclasses.dataclass
4764
class BufferedArray:
4865
array: zarr.Array
66+
array_offset: int
4967
buff: np.ndarray
68+
buffer_row: int
5069

51-
def __init__(self, array):
70+
def __init__(self, array, offset):
5271
self.array = array
72+
self.array_offset = offset
73+
assert offset % array.chunks[0] == 0
5374
dims = list(array.shape)
5475
dims[0] = min(array.chunks[0], array.shape[0])
5576
self.buff = np.zeros(dims, dtype=array.dtype)
77+
self.buffer_row = 0
5678

5779
@property
5880
def chunk_length(self):
5981
return self.buff.shape[0]
6082

61-
def swap_buffers(self):
62-
self.buff = np.zeros_like(self.buff)
63-
64-
def async_flush(self, executor, offset, buff_stop=None):
65-
return async_flush_array(executor, self.buff[:buff_stop], self.array, offset)
66-
67-
68-
# TODO: factor these functions into the BufferedArray class
83+
def next_buffer_row(self):
84+
if self.buffer_row == self.chunk_length:
85+
self.flush()
86+
row = self.buffer_row
87+
self.buffer_row += 1
88+
return row
89+
90+
def flush(self):
91+
# TODO just move sync_flush_array in here
92+
if self.buffer_row != 0:
93+
if len(self.array.chunks) <= 1:
94+
sync_flush_1d_array(
95+
self.buff[: self.buffer_row], self.array, self.array_offset
96+
)
97+
else:
98+
sync_flush_2d_array(
99+
self.buff[: self.buffer_row], self.array, self.array_offset
100+
)
101+
logger.debug(
102+
f"Flushed chunk {self.array} {self.array_offset} + {self.buffer_row}")
103+
self.array_offset += self.chunk_length
104+
self.buffer_row = 0
69105

70106

71-
def sync_flush_array(np_buffer, zarr_array, offset):
107+
def sync_flush_1d_array(np_buffer, zarr_array, offset):
72108
zarr_array[offset : offset + np_buffer.shape[0]] = np_buffer
109+
update_progress(1)
73110

74111

75-
def async_flush_array(executor, np_buffer, zarr_array, offset):
76-
"""
77-
Flush the specified chunk aligned buffer to the specified zarr array.
78-
"""
79-
logger.debug(f"Schedule flush {zarr_array} @ {offset}")
80-
assert zarr_array.shape[1:] == np_buffer.shape[1:]
81-
# print("sync", zarr_array, np_buffer)
82-
83-
if len(np_buffer.shape) == 1:
84-
futures = [executor.submit(sync_flush_array, np_buffer, zarr_array, offset)]
85-
else:
86-
futures = async_flush_2d_array(executor, np_buffer, zarr_array, offset)
87-
return futures
88-
89-
90-
def async_flush_2d_array(executor, np_buffer, zarr_array, offset):
91-
# Flush each of the chunks in the second dimension separately
112+
def sync_flush_2d_array(np_buffer, zarr_array, offset):
113+
# Write chunks in the second dimension 1-by-1 to make progress more
114+
# incremental, and to avoid large memcopies in the underlying
115+
# encoder implementations.
92116
s = slice(offset, offset + np_buffer.shape[0])
93-
94-
def flush_chunk(start, stop):
95-
zarr_array[s, start:stop] = np_buffer[:, start:stop]
96-
97117
chunk_width = zarr_array.chunks[1]
98118
zarr_array_width = zarr_array.shape[1]
99119
start = 0
100-
futures = []
101120
while start < zarr_array_width:
102121
stop = min(start + chunk_width, zarr_array_width)
103-
future = executor.submit(flush_chunk, start, stop)
104-
futures.append(future)
122+
zarr_array[s, start:stop] = np_buffer[:, start:stop]
123+
update_progress(1)
105124
start = stop
106125

107-
return futures
108-
109-
110-
class ThreadedZarrEncoder(contextlib.AbstractContextManager):
111-
# TODO (maybe) add option with encoder_threads=None to run synchronously for
112-
# debugging using a mock Executor
113-
def __init__(self, buffered_arrays, encoder_threads=1):
114-
self.buffered_arrays = buffered_arrays
115-
self.executor = cf.ThreadPoolExecutor(max_workers=encoder_threads)
116-
self.chunk_length = buffered_arrays[0].chunk_length
117-
assert all(ba.chunk_length == self.chunk_length for ba in self.buffered_arrays)
118-
self.futures = []
119-
self.array_offset = 0
120-
self.next_row = -1
121-
122-
def next_buffer_row(self):
123-
self.next_row += 1
124-
if self.next_row == self.chunk_length:
125-
self.swap_buffers()
126-
self.array_offset += self.chunk_length
127-
self.next_row = 0
128-
return self.next_row
129-
130-
def swap_buffers(self):
131-
wait_on_futures(self.futures)
132-
self.futures = []
133-
for ba in self.buffered_arrays:
134-
self.futures.extend(
135-
ba.async_flush(self.executor, self.array_offset, self.next_row)
136-
)
137-
ba.swap_buffers()
138-
139-
def __exit__(self, exc_type, exc_val, exc_tb):
140-
if exc_type is None:
141-
# Normal exit condition
142-
self.next_row += 1
143-
self.swap_buffers()
144-
wait_on_futures(self.futures)
145-
else:
146-
cancel_futures(self.futures)
147-
self.executor.shutdown()
148-
return False
149-
150126

151127
@dataclasses.dataclass
152128
class ProgressConfig:
@@ -157,6 +133,10 @@ class ProgressConfig:
157133
poll_interval: float = 0.001
158134

159135

136+
# NOTE: this approach means that we cannot have more than one
137+
# progressable thing happening per source process. This is
138+
# probably fine in practise, but there could be corner cases
139+
# where it's not. Something to watch out for.
160140
_progress_counter = multiprocessing.Value("Q", 0)
161141

162142

@@ -190,7 +170,16 @@ def progress_thread_worker(config):
190170
inc = current - pbar.n
191171
pbar.update(inc)
192172
time.sleep(config.poll_interval)
173+
# TODO figure out why we're sometimes going over total
174+
# if get_progress() != config.total:
175+
# print("HOW DID THIS HAPPEN!!")
176+
# print(get_progress())
177+
# print(config)
178+
# assert get_progress() == config.total
179+
inc = config.total - pbar.n
180+
pbar.update(inc)
193181
pbar.close()
182+
# print("EXITING PROGRESS THREAD")
194183

195184

196185
class ParallelWorkManager(contextlib.AbstractContextManager):
@@ -228,7 +217,7 @@ def __exit__(self, exc_type, exc_val, exc_tb):
228217
# Note: this doesn't seem to be working correctly. If
229218
# we set a timeout of None we get deadlocks
230219
set_progress(self.progress_config.total)
231-
timeout = 1
220+
timeout = None
232221
else:
233222
cancel_futures(self.futures)
234223
timeout = 0

0 commit comments

Comments
 (0)