Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions bio2zarr/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,13 +134,16 @@ def cancel_futures(futures):
class BufferedArray:
array: zarr.Array
array_offset: int
name: str
buff: np.ndarray
buffer_row: int
max_buff_size: int = 0

def __init__(self, array, offset):
def __init__(self, array, offset, name="Unknown"):
self.array = array
self.array_offset = offset
assert offset % array.chunks[0] == 0
self.name = name
dims = list(array.shape)
dims[0] = min(array.chunks[0], array.shape[0])
self.buff = np.empty(dims, dtype=array.dtype)
Expand Down Expand Up @@ -171,11 +174,17 @@ def flush(self):
self.buff[: self.buffer_row], self.array, self.array_offset
)
logger.debug(
f"Flushed <{self.array.name} {self.array.shape} "
f"Flushed <{self.name} {self.array.shape} "
f"{self.array.dtype}> "
f"{self.array_offset}:{self.array_offset + self.buffer_row}"
f"{self.buff.nbytes / 2**20: .2f}Mb"
)
# Note this is inaccurate for string data as we're just reporting the
# size of the container. When we switch the numpy 2 StringDtype this
# should improve and we can get more visibility on how memory
# is being used.
# https://github.com/sgkit-dev/bio2zarr/issues/30
self.max_buff_size = max(self.max_buff_size, self.buff.nbytes)
self.array_offset += self.variants_chunk_size
self.buffer_row = 0

Expand Down
101 changes: 42 additions & 59 deletions bio2zarr/vcf2zarr/vcz.py
Original file line number Diff line number Diff line change
Expand Up @@ -844,23 +844,32 @@ def encode_partition(self, partition_index):
os.rename(partition_path, final_path)

def init_partition_array(self, partition_index, name):
field_map = self.schema.field_map()
array_spec = field_map[name]
# Create an empty array like the definition
src = self.arrays_path / name
src = self.arrays_path / array_spec.name
# Overwrite any existing WIP files
wip_path = self.wip_partition_array_path(partition_index, name)
wip_path = self.wip_partition_array_path(partition_index, array_spec.name)
shutil.copytree(src, wip_path, dirs_exist_ok=True)
array = zarr.open_array(store=wip_path, mode="a")
logger.debug(f"Opened empty array {array.name} <{array.dtype}> @ {wip_path}")
return array
partition = self.metadata.partitions[partition_index]
ba = core.BufferedArray(array, partition.start, name)
logger.info(
f"Start partition {partition_index} array {name} <{array.dtype}> "
f"{array.shape} @ {wip_path}"
)
return ba

def finalise_partition_array(self, partition_index, name):
logger.debug(f"Encoded {name} partition {partition_index}")
def finalise_partition_array(self, partition_index, buffered_array):
buffered_array.flush()
logger.info(
f"Completed partition {partition_index} array {buffered_array.name} "
f"max_memory={core.display_size(buffered_array.max_buff_size)}"
)

def encode_array_partition(self, array_spec, partition_index):
array = self.init_partition_array(partition_index, array_spec.name)

partition = self.metadata.partitions[partition_index]
ba = core.BufferedArray(array, partition.start)
ba = self.init_partition_array(partition_index, array_spec.name)
source_field = self.icf.fields[array_spec.vcf_field]
sanitiser = source_field.sanitiser_factory(ba.buff.shape)

Expand All @@ -869,20 +878,16 @@ def encode_array_partition(self, array_spec, partition_index):
# to make it easier to reason about dimension padding
j = ba.next_buffer_row()
sanitiser(ba.buff, j, value)
ba.flush()
self.finalise_partition_array(partition_index, array_spec.name)
self.finalise_partition_array(partition_index, ba)

def encode_genotypes_partition(self, partition_index):
gt_array = self.init_partition_array(partition_index, "call_genotype")
gt_mask_array = self.init_partition_array(partition_index, "call_genotype_mask")
gt_phased_array = self.init_partition_array(
partition_index, "call_genotype_phased"
)
# FIXME we should be doing these one at a time, reading back in the genotypes
# like we do for local alleles
gt = self.init_partition_array(partition_index, "call_genotype")
gt_mask = self.init_partition_array(partition_index, "call_genotype_mask")
gt_phased = self.init_partition_array(partition_index, "call_genotype_phased")

partition = self.metadata.partitions[partition_index]
gt = core.BufferedArray(gt_array, partition.start)
gt_mask = core.BufferedArray(gt_mask_array, partition.start)
gt_phased = core.BufferedArray(gt_phased_array, partition.start)

source_field = self.icf.fields["FORMAT/GT"]
for value in source_field.iter_values(partition.start, partition.stop):
Expand All @@ -898,18 +903,14 @@ def encode_genotypes_partition(self, partition_index):
# with mixed ploidies?
j = gt_mask.next_buffer_row()
gt_mask.buff[j] = gt.buff[j] < 0
gt.flush()
gt_phased.flush()
gt_mask.flush()

self.finalise_partition_array(partition_index, "call_genotype")
self.finalise_partition_array(partition_index, "call_genotype_mask")
self.finalise_partition_array(partition_index, "call_genotype_phased")
self.finalise_partition_array(partition_index, gt)
self.finalise_partition_array(partition_index, gt_phased)
self.finalise_partition_array(partition_index, gt_mask)

