Skip to content

Commit 6c96407

Browse files
Initial refactor of the encode path
1 parent 8acf0d4 commit 6c96407

File tree

1 file changed

+64
-85
lines changed

1 file changed

+64
-85
lines changed

bio2zarr/vcf.py

Lines changed: 64 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -1284,11 +1284,11 @@ def summary_table(self):
12841284
return data
12851285

12861286

1287-
# TODO refactor this into a VcfZarrWriter class, and get rid of the
1288-
# statis methods.
1289-
class SgvcfZarr:
1290-
def __init__(self, path):
1287+
class VcfZarrWriter:
1288+
def __init__(self, path, pcvcf, schema):
12911289
self.path = pathlib.Path(path)
1290+
self.pcvcf = pcvcf
1291+
self.schema = schema
12921292
self.root = None
12931293

12941294
def create_array(self, variable):
@@ -1306,8 +1306,8 @@ def create_array(self, variable):
13061306
)
13071307
a.attrs["_ARRAY_DIMENSIONS"] = variable.dimensions
13081308

1309-
def encode_column_slice(self, pcvcf, column, start, stop):
1310-
source_col = pcvcf.columns[column.vcf_field]
1309+
def encode_column_slice(self, column, start, stop):
1310+
source_col = self.pcvcf.columns[column.vcf_field]
13111311
array = self.root[column.name]
13121312
ba = core.BufferedArray(array, start)
13131313
sanitiser = source_col.sanitiser_factory(ba.buff.shape)
@@ -1320,8 +1320,8 @@ def encode_column_slice(self, pcvcf, column, start, stop):
13201320
ba.flush()
13211321
logger.debug(f"Encoded {column.name} slice {start}:{stop}")
13221322

1323-
def encode_genotypes_slice(self, pcvcf, start, stop):
1324-
source_col = pcvcf.columns["FORMAT/GT"]
1323+
def encode_genotypes_slice(self, start, stop):
1324+
source_col = self.pcvcf.columns["FORMAT/GT"]
13251325
gt = core.BufferedArray(self.root["call_genotype"], start)
13261326
gt_mask = core.BufferedArray(self.root["call_genotype_mask"], start)
13271327
gt_phased = core.BufferedArray(self.root["call_genotype_phased"], start)
@@ -1340,9 +1340,9 @@ def encode_genotypes_slice(self, pcvcf, start, stop):
13401340
gt_mask.flush()
13411341
logger.debug(f"Encoded GT slice {start}:{stop}")
13421342

1343-
def encode_alleles_slice(self, pcvcf, start, stop):
1344-
ref_col = pcvcf.columns["REF"]
1345-
alt_col = pcvcf.columns["ALT"]
1343+
def encode_alleles_slice(self, start, stop):
1344+
ref_col = self.pcvcf.columns["REF"]
1345+
alt_col = self.pcvcf.columns["ALT"]
13461346
alleles = core.BufferedArray(self.root["variant_allele"], start)
13471347

