Skip to content

Commit 2c31ba2

Browse files
Add max_v_chunks option to encode
Closes #57
1 parent f77642e commit 2c31ba2

File tree

6 files changed

+96
-12
lines changed

6 files changed

+96
-12
lines changed

bio2zarr/cli.py

Lines changed: 41 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,21 @@
1515
)
1616

1717
# TODO help text
18-
chunk_length = click.option("-l", "--chunk-length", type=int, default=None)
18+
chunk_length = click.option(
19+
"-l",
20+
"--chunk-length",
21+
type=int,
22+
default=None,
23+
help="Chunk size in the variants dimension",
24+
)
1925

20-
chunk_width = click.option("-w", "--chunk-width", type=int, default=None)
26+
chunk_width = click.option(
27+
"-w",
28+
"--chunk-width",
29+
type=int,
30+
default=None,
31+
help="Chunk size in the samples dimension",
32+
)
2133

2234
version = click.version_option(version=provenance.__version__)
2335

@@ -83,13 +95,30 @@ def mkschema(if_path):
8395
@click.argument("if_path", type=click.Path())
8496
@click.argument("zarr_path", type=click.Path())
8597
@verbose
86-
@click.option("-s", "--schema", default=None)
87-
# TODO: these are mutually exclusive with schema, tell click this
98+
@click.option("-s", "--schema", default=None, type=click.Path(exists=True))
8899
@chunk_length
89100
@chunk_width
101+
@click.option(
102+
"-V",
103+
"--max-variant-chunks",
104+
type=int,
105+
default=None,
106+
help=(
107+
"Truncate the output in the variants dimension to have "
108+
"this number of chunks. Mainly intended to help with "
109+
"schema tuning."
110+
),
111+
)
90112
@worker_processes
91113
def encode(
92-
if_path, zarr_path, verbose, schema, chunk_length, chunk_width, worker_processes
114+
if_path,
115+
zarr_path,
116+
verbose,
117+
schema,
118+
chunk_length,
119+
chunk_width,
120+
max_variant_chunks,
121+
worker_processes,
93122
):
94123
"""
95124
Encode intermediate format (see explode) to vcfzarr
@@ -101,6 +130,7 @@ def encode(
101130
schema,
102131
chunk_length=chunk_length,
103132
chunk_width=chunk_width,
133+
max_v_chunks=max_variant_chunks,
104134
worker_processes=worker_processes,
105135
show_progress=True,
106136
)
@@ -132,6 +162,9 @@ def convert_vcf(vcfs, out_path, chunk_length, chunk_width, verbose, worker_proce
132162
@click.argument("vcfs", nargs=-1, required=True)
133163
@click.argument("out_path", type=click.Path())
134164
def validate(vcfs, out_path):
165+
"""
166+
Development only, do not use. Will be removed before release.
167+
"""
135168
# FIXME! Will silently not look at remaining VCFs
136169
vcf.validate(vcfs[0], out_path, show_progress=True)
137170

@@ -158,7 +191,9 @@ def vcf2zarr():
158191
@verbose
159192
@chunk_length
160193
@chunk_width
161-
def convert_plink(in_path, out_path, verbose, worker_processes, chunk_length, chunk_width):
194+
def convert_plink(
195+
in_path, out_path, verbose, worker_processes, chunk_length, chunk_width
196+
):
162197
"""
163198
In development; DO NOT USE!
164199
"""

bio2zarr/core.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,13 +23,15 @@
2323
)
2424

2525

26-
def chunk_aligned_slices(z, n):
26+
def chunk_aligned_slices(z, n, max_chunks=None):
2727
"""
2828
Returns at n slices in the specified zarr array, aligned
2929
with its chunks
3030
"""
3131
chunk_size = z.chunks[0]
3232
num_chunks = int(np.ceil(z.shape[0] / chunk_size))
33+
if max_chunks is not None:
34+
num_chunks = min(num_chunks, max_chunks)
3335
slices = []
3436
splits = np.array_split(np.arange(num_chunks), min(n, num_chunks))
3537
for split in splits:
@@ -132,7 +134,7 @@ class ProgressConfig:
132134
units: str = ""
133135
title: str = ""
134136
show: bool = False
135-
poll_interval: float = 0.001
137+
poll_interval: float = 0.01
136138

137139

138140
# NOTE: this approach means that we cannot have more than one
@@ -186,6 +188,7 @@ def __init__(self, worker_processes=1, progress_config=None):
186188
self.progress_thread = threading.Thread(
187189
target=self._update_progress_worker,
188190
name="progress-update",
191+
daemon=True, # Avoids deadlock on exit in awkward error conditions
189192
)
190193
self.progress_thread.start()
191194

bio2zarr/vcf.py

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1091,7 +1091,7 @@ def generate(pcvcf, chunk_length=None, chunk_width=None):
10911091
chunk_width = 1000
10921092
if chunk_length is None:
10931093
chunk_length = 10_000
1094-
1094+
logger.info(f"Generating schema with chunks={chunk_length, chunk_width}")
10951095
compressor = core.default_compressor.get_config()
10961096

10971097
def fixed_field_spec(
@@ -1403,7 +1403,13 @@ def encode_filter_id(self, pcvcf, filter_names):
14031403

14041404
@staticmethod
14051405
def encode(
1406-
pcvcf, path, conversion_spec, *, worker_processes=1, show_progress=False
1406+
pcvcf,
1407+
path,
1408+
conversion_spec,
1409+
*,
1410+
worker_processes=1,
1411+
max_v_chunks=None,
1412+
show_progress=False,
14071413
):
14081414
path = pathlib.Path(path)
14091415
# TODO: we should do this as a future to avoid blocking
@@ -1425,7 +1431,15 @@ def encode(
14251431

14261432
num_slices = max(1, worker_processes * 4)
14271433
# Using POS arbitrarily to get the array slices
1428-
slices = core.chunk_aligned_slices(sgvcf.root["variant_position"], num_slices)
1434+
slices = core.chunk_aligned_slices(
1435+
sgvcf.root["variant_position"], num_slices, max_chunks=max_v_chunks
1436+
)
1437+
truncated = slices[-1][-1]
1438+
for array in sgvcf.root.values():
1439+
if array.attrs["_ARRAY_DIMENSIONS"][0] == "variants":
1440+
shape = list(array.shape)
1441+
shape[0] = truncated
1442+
array.resize(shape)
14291443

14301444
chunked_1d = [
14311445
col for col in conversion_spec.columns.values() if len(col.chunks) <= 1
@@ -1503,6 +1517,7 @@ def encode(
15031517
schema_path=None,
15041518
chunk_length=None,
15051519
chunk_width=None,
1520+
max_v_chunks=None,
15061521
worker_processes=1,
15071522
show_progress=False,
15081523
):
@@ -1514,13 +1529,17 @@ def encode(
15141529
chunk_width=chunk_width,
15151530
)
15161531
else:
1517-
# TODO checking that chunk_width and chunk_length are None
1532+
logger.info(f"Reading schema from {schema_path}")
1533+
if chunk_length is not None or chunk_width is not None:
1534+
raise ValueError("Cannot specify schema along with chunk sizes")
15181535
with open(schema_path, "r") as f:
15191536
schema = ZarrConversionSpec.fromjson(f.read())
1537+
15201538
SgvcfZarr.encode(
15211539
pcvcf,
15221540
zarr_path,
15231541
conversion_spec=schema,
1542+
max_v_chunks=max_v_chunks,
15241543
worker_processes=worker_processes,
15251544
show_progress=show_progress,
15261545
)

tests/test_cli.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ def test_encode(self):
6666
None,
6767
chunk_length=None,
6868
chunk_width=None,
69+
max_v_chunks=None,
6970
worker_processes=1,
7071
show_progress=True,
7172
)

tests/test_core.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,22 @@ def test_20_chunk_5(self, n, expected):
7272
result = core.chunk_aligned_slices(z, n)
7373
assert result == expected
7474

75+
@pytest.mark.parametrize(
76+
["n", "max_chunks", "expected"],
77+
[
78+
(1, 5, [(0, 20)]),
79+
(1, 1, [(0, 5)]),
80+
(2, 1, [(0, 5)]),
81+
(3, 1, [(0, 5)]),
82+
(2, 3, [(0, 10), (10, 15)]),
83+
(2, 4, [(0, 10), (10, 20)]),
84+
],
85+
)
86+
def test_20_chunk_5_max_chunks(self, n, max_chunks, expected):
87+
z = zarr.array(np.arange(20), chunks=5, dtype=int)
88+
result = core.chunk_aligned_slices(z, n, max_chunks=max_chunks)
89+
assert result == expected
90+
7591
@pytest.mark.parametrize(
7692
["n", "expected"],
7793
[

tests/test_vcf_examples.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -315,6 +315,16 @@ def test_full_pipeline(self, ds, tmp_path, worker_processes):
315315
ds2 = sg.load_dataset(out)
316316
xt.assert_equal(ds, ds2)
317317

318+
@pytest.mark.parametrize("max_v_chunks", [1, 2, 3])
319+
@pytest.mark.parametrize("chunk_length", [1, 2, 3])
320+
def test_max_v_chunks(self, ds, tmp_path, max_v_chunks, chunk_length):
321+
exploded = tmp_path / "example.exploded"
322+
vcf.explode([self.data_path], exploded)
323+
out = tmp_path / "example.zarr"
324+
vcf.encode(exploded, out, chunk_length=chunk_length, max_v_chunks=max_v_chunks)
325+
ds2 = sg.load_dataset(out)
326+
xt.assert_equal(ds.isel(variants=slice(None, chunk_length * max_v_chunks)), ds2)
327+
318328
@pytest.mark.parametrize("worker_processes", [0, 1, 2])
319329
def test_worker_processes(self, ds, tmp_path, worker_processes):
320330
out = tmp_path / "example.vcf.zarr"

0 commit comments

Comments
 (0)