Skip to content

Commit 099101e

Browse files
Add per-array WIP and atomic swap
1 parent 6c96407 commit 099101e

File tree

1 file changed

+51
-42
lines changed

1 file changed

+51
-42
lines changed

bio2zarr/vcf.py

Lines changed: 51 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1289,15 +1289,16 @@ def __init__(self, path, pcvcf, schema):
12891289
self.path = pathlib.Path(path)
12901290
self.pcvcf = pcvcf
12911291
self.schema = schema
1292-
self.root = None
1292+
store = zarr.DirectoryStore(self.path)
1293+
self.root = zarr.group(store=store)
12931294

1294-
def create_array(self, variable):
1295+
def init_array(self, variable):
12951296
# print("CREATE", variable)
12961297
object_codec = None
12971298
if variable.dtype == "O":
12981299
object_codec = numcodecs.VLenUTF8()
12991300
a = self.root.empty(
1300-
variable.name,
1301+
"wip_" + variable.name,
13011302
shape=variable.shape,
13021303
chunks=variable.chunks,
13031304
dtype=variable.dtype,
@@ -1306,9 +1307,19 @@ def create_array(self, variable):
13061307
)
13071308
a.attrs["_ARRAY_DIMENSIONS"] = variable.dimensions
13081309

1309-
def encode_column_slice(self, column, start, stop):
1310+
def get_array(self, name):
1311+
return self.root["wip_" + name]
1312+
1313+
def finalise_array(self, variable):
1314+
source = self.path / ("wip_" + variable.name)
1315+
dest = self.path / variable.name
1316+
# Atomic swap
1317+
os.rename(source, dest)
1318+
logger.debug(f"Finalised {variable.name}")
1319+
1320+
def encode_array_slice(self, column, start, stop):
13101321
source_col = self.pcvcf.columns[column.vcf_field]
1311-
array = self.root[column.name]
1322+
array = self.get_array(column.name)
13121323
ba = core.BufferedArray(array, start)
13131324
sanitiser = source_col.sanitiser_factory(ba.buff.shape)
13141325

@@ -1322,9 +1333,9 @@ def encode_column_slice(self, column, start, stop):
13221333

13231334
def encode_genotypes_slice(self, start, stop):
13241335
source_col = self.pcvcf.columns["FORMAT/GT"]
1325-
gt = core.BufferedArray(self.root["call_genotype"], start)
1326-
gt_mask = core.BufferedArray(self.root["call_genotype_mask"], start)
1327-
gt_phased = core.BufferedArray(self.root["call_genotype_phased"], start)
1336+
gt = core.BufferedArray(self.get_array("call_genotype"), start)
1337+
gt_mask = core.BufferedArray(self.get_array("call_genotype_mask"), start)
1338+
gt_phased = core.BufferedArray(self.get_array("call_genotype_phased"), start)
13281339

13291340
for value in source_col.iter_values(start, stop):
13301341
j = gt.next_buffer_row()
@@ -1343,7 +1354,7 @@ def encode_genotypes_slice(self, start, stop):
13431354
def encode_alleles_slice(self, start, stop):
13441355
ref_col = self.pcvcf.columns["REF"]
13451356
alt_col = self.pcvcf.columns["ALT"]
1346-
alleles = core.BufferedArray(self.root["variant_allele"], start)
1357+
alleles = core.BufferedArray(self.get_array("variant_allele"), start)
13471358