13481348
for ref, alt in zip(
@@ -1355,8 +1355,8 @@ def encode_alleles_slice(self, pcvcf, start, stop):
13551355
alleles.flush()
13561356
logger.debug(f"Encoded alleles slice {start}:{stop}")
13571357

1358-
def encode_id_slice(self, pcvcf, start, stop):
1359-
col = pcvcf.columns["ID"]
1358+
def encode_id_slice(self, start, stop):
1359+
col = self.pcvcf.columns["ID"]
13601360
vid = core.BufferedArray(self.root["variant_id"], start)
13611361
vid_mask = core.BufferedArray(self.root["variant_id_mask"], start)
13621362

@@ -1374,8 +1374,8 @@ def encode_id_slice(self, pcvcf, start, stop):
13741374
vid_mask.flush()
13751375
logger.debug(f"Encoded ID slice {start}:{stop}")
13761376

1377-
def encode_filters_slice(self, pcvcf, lookup, start, stop):
1378-
col = pcvcf.columns["FILTERS"]
1377+
def encode_filters_slice(self, lookup, start, stop):
1378+
col = self.pcvcf.columns["FILTERS"]
13791379
var_filter = core.BufferedArray(self.root["variant_filter"], start)
13801380

13811381
for value in col.iter_values(start, stop):
@@ -1389,8 +1389,8 @@ def encode_filters_slice(self, pcvcf, lookup, start, stop):
13891389
var_filter.flush()
13901390
logger.debug(f"Encoded FILTERS slice {start}:{stop}")
13911391

1392-
def encode_contig_slice(self, pcvcf, lookup, start, stop):
1393-
col = pcvcf.columns["CHROM"]
1392+
def encode_contig_slice(self, lookup, start, stop):
1393+
col = self.pcvcf.columns["CHROM"]
13941394
contig = core.BufferedArray(self.root["variant_contig"], start)
13951395

13961396
for value in col.iter_values(start, stop):
@@ -1403,162 +1403,144 @@ def encode_contig_slice(self, pcvcf, lookup, start, stop):
14031403
contig.flush()
14041404
logger.debug(f"Encoded CHROM slice {start}:{stop}")
14051405

1406-
def encode_samples(self, pcvcf, sample_id, chunk_width):
1407-
if not np.array_equal(sample_id, pcvcf.metadata.samples):
1406+
def encode_samples(self):
1407+
if not np.array_equal(self.schema.sample_id, self.pcvcf.metadata.samples):
14081408
raise ValueError("Subsetting or reordering samples not supported currently")
14091409
array = self.root.array(
14101410
"sample_id",
1411-
sample_id,
1411+
self.schema.sample_id,
14121412
dtype="str",
14131413
compressor=core.default_compressor,
1414-
chunks=(chunk_width,),
1414+
chunks=(self.schema.chunk_width,),
14151415
)
14161416
array.attrs["_ARRAY_DIMENSIONS"] = ["samples"]
14171417
logger.debug("Samples done")
14181418

1419-
def encode_contig_id(self, pcvcf, contig_names, contig_lengths):
1419+
def encode_contig_id(self):
14201420
array = self.root.array(
14211421
"contig_id",
1422-
contig_names,
1422+
self.schema.contig_id,
14231423
dtype="str",
14241424
compressor=core.default_compressor,
14251425
)
14261426
array.attrs["_ARRAY_DIMENSIONS"] = ["contigs"]
1427-
if contig_lengths is not None:
1427+
if self.schema.contig_length is not None:
14281428
array = self.root.array(
14291429
"contig_length",
1430-
contig_lengths,
1430+
self.schema.contig_length,
14311431
dtype=np.int64,
14321432
)
14331433
array.attrs["_ARRAY_DIMENSIONS"] = ["contigs"]
1434-
return {v: j for j, v in enumerate(contig_names)}
1434+
return {v: j for j, v in enumerate(self.schema.contig_id)}
14351435

1436-
def encode_filter_id(self, pcvcf, filter_names):
1436+
def encode_filter_id(self):
14371437
array = self.root.array(
14381438
"filter_id",
1439-
filter_names,
1439+
self.schema.filter_id,
14401440
dtype="str",
14411441
compressor=core.default_compressor,
14421442
)
14431443
array.attrs["_ARRAY_DIMENSIONS"] = ["filters"]
1444-
return {v: j for j, v in enumerate(filter_names)}
1444+
return {v: j for j, v in enumerate(self.schema.filter_id)}
14451445

1446-
@staticmethod
14471446
def encode(
1448-
pcvcf,
1449-
path,
1450-
conversion_spec,
1451-
*,
1447+
self,
14521448
worker_processes=1,
14531449
max_v_chunks=None,
14541450
show_progress=False,
14551451
):
1456-
path = pathlib.Path(path)
14571452
# TODO: we should do this as a future to avoid blocking
1458-
if path.exists():
1453+
if self.path.exists():
14591454
logger.warning(f"Deleting existing {path}")
1460-
shutil.rmtree(path)
1461-
write_path = path.with_suffix(path.suffix + f".{os.getpid()}.build")
1455+
shutil.rmtree(self.path)
1456+
write_path = self.path.with_suffix(self.path.suffix + f".{os.getpid()}.build")
14621457
store = zarr.DirectoryStore(write_path)
1463-
# FIXME, duplicating logic about the store
14641458
logger.info(f"Create zarr at {write_path}")
1465-
sgvcf = SgvcfZarr(write_path)
1466-
sgvcf.root = zarr.group(store=store, overwrite=True)
1467-
for column in conversion_spec.columns.values():
1468-
sgvcf.create_array(column)
1459+
self.root = zarr.group(store=store, overwrite=True)
1460+
for column in self.schema.columns.values():
1461+
self.create_array(column)
14691462

1470-
sgvcf.root.attrs["vcf_zarr_version"] = "0.2"
1471-
sgvcf.root.attrs["vcf_header"] = pcvcf.vcf_header
1472-
sgvcf.root.attrs["source"] = f"bio2zarr-{provenance.__version__}"
1463+
self.root.attrs["vcf_zarr_version"] = "0.2"
1464+
self.root.attrs["vcf_header"] = self.pcvcf.vcf_header
1465+
self.root.attrs["source"] = f"bio2zarr-{provenance.__version__}"
14731466

14741467
num_slices = max(1, worker_processes * 4)
14751468
# Using POS arbitrarily to get the array slices
14761469
slices = core.chunk_aligned_slices(
1477-
sgvcf.root["variant_position"], num_slices, max_chunks=max_v_chunks
1470+
self.root["variant_position"], num_slices, max_chunks=max_v_chunks
14781471
)
14791472
truncated = slices[-1][-1]
1480-
for array in sgvcf.root.values():
1473+
for array in self.root.values():
14811474
if array.attrs["_ARRAY_DIMENSIONS"][0] == "variants":
14821475
shape = list(array.shape)
14831476
shape[0] = truncated
14841477
array.resize(shape)
14851478

14861479
chunked_1d = [
1487-
col for col in conversion_spec.columns.values() if len(col.chunks) <= 1
1480+
col for col in self.schema.columns.values() if len(col.chunks) <= 1
14881481
]
14891482
progress_config = core.ProgressConfig(
1490-
total=sum(sgvcf.root[col.name].nchunks for col in chunked_1d),
1483+
total=sum(self.root[col.name].nchunks for col in chunked_1d),
14911484
title="Encode 1D",
14921485
units="chunks",
14931486
show=show_progress,
14941487
)
14951488

14961489
# Do these syncronously for simplicity so we have the mapping
1497-
filter_id_map = sgvcf.encode_filter_id(pcvcf, conversion_spec.filter_id)
1498-
contig_id_map = sgvcf.encode_contig_id(
1499-
pcvcf, conversion_spec.contig_id, conversion_spec.contig_length
1500-
)
1490+
filter_id_map = self.encode_filter_id()
1491+
contig_id_map = self.encode_contig_id()
15011492

15021493
with core.ParallelWorkManager(worker_processes, progress_config) as pwm:
1503-
pwm.submit(
1504-
sgvcf.encode_samples,
1505-
pcvcf,
1506-
conversion_spec.sample_id,
1507-
conversion_spec.chunk_width,
1508-
)
1494+
pwm.submit(self.encode_samples)
15091495
for start, stop in slices:
1510-
pwm.submit(sgvcf.encode_alleles_slice, pcvcf, start, stop)
1511-
pwm.submit(sgvcf.encode_id_slice, pcvcf, start, stop)
1512-
pwm.submit(
1513-
sgvcf.encode_filters_slice, pcvcf, filter_id_map, start, stop
1514-
)
1515-
pwm.submit(sgvcf.encode_contig_slice, pcvcf, contig_id_map, start, stop)
1496+
pwm.submit(self.encode_alleles_slice, start, stop)
1497+
pwm.submit(self.encode_id_slice, start, stop)
1498+
pwm.submit(self.encode_filters_slice, filter_id_map, start, stop)
1499+
pwm.submit(self.encode_contig_slice, contig_id_map, start, stop)
15161500
for col in chunked_1d:
15171501
if col.vcf_field is not None:
1518-
pwm.submit(sgvcf.encode_column_slice, pcvcf, col, start, stop)
1502+
pwm.submit(self.encode_column_slice, col, start, stop)
15191503

15201504
chunked_2d = [
1521-
col for col in conversion_spec.columns.values() if len(col.chunks) >= 2
1505+
col for col in self.schema.columns.values() if len(col.chunks) >= 2
15221506
]
15231507
if len(chunked_2d) > 0:
15241508
progress_config = core.ProgressConfig(
1525-
total=sum(sgvcf.root[col.name].nchunks for col in chunked_2d),
1509+
total=sum(self.root[col.name].nchunks for col in chunked_2d),
15261510
title="Encode 2D",
15271511
units="chunks",
15281512
show=show_progress,
15291513
)
15301514
with core.ParallelWorkManager(worker_processes, progress_config) as pwm:
1531-
if "call_genotype" in conversion_spec.columns:
1515+
if "call_genotype" in self.schema.columns:
15321516
arrays = [
1533-
sgvcf.root["call_genotype"],
1534-
sgvcf.root["call_genotype_phased"],
1535-
sgvcf.root["call_genotype_mask"],
1517+
self.root["call_genotype"],
1518+
self.root["call_genotype_phased"],
1519+
self.root["call_genotype_mask"],
15361520
]
15371521
min_mem = sum(array.blocks[0].nbytes for array in arrays)
15381522
logger.info(
15391523
f"Submit encode call_genotypes in {len(slices)} slices. "
15401524
f"Min per-worker mem={display_size(min_mem)}"
15411525
)
15421526
for start, stop in slices:
1543-
pwm.submit(sgvcf.encode_genotypes_slice, pcvcf, start, stop)
1527+
pwm.submit(self.encode_genotypes_slice, start, stop)
15441528

15451529
for col in chunked_2d:
15461530
if col.vcf_field is not None:
1547-
array = sgvcf.root[col.name]
1531+
array = self.root[col.name]
15481532
min_mem = array.blocks[0].nbytes
15491533
logger.info(
15501534
f"Submit encode {col.name} in {len(slices)} slices. "
15511535
f"Min per-worker mem={display_size(min_mem)}"
15521536
)
15531537
for start, stop in slices:
1554-
pwm.submit(
1555-
sgvcf.encode_column_slice, pcvcf, col, start, stop
1556-
)
1538+
pwm.submit(self.encode_column_slice, col, start, stop)
15571539

15581540
zarr.consolidate_metadata(write_path)
15591541
# Atomic swap, now we've completely finished.
1560-
logger.info(f"Moving to final path {path}")
1561-
os.rename(write_path, path)
1542+
logger.info(f"Moving to final path {self.path}")
1543+
os.rename(write_path, self.path)
15621544

15631545

15641546
def mkschema(if_path, out):
@@ -1590,11 +1572,8 @@ def encode(
15901572
raise ValueError("Cannot specify schema along with chunk sizes")
15911573
with open(schema_path, "r") as f:
15921574
schema = ZarrConversionSpec.fromjson(f.read())
1593-
1594-
SgvcfZarr.encode(
1595-
pcvcf,
1596-
zarr_path,
1597-
conversion_spec=schema,
1575+
vzw = VcfZarrWriter(zarr_path, pcvcf, schema)
1576+
vzw.encode(
15981577
max_v_chunks=max_v_chunks,
15991578
worker_processes=worker_processes,
16001579
show_progress=show_progress,

0 commit comments

Comments
 (0)