def encode_local_alleles_partition(self, partition_index):
partition = self.metadata.partitions[partition_index]
call_LA_array = self.init_partition_array(partition_index, "call_LA")
call_LA = core.BufferedArray(call_LA_array, partition.start)
call_LA = self.init_partition_array(partition_index, "call_LA")

gt_array = zarr.open_array(
store=self.wip_partition_array_path(partition_index, "call_genotype"),
Expand All @@ -921,26 +922,23 @@ def encode_local_alleles_partition(self, partition_index):
la = compute_la_field(genotypes)
j = call_LA.next_buffer_row()
call_LA.buff[j] = la

call_LA.flush()
self.finalise_partition_array(partition_index, "call_LA")
self.finalise_partition_array(partition_index, call_LA)

def encode_local_allele_fields_partition(self, partition_index):
partition = self.metadata.partitions[partition_index]
la_array = zarr.open_array(
store=self.wip_partition_array_path(partition_index, "call_LA"),
mode="r",
)
field_map = self.schema.field_map()
# We got through the localisable fields one-by-one so that we don't need to
# keep several large arrays in memory at once for each partition.
field_map = self.schema.field_map()
for descriptor in localisable_fields:
if descriptor.array_name not in field_map:
continue
assert field_map[descriptor.array_name].vcf_field is None

array = self.init_partition_array(partition_index, descriptor.array_name)
buff = core.BufferedArray(array, partition.start)
buff = self.init_partition_array(partition_index, descriptor.array_name)
source = self.icf.fields[descriptor.vcf_field].iter_values(
partition.start, partition.stop
)
Expand All @@ -951,14 +949,11 @@ def encode_local_allele_fields_partition(self, partition_index):
value = descriptor.sanitise(raw_value, 2, raw_value.dtype)
j = buff.next_buffer_row()
buff.buff[j] = descriptor.convert(value, la)
buff.flush()
self.finalise_partition_array(partition_index, "array_name")
self.finalise_partition_array(partition_index, buff)

def encode_alleles_partition(self, partition_index):
array_name = "variant_allele"
alleles_array = self.init_partition_array(partition_index, array_name)
alleles = self.init_partition_array(partition_index, "variant_allele")
partition = self.metadata.partitions[partition_index]
alleles = core.BufferedArray(alleles_array, partition.start)
ref_field = self.icf.fields["REF"]
alt_field = self.icf.fields["ALT"]

Expand All @@ -970,16 +965,12 @@ def encode_alleles_partition(self, partition_index):
alleles.buff[j, :] = constants.STR_FILL
alleles.buff[j, 0] = ref[0]
alleles.buff[j, 1 : 1 + len(alt)] = alt
alleles.flush()

self.finalise_partition_array(partition_index, array_name)
self.finalise_partition_array(partition_index, alleles)

def encode_id_partition(self, partition_index):
vid_array = self.init_partition_array(partition_index, "variant_id")
vid_mask_array = self.init_partition_array(partition_index, "variant_id_mask")
vid = self.init_partition_array(partition_index, "variant_id")
vid_mask = self.init_partition_array(partition_index, "variant_id_mask")
partition = self.metadata.partitions[partition_index]
vid = core.BufferedArray(vid_array, partition.start)
vid_mask = core.BufferedArray(vid_mask_array, partition.start)
field = self.icf.fields["ID"]

for value in field.iter_values(partition.start, partition.stop):
Expand All @@ -992,18 +983,14 @@ def encode_id_partition(self, partition_index):
else:
vid.buff[j] = constants.STR_MISSING
vid_mask.buff[j] = True
vid.flush()
vid_mask.flush()

self.finalise_partition_array(partition_index, "variant_id")
self.finalise_partition_array(partition_index, "variant_id_mask")
self.finalise_partition_array(partition_index, vid)
self.finalise_partition_array(partition_index, vid_mask)

def encode_filters_partition(self, partition_index):
lookup = {filt.id: index for index, filt in enumerate(self.schema.filters)}
array_name = "variant_filter"
array = self.init_partition_array(partition_index, array_name)
var_filter = self.init_partition_array(partition_index, "variant_filter")
partition = self.metadata.partitions[partition_index]
var_filter = core.BufferedArray(array, partition.start)

field = self.icf.fields["FILTERS"]
for value in field.iter_values(partition.start, partition.stop):
Expand All @@ -1016,16 +1003,13 @@ def encode_filters_partition(self, partition_index):
raise ValueError(
f"Filter '{f}' was not defined in the header."
) from None
var_filter.flush()

self.finalise_partition_array(partition_index, array_name)
self.finalise_partition_array(partition_index, var_filter)

def encode_contig_partition(self, partition_index):
lookup = {contig.id: index for index, contig in enumerate(self.schema.contigs)}
array_name = "variant_contig"
array = self.init_partition_array(partition_index, array_name)
contig = self.init_partition_array(partition_index, "variant_contig")
partition = self.metadata.partitions[partition_index]
contig = core.BufferedArray(array, partition.start)
field = self.icf.fields["CHROM"]

for value in field.iter_values(partition.start, partition.stop):
Expand All @@ -1035,9 +1019,8 @@ def encode_contig_partition(self, partition_index):
# will always succeed. However, if anyone ever does hit a KeyError
# here, please do open an issue with a reproducible example!
contig.buff[j] = lookup[value[0]]
contig.flush()

self.finalise_partition_array(partition_index, array_name)
self.finalise_partition_array(partition_index, contig)

#######################
# finalise
Expand Down
Loading