Skip to content

Commit 39c53f2

Browse files
Merge pull request #11 from jeromekelleher/refactor-basic
Move BufferedArray into core
2 parents 0a72058 + 661c417 commit 39c53f2

File tree

3 files changed

+121
-77
lines changed

3 files changed

+121
-77
lines changed

bio2zarr/core.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
import dataclasses
2+
import logging
3+
4+
import zarr
5+
import numpy as np
6+
7+
8+
logger = logging.getLogger(__name__)
9+
10+
11+
@dataclasses.dataclass
12+
class BufferedArray:
13+
array: zarr.Array
14+
buff: np.ndarray
15+
16+
def __init__(self, array):
17+
self.array = array
18+
dims = list(array.shape)
19+
dims[0] = min(array.chunks[0], array.shape[0])
20+
self.buff = np.zeros(dims, dtype=array.dtype)
21+
22+
def swap_buffers(self):
23+
self.buff = np.zeros_like(self.buff)
24+
25+
def async_flush(self, executor, offset, buff_stop=None):
26+
return async_flush_array(executor, self.buff[:buff_stop], self.array, offset)
27+
28+
29+
def sync_flush_array(np_buffer, zarr_array, offset):
30+
zarr_array[offset : offset + np_buffer.shape[0]] = np_buffer
31+
32+
33+
def async_flush_array(executor, np_buffer, zarr_array, offset):
34+
"""
35+
Flush the specified chunk aligned buffer to the specified zarr array.
36+
"""
37+
logger.debug(f"Schedule flush {zarr_array} @ {offset}")
38+
assert zarr_array.shape[1:] == np_buffer.shape[1:]
39+
# print("sync", zarr_array, np_buffer)
40+
41+
if len(np_buffer.shape) == 1:
42+
futures = [executor.submit(sync_flush_array, np_buffer, zarr_array, offset)]
43+
else:
44+
futures = async_flush_2d_array(executor, np_buffer, zarr_array, offset)
45+
return futures
46+
47+
48+
def async_flush_2d_array(executor, np_buffer, zarr_array, offset):
49+
# Flush each of the chunks in the second dimension separately
50+
s = slice(offset, offset + np_buffer.shape[0])
51+
52+
def flush_chunk(start, stop):
53+
zarr_array[s, start:stop] = np_buffer[:, start:stop]
54+
55+
chunk_width = zarr_array.chunks[1]
56+
zarr_array_width = zarr_array.shape[1]
57+
start = 0
58+
futures = []
59+
while start < zarr_array_width:
60+
stop = min(start + chunk_width, zarr_array_width)
61+
future = executor.submit(flush_chunk, start, stop)
62+
futures.append(future)
63+
start = stop
64+
65+
return futures

bio2zarr/vcf.py

Lines changed: 14 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@
2525

2626
import bed_reader
2727

28+
from . import core
29+
2830
logger = logging.getLogger(__name__)
2931

3032
INT_MISSING = -1
@@ -1109,21 +1111,6 @@ def fixed_field_spec(
11091111
)
11101112

11111113

1112-
@dataclasses.dataclass
1113-
class BufferedArray:
1114-
array: Any
1115-
buff: Any
1116-
1117-
def __init__(self, array):
1118-
self.array = array
1119-
dims = list(array.shape)
1120-
dims[0] = min(array.chunks[0], array.shape[0])
1121-
self.buff = np.zeros(dims, dtype=array.dtype)
1122-
1123-
def swap_buffers(self):
1124-
self.buff = np.zeros_like(self.buff)
1125-
1126-
11271114
class SgvcfZarr:
11281115
def __init__(self, path):
11291116
self.path = pathlib.Path(path)
@@ -1143,7 +1130,7 @@ def create_array(self, variable):
11431130
def encode_column(self, pcvcf, column):
11441131
source_col = pcvcf.columns[column.vcf_field]
11451132
array = self.root[column.name]
1146-
ba = BufferedArray(array)
1133+
ba = core.BufferedArray(array)
11471134
sanitiser = source_col.sanitiser_factory(ba.buff.shape)
11481135
chunk_length = array.chunks[0]
11491136

@@ -1157,9 +1144,7 @@ def encode_column(self, pcvcf, column):
11571144
j += 1
11581145
if j == chunk_length:
11591146
flush_futures(futures)
1160-
futures.extend(
1161-
async_flush_array(executor, ba.buff, ba.array, chunk_start)
1162-
)
1147+
futures.extend(ba.async_flush(executor, chunk_start))
11631148
ba.swap_buffers()
11641149
j = 0
11651150
chunk_start += chunk_length
@@ -1170,16 +1155,14 @@ def encode_column(self, pcvcf, column):
11701155

