@@ -1289,15 +1289,16 @@ def __init__(self, path, pcvcf, schema):
1289
1289
self .path = pathlib .Path (path )
1290
1290
self .pcvcf = pcvcf
1291
1291
self .schema = schema
1292
- self .root = None
1292
+ store = zarr .DirectoryStore (self .path )
1293
+ self .root = zarr .group (store = store )
1293
1294
1294
- def create_array (self , variable ):
1295
+ def init_array (self , variable ):
1295
1296
# print("CREATE", variable)
1296
1297
object_codec = None
1297
1298
if variable .dtype == "O" :
1298
1299
object_codec = numcodecs .VLenUTF8 ()
1299
1300
a = self .root .empty (
1300
- variable .name ,
1301
+ "wip_" + variable .name ,
1301
1302
shape = variable .shape ,
1302
1303
chunks = variable .chunks ,
1303
1304
dtype = variable .dtype ,
@@ -1306,9 +1307,19 @@ def create_array(self, variable):
1306
1307
)
1307
1308
a .attrs ["_ARRAY_DIMENSIONS" ] = variable .dimensions
1308
1309
1309
- def encode_column_slice (self , column , start , stop ):
1310
+ def get_array (self , name ):
1311
+ return self .root ["wip_" + name ]
1312
+
1313
+ def finalise_array (self , variable ):
1314
+ source = self .path / ("wip_" + variable .name )
1315
+ dest = self .path / variable .name
1316
+ # Atomic swap
1317
+ os .rename (source , dest )
1318
+ logger .debug (f"Finalised { variable .name } " )
1319
+
1320
+ def encode_array_slice (self , column , start , stop ):
1310
1321
source_col = self .pcvcf .columns [column .vcf_field ]
1311
- array = self .root [ column .name ]
1322
+ array = self .get_array ( column .name )
1312
1323
ba = core .BufferedArray (array , start )
1313
1324
sanitiser = source_col .sanitiser_factory (ba .buff .shape )
1314
1325
@@ -1322,9 +1333,9 @@ def encode_column_slice(self, column, start, stop):
1322
1333
1323
1334
def encode_genotypes_slice (self , start , stop ):
1324
1335
source_col = self .pcvcf .columns ["FORMAT/GT" ]
1325
- gt = core .BufferedArray (self .root [ "call_genotype" ] , start )
1326
- gt_mask = core .BufferedArray (self .root [ "call_genotype_mask" ] , start )
1327
- gt_phased = core .BufferedArray (self .root [ "call_genotype_phased" ] , start )
1336
+ gt = core .BufferedArray (self .get_array ( "call_genotype" ) , start )
1337
+ gt_mask = core .BufferedArray (self .get_array ( "call_genotype_mask" ) , start )
1338
+ gt_phased = core .BufferedArray (self .get_array ( "call_genotype_phased" ) , start )
1328
1339
1329
1340
for value in source_col .iter_values (start , stop ):
1330
1341
j = gt .next_buffer_row ()
@@ -1343,7 +1354,7 @@ def encode_genotypes_slice(self, start, stop):
1343
1354
def encode_alleles_slice (self , start , stop ):
1344
1355
ref_col = self .pcvcf .columns ["REF" ]
1345
1356
alt_col = self .pcvcf .columns ["ALT" ]
1346
- alleles = core .BufferedArray (self .root [ "variant_allele" ] , start )
1357
+ alleles = core .BufferedArray (self .get_array ( "variant_allele" ) , start )
1347
1358
1348
1359
for ref , alt in zip (
1349
1360
ref_col .iter_values (start , stop ), alt_col .iter_values (start , stop )
@@ -1357,8 +1368,8 @@ def encode_alleles_slice(self, start, stop):
1357
1368
1358
1369
def encode_id_slice (self , start , stop ):
1359
1370
col = self .pcvcf .columns ["ID" ]
1360
- vid = core .BufferedArray (self .root [ "variant_id" ] , start )
1361
- vid_mask = core .BufferedArray (self .root [ "variant_id_mask" ] , start )
1371
+ vid = core .BufferedArray (self .get_array ( "variant_id" ) , start )
1372
+ vid_mask = core .BufferedArray (self .get_array ( "variant_id_mask" ) , start )
1362
1373
1363
1374
for value in col .iter_values (start , stop ):
1364
1375
j = vid .next_buffer_row ()
@@ -1376,7 +1387,7 @@ def encode_id_slice(self, start, stop):
1376
1387
1377
1388
def encode_filters_slice (self , lookup , start , stop ):
1378
1389
col = self .pcvcf .columns ["FILTERS" ]
1379
- var_filter = core .BufferedArray (self .root [ "variant_filter" ] , start )
1390
+ var_filter = core .BufferedArray (self .get_array ( "variant_filter" ) , start )
1380
1391
1381
1392
for value in col .iter_values (start , stop ):
1382
1393
j = var_filter .next_buffer_row ()
@@ -1391,7 +1402,7 @@ def encode_filters_slice(self, lookup, start, stop):
1391
1402
1392
1403
def encode_contig_slice (self , lookup , start , stop ):
1393
1404
col = self .pcvcf .columns ["CHROM" ]
1394
- contig = core .BufferedArray (self .root [ "variant_contig" ] , start )
1405
+ contig = core .BufferedArray (self .get_array ( "variant_contig" ) , start )
1395
1406
1396
1407
for value in col .iter_values (start , stop ):
1397
1408
j = contig .next_buffer_row ()
@@ -1443,31 +1454,28 @@ def encode_filter_id(self):
1443
1454
array .attrs ["_ARRAY_DIMENSIONS" ] = ["filters" ]
1444
1455
return {v : j for j , v in enumerate (self .schema .filter_id )}
1445
1456
1457
+ def init (self ):
1458
+ self .root .attrs ["vcf_zarr_version" ] = "0.2"
1459
+ self .root .attrs ["vcf_header" ] = self .pcvcf .vcf_header
1460
+ self .root .attrs ["source" ] = f"bio2zarr-{ provenance .__version__ } "
1461
+ for column in self .schema .columns .values ():
1462
+ self .init_array (column )
1463
+
1464
+ def finalise (self ):
1465
+ for column in self .schema .columns .values ():
1466
+ self .finalise_array (column )
1467
+ zarr .consolidate_metadata (self .path )
1468
+
1446
1469
def encode (
1447
1470
self ,
1448
1471
worker_processes = 1 ,
1449
1472
max_v_chunks = None ,
1450
1473
show_progress = False ,
1451
1474
):
1452
- # TODO: we should do this as a future to avoid blocking
1453
- if self .path .exists ():
1454
- logger .warning (f"Deleting existing { path } " )
1455
- shutil .rmtree (self .path )
1456
- write_path = self .path .with_suffix (self .path .suffix + f".{ os .getpid ()} .build" )
1457
- store = zarr .DirectoryStore (write_path )
1458
- logger .info (f"Create zarr at { write_path } " )
1459
- self .root = zarr .group (store = store , overwrite = True )
1460
- for column in self .schema .columns .values ():
1461
- self .create_array (column )
1462
-
1463
- self .root .attrs ["vcf_zarr_version" ] = "0.2"
1464
- self .root .attrs ["vcf_header" ] = self .pcvcf .vcf_header
1465
- self .root .attrs ["source" ] = f"bio2zarr-{ provenance .__version__ } "
1466
-
1467
1475
num_slices = max (1 , worker_processes * 4 )
1468
1476
# Using POS arbitrarily to get the array slices
1469
1477
slices = core .chunk_aligned_slices (
1470
- self .root [ "variant_position" ] , num_slices , max_chunks = max_v_chunks
1478
+ self .get_array ( "variant_position" ) , num_slices , max_chunks = max_v_chunks
1471
1479
)
1472
1480
truncated = slices [- 1 ][- 1 ]
1473
1481
for array in self .root .values ():
@@ -1480,7 +1488,7 @@ def encode(
1480
1488
col for col in self .schema .columns .values () if len (col .chunks ) <= 1
1481
1489
]
1482
1490
progress_config = core .ProgressConfig (
1483
- total = sum (self .root [ col .name ] .nchunks for col in chunked_1d ),
1491
+ total = sum (self .get_array ( col .name ) .nchunks for col in chunked_1d ),
1484
1492
title = "Encode 1D" ,
1485
1493
units = "chunks" ,
1486
1494
show = show_progress ,
@@ -1499,24 +1507,24 @@ def encode(
1499
1507
pwm .submit (self .encode_contig_slice , contig_id_map , start , stop )
1500
1508
for col in chunked_1d :
1501
1509
if col .vcf_field is not None :
1502
- pwm .submit (self .encode_column_slice , col , start , stop )
1510
+ pwm .submit (self .encode_array_slice , col , start , stop )
1503
1511
1504
1512
chunked_2d = [
1505
1513
col for col in self .schema .columns .values () if len (col .chunks ) >= 2
1506
1514
]
1507
1515
if len (chunked_2d ) > 0 :
1508
1516
progress_config = core .ProgressConfig (
1509
- total = sum (self .root [ col .name ] .nchunks for col in chunked_2d ),
1517
+ total = sum (self .get_array ( col .name ) .nchunks for col in chunked_2d ),
1510
1518
title = "Encode 2D" ,
1511
1519
units = "chunks" ,
1512
1520
show = show_progress ,
1513
1521
)
1514
1522
with core .ParallelWorkManager (worker_processes , progress_config ) as pwm :
1515
1523
if "call_genotype" in self .schema .columns :
1516
1524
arrays = [
1517
- self .root [ "call_genotype" ] ,
1518
- self .root [ "call_genotype_phased" ] ,
1519
- self .root [ "call_genotype_mask" ] ,
1525
+ self .get_array ( "call_genotype" ) ,
1526
+ self .get_array ( "call_genotype_phased" ) ,
1527
+ self .get_array ( "call_genotype_mask" ) ,
1520
1528
]
1521
1529
min_mem = sum (array .blocks [0 ].nbytes for array in arrays )
1522
1530
logger .info (
@@ -1528,19 +1536,14 @@ def encode(
1528
1536
1529
1537
for col in chunked_2d :
1530
1538
if col .vcf_field is not None :
1531
- array = self .root [ col .name ]
1539
+ array = self .get_array ( col .name )
1532
1540
min_mem = array .blocks [0 ].nbytes
1533
1541
logger .info (
1534
1542
f"Submit encode { col .name } in { len (slices )} slices. "
1535
1543
f"Min per-worker mem={ display_size (min_mem )} "
1536
1544
)
1537
1545
for start , stop in slices :
1538
- pwm .submit (self .encode_column_slice , col , start , stop )
1539
-
1540
- zarr .consolidate_metadata (write_path )
1541
- # Atomic swap, now we've completely finished.
1542
- logger .info (f"Moving to final path { self .path } " )
1543
- os .rename (write_path , self .path )
1546
+ pwm .submit (self .encode_array_slice , col , start , stop )
1544
1547
1545
1548
1546
1549
def mkschema (if_path , out ):
@@ -1572,12 +1575,18 @@ def encode(
1572
1575
raise ValueError ("Cannot specify schema along with chunk sizes" )
1573
1576
with open (schema_path , "r" ) as f :
1574
1577
schema = ZarrConversionSpec .fromjson (f .read ())
1578
+ zarr_path = pathlib .Path (zarr_path )
1579
+ if zarr_path .exists ():
1580
+ logger .warning (f"Deleting existing { zarr_path } " )
1581
+ shutil .rmtree (zarr_path )
1575
1582
vzw = VcfZarrWriter (zarr_path , pcvcf , schema )
1583
+ vzw .init ()
1576
1584
vzw .encode (
1577
1585
max_v_chunks = max_v_chunks ,
1578
1586
worker_processes = worker_processes ,
1579
1587
show_progress = show_progress ,
1580
1588
)
1589
+ vzw .finalise ()
1581
1590
1582
1591
1583
1592
def convert (
0 commit comments