Skip to content
Merged
10 changes: 3 additions & 7 deletions bio2zarr/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,6 @@ def show_work_summary(work_summary, json):
@compressor
@progress
@worker_processes
@local_alleles
def explode(
vcfs,
icf_path,
Expand All @@ -231,7 +230,6 @@ def explode(
compressor,
progress,
worker_processes,
local_alleles,
):
"""
Convert VCF(s) to intermediate columnar format
Expand All @@ -245,7 +243,6 @@ def explode(
column_chunk_size=column_chunk_size,
compressor=get_compressor(compressor),
show_progress=progress,
local_alleles=local_alleles,
)


Expand All @@ -260,7 +257,6 @@ def explode(
@verbose
@progress
@worker_processes
@local_alleles
def dexplode_init(
vcfs,
icf_path,
Expand All @@ -272,7 +268,6 @@ def dexplode_init(
verbose,
progress,
worker_processes,
local_alleles,
):
"""
Initial step for distributed conversion of VCF(s) to intermediate columnar format
Expand All @@ -289,7 +284,6 @@ def dexplode_init(
worker_processes=worker_processes,
compressor=get_compressor(compressor),
show_progress=progress,
local_alleles=local_alleles,
)
show_work_summary(work_summary, json)

Expand Down Expand Up @@ -340,7 +334,8 @@ def inspect(path, verbose):
@icf_path
@variants_chunk_size
@samples_chunk_size
def mkschema(icf_path, variants_chunk_size, samples_chunk_size):
@local_alleles
def mkschema(icf_path, variants_chunk_size, samples_chunk_size, local_alleles):
"""
Generate a schema for zarr encoding
"""
Expand All @@ -350,6 +345,7 @@ def mkschema(icf_path, variants_chunk_size, samples_chunk_size):
stream,
variants_chunk_size=variants_chunk_size,
samples_chunk_size=samples_chunk_size,
local_alleles=local_alleles,
)


Expand Down
21 changes: 21 additions & 0 deletions bio2zarr/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,27 @@ def chunk_aligned_slices(z, n, max_chunks=None):
return slices