11711156
if j != 0:
11721157
flush_futures(futures)
1173-
futures.extend(
1174-
async_flush_array(executor, ba.buff[:j], ba.array, chunk_start)
1175-
)
1158+
futures.extend(ba.async_flush(executor, chunk_start, j))
11761159
flush_futures(futures)
11771160

11781161
def encode_genotypes(self, pcvcf):
11791162
source_col = pcvcf.columns["FORMAT/GT"]
1180-
gt = BufferedArray(self.root["call_genotype"])
1181-
gt_mask = BufferedArray(self.root["call_genotype_mask"])
1182-
gt_phased = BufferedArray(self.root["call_genotype_phased"])
1163+
gt = core.BufferedArray(self.root["call_genotype"])
1164+
gt_mask = core.BufferedArray(self.root["call_genotype_mask"])
1165+
gt_phased = core.BufferedArray(self.root["call_genotype_phased"])
11831166
chunk_length = gt.array.chunks[0]
11841167

11851168
buffered_arrays = [gt, gt_phased, gt_mask]
@@ -1200,9 +1183,7 @@ def encode_genotypes(self, pcvcf):
12001183
if j == chunk_length:
12011184
flush_futures(futures)
12021185
for ba in buffered_arrays:
1203-
futures.extend(
1204-
async_flush_array(executor, ba.buff, ba.array, chunk_start)
1205-
)
1186+
futures.extend(ba.async_flush(executor, chunk_start))
12061187
ba.swap_buffers()
12071188
j = 0
12081189
chunk_start += chunk_length
@@ -1214,9 +1195,7 @@ def encode_genotypes(self, pcvcf):
12141195
if j != 0:
12151196
flush_futures(futures)
12161197
for ba in buffered_arrays:
1217-
futures.extend(
1218-
async_flush_array(executor, ba.buff[:j], ba.array, chunk_start)
1219-
)
1198+
futures.extend(ba.async_flush(executor, chunk_start, j))
12201199
flush_futures(futures)
12211200