13481359
for ref, alt in zip(
13491360
ref_col.iter_values(start, stop), alt_col.iter_values(start, stop)
@@ -1357,8 +1368,8 @@ def encode_alleles_slice(self, start, stop):
13571368

13581369
def encode_id_slice(self, start, stop):
13591370
col = self.pcvcf.columns["ID"]
1360-
vid = core.BufferedArray(self.root["variant_id"], start)
1361-
vid_mask = core.BufferedArray(self.root["variant_id_mask"], start)
1371+
vid = core.BufferedArray(self.get_array("variant_id"), start)
1372+
vid_mask = core.BufferedArray(self.get_array("variant_id_mask"), start)
13621373

13631374
for value in col.iter_values(start, stop):
13641375
j = vid.next_buffer_row()
@@ -1376,7 +1387,7 @@ def encode_id_slice(self, start, stop):
13761387

13771388
def encode_filters_slice(self, lookup, start, stop):
13781389
col = self.pcvcf.columns["FILTERS"]
1379-
var_filter = core.BufferedArray(self.root["variant_filter"], start)
1390+
var_filter = core.BufferedArray(self.get_array("variant_filter"), start)
13801391

13811392
for value in col.iter_values(start, stop):
13821393
j = var_filter.next_buffer_row()
@@ -1391,7 +1402,7 @@ def encode_filters_slice(self, lookup, start, stop):
13911402

13921403
def encode_contig_slice(self, lookup, start, stop):
13931404
col = self.pcvcf.columns["CHROM"]
1394-
contig = core.BufferedArray(self.root["variant_contig"], start)
1405+
contig = core.BufferedArray(self.get_array("variant_contig"), start)
13951406

13961407
for value in col.iter_values(start, stop):
13971408
j = contig.next_buffer_row()
@@ -1443,31 +1454,28 @@ def encode_filter_id(self):
14431454
array.attrs["_ARRAY_DIMENSIONS"] = ["filters"]
14441455
return {v: j for j, v in enumerate(self.schema.filter_id)}
14451456

1457+
def init(self):
1458+
self.root.attrs["vcf_zarr_version"] = "0.2"
1459+
self.root.attrs["vcf_header"] = self.pcvcf.vcf_header
1460+
self.root.attrs["source"] = f"bio2zarr-{provenance.__version__}"
1461+
for column in self.schema.columns.values():
1462+
self.init_array(column)
1463+
1464+
def finalise(self):
1465+
for column in self.schema.columns.values():
1466+
self.finalise_array(column)
1467+
zarr.consolidate_metadata(self.path)
1468+
14461469
def encode(
14471470
self,
14481471
worker_processes=1,
14491472
max_v_chunks=None,
14501473
show_progress=False,
14511474
):
1452-
# TODO: we should do this as a future to avoid blocking
1453-
if self.path.exists():
1454-
logger.warning(f"Deleting existing {path}")
1455-
shutil.rmtree(self.path)
1456-
write_path = self.path.with_suffix(self.path.suffix + f".{os.getpid()}.build")
1457-
store = zarr.DirectoryStore(write_path)
1458-
logger.info(f"Create zarr at {write_path}")
1459-
self.root = zarr.group(store=store, overwrite=True)
1460-
for column in self.schema.columns.values():
1461-
self.create_array(column)
1462-
1463-
self.root.attrs["vcf_zarr_version"] = "0.2"
1464-
self.root.attrs["vcf_header"] = self.pcvcf.vcf_header
1465-
self.root.attrs["source"] = f"bio2zarr-{provenance.__version__}"
1466-
14671475
num_slices = max(1, worker_processes * 4)
14681476
# Using POS arbitrarily to get the array slices
14691477
slices = core.chunk_aligned_slices(
1470-
self.root["variant_position"], num_slices, max_chunks=max_v_chunks
1478+
self.get_array("variant_position"), num_slices, max_chunks=max_v_chunks
14711479
)
14721480
truncated = slices[-1][-1]
14731481
for array in self.root.values():
@@ -1480,7 +1488,7 @@ def encode(
14801488
col for col in self.schema.columns.values() if len(col.chunks) <= 1
14811489
]
14821490
progress_config = core.ProgressConfig(
1483-
total=sum(self.root[col.name].nchunks for col in chunked_1d),
1491+
total=sum(self.get_array(col.name).nchunks for col in chunked_1d),
14841492
title="Encode 1D",
14851493
units="chunks",
14861494
show=show_progress,
@@ -1499,24 +1507,24 @@ def encode(
14991507
pwm.submit(self.encode_contig_slice, contig_id_map, start, stop)
15001508
for col in chunked_1d:
15011509
if col.vcf_field is not None:
1502-
pwm.submit(self.encode_column_slice, col, start, stop)
1510+
pwm.submit(self.encode_array_slice, col, start, stop)
15031511

15041512
chunked_2d = [
15051513
col for col in self.schema.columns.values() if len(col.chunks) >= 2
15061514
]
15071515
if len(chunked_2d) > 0:
15081516
progress_config = core.ProgressConfig(
1509-
total=sum(self.root[col.name].nchunks for col in chunked_2d),
1517+
total=sum(self.get_array(col.name).nchunks for col in chunked_2d),
15101518
title="Encode 2D",
15111519
units="chunks",
15121520
show=show_progress,
15131521
)
15141522
with core.ParallelWorkManager(worker_processes, progress_config) as pwm:
15151523
if "call_genotype" in self.schema.columns:
15161524
arrays = [
1517-
self.root["call_genotype"],
1518-
self.root["call_genotype_phased"],
1519-
self.root["call_genotype_mask"],
1525+
self.get_array("call_genotype"),
1526+
self.get_array("call_genotype_phased"),
1527+
self.get_array("call_genotype_mask"),
15201528
]
15211529
min_mem = sum(array.blocks[0].nbytes for array in arrays)
15221530
logger.info(
@@ -1528,19 +1536,14 @@ def encode(
15281536

15291537
for col in chunked_2d:
15301538
if col.vcf_field is not None:
1531-
array = self.root[col.name]
1539+
array = self.get_array(col.name)
15321540
min_mem = array.blocks[0].nbytes
15331541
logger.info(
15341542
f"Submit encode {col.name} in {len(slices)} slices. "
15351543
f"Min per-worker mem={display_size(min_mem)}"
15361544
)
15371545
for start, stop in slices:
1538-
pwm.submit(self.encode_column_slice, col, start, stop)
1539-
1540-
zarr.consolidate_metadata(write_path)
1541-
# Atomic swap, now we've completely finished.
1542-
logger.info(f"Moving to final path {self.path}")
1543-
os.rename(write_path, self.path)
1546+
pwm.submit(self.encode_array_slice, col, start, stop)
15441547

15451548

15461549
def mkschema(if_path, out):
@@ -1572,12 +1575,18 @@ def encode(
15721575
raise ValueError("Cannot specify schema along with chunk sizes")
15731576
with open(schema_path, "r") as f:
15741577
schema = ZarrConversionSpec.fromjson(f.read())
1578+
zarr_path = pathlib.Path(zarr_path)
1579+
if zarr_path.exists():
1580+
logger.warning(f"Deleting existing {zarr_path}")
1581+
shutil.rmtree(zarr_path)
15751582
vzw = VcfZarrWriter(zarr_path, pcvcf, schema)
1583+
vzw.init()
15761584
vzw.encode(
15771585
max_v_chunks=max_v_chunks,
15781586
worker_processes=worker_processes,
15791587
show_progress=show_progress,
15801588
)
1589+
vzw.finalise()
15811590

15821591

15831592
def convert(

0 commit comments

Comments
 (0)