def first_dim_slice_iter(z, start, stop):
"""
Efficiently iterate over the specified slice of the first dimension of the zarr
array z.
"""
chunk_size = z.chunks[0]
first_chunk = start // chunk_size
last_chunk = (stop // chunk_size) + (stop % chunk_size != 0)
for chunk in range(first_chunk, last_chunk):
Z = z.blocks[chunk]
chunk_start = chunk * chunk_size
chunk_stop = chunk_start + chunk_size
slice_start = None
if start > chunk_start:
slice_start = start - chunk_start
slice_stop = None
if stop < chunk_stop:
slice_stop = stop - chunk_start
yield from Z[slice_start:slice_stop]


def du(path):
"""
Return the total bytes stored at this path.
Expand Down
175 changes: 3 additions & 172 deletions bio2zarr/vcf2zarr/icf.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ def make_field_def(name, vcf_type, vcf_number):
return fields


def scan_vcf(path, target_num_partitions, *, local_alleles):
def scan_vcf(path, target_num_partitions):
with vcf_utils.IndexedVcf(path) as indexed_vcf:
vcf = indexed_vcf.vcf
filters = []
Expand All @@ -237,10 +237,6 @@ def scan_vcf(path, target_num_partitions, *, local_alleles):
pass_filter = filters.pop(pass_index)
filters.insert(0, pass_filter)

# Indicates whether vcf2zarr can introduce local alleles
can_localize = False
should_add_laa_field = True
should_add_lpl_field = True
fields = fixed_vcf_field_definitions()
for h in vcf.header_iter():
if h["HeaderType"] in ["INFO", "FORMAT"]:
Expand All @@ -249,36 +245,6 @@ def scan_vcf(path, target_num_partitions, *, local_alleles):
field.vcf_type = "Integer"
field.vcf_number = "."
fields.append(field)
if field.category == "FORMAT":
if field.name == "PL":
can_localize = True
if field.name == "LAA":
should_add_laa_field = False
if field.name == "LPL":
should_add_lpl_field = False

if local_alleles and can_localize:
if should_add_laa_field:
laa_field = VcfField(
category="FORMAT",
name="LAA",
vcf_type="Integer",
vcf_number=".",
description="1-based indices into ALT, indicating which alleles"
" are relevant (local) for the current sample",
summary=VcfFieldSummary(),
)
fields.append(laa_field)
if should_add_lpl_field:
lpl_field = VcfField(
category="FORMAT",
name="LPL",
vcf_type="Integer",
vcf_number="LG",
description="Local-allele representation of PL",
summary=VcfFieldSummary(),
)
fields.append(lpl_field)

try:
contig_lengths = vcf.seqlens
Expand Down Expand Up @@ -315,14 +281,7 @@ def scan_vcf(path, target_num_partitions, *, local_alleles):
return metadata, vcf.raw_header


def scan_vcfs(
paths,
show_progress,
target_num_partitions,
worker_processes=1,
*,
local_alleles,
):
def scan_vcfs(paths, show_progress, target_num_partitions, worker_processes=1):
logger.info(
f"Scanning {len(paths)} VCFs attempting to split into {target_num_partitions}"
f" partitions."
Expand All @@ -346,7 +305,6 @@ def scan_vcfs(
scan_vcf,
path,
max(1, target_num_partitions // len(paths)),
local_alleles=local_alleles,
)
results = list(pwm.results_as_completed())

Expand Down Expand Up @@ -505,104 +463,6 @@ def sanitise_value_int_2d(buff, j, value):
buff[j, :, : value.shape[1]] = value


def compute_laa_field(variant) -> np.ndarray:
"""
Computes the value of the LAA field for each sample given a variant.

The LAA field is a list of one-based indices into the ALT alleles
that indicates which alternate alleles are observed in the sample.

This method infers which alleles are observed from the GT field.
"""
sample_count = variant.num_called + variant.num_unknown
alt_allele_count = len(variant.ALT)
allele_count = alt_allele_count + 1
allele_counts = np.zeros((sample_count, allele_count), dtype=int)

if "GT" in variant.FORMAT:
# The last element of each sample's genotype indicates the phasing
# and is not an allele.
genotypes = variant.genotype.array()[:, :-1]
genotypes.clip(0, None, out=genotypes)
genotype_allele_counts = np.apply_along_axis(
np.bincount, axis=1, arr=genotypes, minlength=allele_count
)
allele_counts += genotype_allele_counts

allele_counts[:, 0] = 0 # We don't count the reference allele
max_row_length = 1

def nonzero_pad(arr: np.ndarray, *, length: int):
nonlocal max_row_length
alleles = arr.nonzero()[0]
max_row_length = max(max_row_length, len(alleles))
pad_length = length - len(alleles)
return np.pad(
alleles,
(0, pad_length),
mode="constant",
constant_values=constants.INT_FILL,
)

alleles = np.apply_along_axis(
nonzero_pad, axis=1, arr=allele_counts, length=max(1, alt_allele_count)
)
alleles = alleles[:, :max_row_length]

return alleles


def compute_lpl_field(variant, laa_val: np.ndarray) -> np.ndarray:
assert laa_val is not None

la_val = np.zeros((laa_val.shape[0], laa_val.shape[1] + 1), dtype=laa_val.dtype)
la_val[:, 1:] = laa_val
ploidy = variant.ploidy

if "PL" not in variant.FORMAT:
sample_count = variant.num_called + variant.num_unknown
local_allele_count = la_val.shape[1]

if ploidy == 1:
local_genotype_count = local_allele_count
elif ploidy == 2:
local_genotype_count = local_allele_count * (local_allele_count + 1) // 2
else:
raise ValueError(f"Cannot handle ploidy = {ploidy}")

return np.full((sample_count, local_genotype_count), constants.INT_MISSING)

# Compute a and b
if ploidy == 1:
a = la_val
b = np.zeros_like(la_val)
elif ploidy == 2:
repeats = np.arange(1, la_val.shape[1] + 1)
b = np.repeat(la_val, repeats, axis=1)
arange_tile = np.tile(np.arange(la_val.shape[1]), (la_val.shape[1], 1))
tril_indices = np.tril_indices_from(arange_tile)
a_index = np.tile(arange_tile[tril_indices], (b.shape[0], 1))
row_index = np.arange(la_val.shape[0]).reshape(-1, 1)
a = la_val[row_index, a_index]
else:
raise ValueError(f"Cannot handle ploidy = {ploidy}")

# Compute n, the local indices of the PL field
n = (b * (b + 1) / 2 + a).astype(int)

pl_val = variant.format("PL")
pl_val[pl_val == constants.VCF_INT_MISSING] = constants.INT_MISSING
# When the PL value is missing in all samples, pl_val has shape (sample_count, 1).
# In that case, we need to broadcast the PL value.
if pl_val.shape[1] < n.shape[1]:
pl_val = np.broadcast_to(pl_val, n.shape)
row_index = np.arange(pl_val.shape[0]).reshape(-1, 1)
lpl_val = pl_val[row_index, n]
lpl_val[b == constants.INT_FILL] = constants.INT_FILL

return lpl_val


missing_value_map = {
"Integer": constants.INT_MISSING,
"Float": constants.FLOAT32_MISSING,
Expand Down Expand Up @@ -1107,14 +967,11 @@ def init(
target_num_partitions=None,
show_progress=False,
compressor=None,
local_alleles=None,
):
if self.path.exists():
raise ValueError(f"ICF path already exists: {self.path}")
if compressor is None:
compressor = ICF_DEFAULT_COMPRESSOR
if local_alleles is None:
local_alleles = False
vcfs = [pathlib.Path(vcf) for vcf in vcfs]
target_num_partitions = max(target_num_partitions, len(vcfs))

Expand All @@ -1124,7 +981,6 @@ def init(
worker_processes=worker_processes,
show_progress=show_progress,
target_num_partitions=target_num_partitions,
local_alleles=local_alleles,
)
check_field_clobbering(icf_metadata)
self.metadata = icf_metadata
Expand Down Expand Up @@ -1207,17 +1063,6 @@ def process_partition(self, partition_index):
else:
format_fields.append(field)

format_field_names = [format_field.name for format_field in format_fields]
if "LAA" in format_field_names and "LPL" in format_field_names:
laa_index = format_field_names.index("LAA")
lpl_index = format_field_names.index("LPL")
# LAA needs to come before LPL
if lpl_index < laa_index:
format_fields[laa_index], format_fields[lpl_index] = (
format_fields[lpl_index],
format_fields[laa_index],
)

last_position = None
with IcfPartitionWriter(
self.metadata,
Expand Down Expand Up @@ -1245,18 +1090,8 @@ def process_partition(self, partition_index):
else:
val = variant.genotype.array()
tcw.append("FORMAT/GT", val)
laa_val = None
for field in format_fields:
if field.name == "LAA":
if "LAA" not in variant.FORMAT:
laa_val = compute_laa_field(variant)
else:
laa_val = variant.format("LAA")
val = laa_val
elif field.name == "LPL" and "LPL" not in variant.FORMAT:
val = compute_lpl_field(variant, laa_val)
else:
val = variant.format(field.name)
val = variant.format(field.name)
tcw.append(field.full_name, val)

# Note: an issue with updating the progress per variant here like
Expand Down Expand Up @@ -1352,7 +1187,6 @@ def explode(
worker_processes=1,
show_progress=False,
compressor=None,
local_alleles=None,
):
writer = IntermediateColumnarFormatWriter(icf_path)
writer.init(
Expand All @@ -1363,7 +1197,6 @@ def explode(
show_progress=show_progress,
column_chunk_size=column_chunk_size,
compressor=compressor,
local_alleles=local_alleles,
)
writer.explode(worker_processes=worker_processes, show_progress=show_progress)
writer.finalise()
Expand All @@ -1379,7 +1212,6 @@ def explode_init(
worker_processes=1,
show_progress=False,
compressor=None,
local_alleles=None,
):
writer = IntermediateColumnarFormatWriter(icf_path)
return writer.init(
Expand All @@ -1389,7 +1221,6 @@ def explode_init(
show_progress=show_progress,
column_chunk_size=column_chunk_size,
compressor=compressor,
local_alleles=local_alleles,
)


Expand Down
Loading
Loading