12221201
def encode_alleles(self, pcvcf):
@@ -1417,45 +1396,6 @@ def convert(
14171396
os.rename(write_path, path)
14181397

14191398

1420-
def sync_flush_array(np_buffer, zarr_array, offset):
1421-
zarr_array[offset : offset + np_buffer.shape[0]] = np_buffer
1422-
1423-
1424-
def async_flush_array(executor, np_buffer, zarr_array, offset):
1425-
"""
1426-
Flush the specified chunk aligned buffer to the specified zarr array.
1427-
"""
1428-
logger.debug(f"Schedule flush {zarr_array} @ {offset}")
1429-
assert zarr_array.shape[1:] == np_buffer.shape[1:]
1430-
# print("sync", zarr_array, np_buffer)
1431-
1432-
if len(np_buffer.shape) == 1:
1433-
futures = [executor.submit(sync_flush_array, np_buffer, zarr_array, offset)]
1434-
else:
1435-
futures = async_flush_2d_array(executor, np_buffer, zarr_array, offset)
1436-
return futures
1437-
1438-
1439-
def async_flush_2d_array(executor, np_buffer, zarr_array, offset):
1440-
# Flush each of the chunks in the second dimension separately
1441-
s = slice(offset, offset + np_buffer.shape[0])
1442-
1443-
def flush_chunk(start, stop):
1444-
zarr_array[s, start:stop] = np_buffer[:, start:stop]
1445-
1446-
chunk_width = zarr_array.chunks[1]
1447-
zarr_array_width = zarr_array.shape[1]
1448-
start = 0
1449-
futures = []
1450-
while start < zarr_array_width:
1451-
stop = min(start + chunk_width, zarr_array_width)
1452-
future = executor.submit(flush_chunk, start, stop)
1453-
futures.append(future)
1454-
start = stop
1455-
1456-
return futures
1457-
1458-
14591399
def generate_spec(columnarised, out):
14601400
pcvcf = PickleChunkedVcf.load(columnarised)
14611401
spec = ZarrConversionSpec.generate(pcvcf)
@@ -1516,9 +1456,9 @@ def encode_bed_partition_genotypes(bed_path, zarr_path, start_variant, end_varia
15161456

15171457
store = zarr.DirectoryStore(zarr_path)
15181458
root = zarr.group(store=store)
1519-
gt = BufferedArray(root["call_genotype"])
1520-
gt_mask = BufferedArray(root["call_genotype_mask"])
1521-
gt_phased = BufferedArray(root["call_genotype_phased"])
1459+
gt = core.BufferedArray(root["call_genotype"])
1460+
gt_mask = core.BufferedArray(root["call_genotype_mask"])
1461+
gt_phased = core.BufferedArray(root["call_genotype_phased"])
15221462
chunk_length = gt.array.chunks[0]
15231463
assert start_variant % chunk_length == 0
15241464

@@ -1547,9 +1487,7 @@ def encode_bed_partition_genotypes(bed_path, zarr_path, start_variant, end_varia
15471487
assert j <= chunk_length
15481488
flush_futures(futures)
15491489
for ba in buffered_arrays:
1550-
futures.extend(
1551-
async_flush_array(executor, ba.buff[:j], ba.array, start)
1552-
)
1490+
ba.async_flush(extend, start, j)
15531491
ba.swap_buffers()
15541492
start = stop
15551493
flush_futures(futures)

tests/test_vcf.py

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,12 @@
33
import xarray.testing as xt
44
import pytest
55
import sgkit as sg
6+
import zarr
67

78
from bio2zarr import vcf
89

910

10-
class TestSmallExampleValues:
11+
class TestSmallExample:
1112
@pytest.fixture(scope="class")
1213
def ds(self, tmp_path_factory):
1314
path = "tests/data/vcf/sample.vcf.gz"
@@ -230,6 +231,46 @@ def test_no_genotypes(self, ds, tmp_path):
230231
if col != "sample_id" and not col.startswith("call_"):
231232
xt.assert_equal(ds[col], ds2[col])
232233

234+
@pytest.mark.parametrize(
235+
["chunk_length", "chunk_width", "y_chunks", "x_chunks"],
236+
[
237+
(1, 1, (1, 1, 1, 1, 1, 1, 1, 1, 1), (1, 1, 1)),
238+
(2, 2, (2, 2, 2, 2, 1), (2, 1)),
239+
(3, 3, (3, 3, 3), (3,)),
240+
(4, 3, (4, 4, 1), (3,)),
241+
],
242+
)
243+
def test_chunk_size(
244+
self, ds, tmp_path, chunk_length, chunk_width, y_chunks, x_chunks
245+
):
246+
path = "tests/data/vcf/sample.vcf.gz"
247+
out = tmp_path / "example.vcf.zarr"
248+
vcf.convert_vcf([path], out, chunk_length=chunk_length, chunk_width=chunk_width)
249+
ds2 = sg.load_dataset(out)
250+
xt.assert_equal(ds, ds2)
251+
assert ds2.call_DP.chunks == (y_chunks, x_chunks)
252+
assert ds2.call_GQ.chunks == (y_chunks, x_chunks)
253+
assert ds2.call_HQ.chunks == (y_chunks, x_chunks, (2,))
254+
assert ds2.call_genotype.chunks == (y_chunks, x_chunks, (2,))
255+
assert ds2.call_genotype_mask.chunks == (y_chunks, x_chunks, (2,))
256+
assert ds2.call_genotype_phased.chunks == (y_chunks, x_chunks)
257+
assert ds2.variant_AA.chunks == (y_chunks,)
258+
assert ds2.variant_AC.chunks == (y_chunks, (2,))
259+
assert ds2.variant_AF.chunks == (y_chunks, (2,))
260+
assert ds2.variant_DB.chunks == (y_chunks,)
261+
assert ds2.variant_DP.chunks == (y_chunks,)
262+
assert ds2.variant_NS.chunks == (y_chunks,)
263+
assert ds2.variant_allele.chunks == (y_chunks, (4,))
264+
assert ds2.variant_contig.chunks == (y_chunks,)
265+
assert ds2.variant_filter.chunks == (y_chunks, (3,))
266+
assert ds2.variant_id.chunks == (y_chunks,)
267+
assert ds2.variant_id_mask.chunks == (y_chunks,)
268+
assert ds2.variant_position.chunks == (y_chunks,)
269+
assert ds2.variant_quality.chunks == (y_chunks,)
270+
assert ds2.contig_id.chunks == ((3,),)
271+
assert ds2.filter_id.chunks == ((3,),)
272+
assert ds2.sample_id.chunks == (x_chunks,)
273+
233274

234275
@pytest.mark.parametrize(
235276
"name",

0 commit comments

Comments
 (0)