Skip to content

Commit 8c28169

Browse files
Initial pass at computing LAA values
1 parent 9b5e967 commit 8c28169

File tree

1 file changed

+75
-0
lines changed

1 file changed

+75
-0
lines changed

bio2zarr/vcf2zarr/vcz.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,8 @@ def convert_local_allele_field_types(fields):
197197
gt = fields_by_name["call_genotype"]
198198
if gt.shape[-1] != 2:
199199
raise ValueError("Local alleles only supported on diploid data")
200+
# TODO check if LAA is already in here
201+
200202
shape = gt.shape[:-1]
201203
chunks = gt.chunks[:-1]
202204

@@ -214,6 +216,7 @@ def convert_local_allele_field_types(fields):
214216
)
215217
pl = fields_by_name.get("call_PL", None)
216218
if pl is not None:
219+
# TODO check if call_LPL is in the list already
217220
pl.name = "call_LPL"
218221
pl.vcf_field = None
219222
pl.shape = (*shape, 3)
@@ -511,6 +514,46 @@ def fromdict(d):
511514
return ret
512515

513516

517+
def compute_laa_field(genotypes, alleles) -> np.ndarray:
518+
"""
519+
Computes the value of the LAA field for each sample given the genotypes
520+
for a variant.
521+
522+
The LAA field is a list of one-based indices into the ALT alleles
523+
that indicates which alternate alleles are observed in the sample.
524+
"""
525+
alt_allele_count = len(alleles) - 1
526+
allele_counts = np.zeros((genotypes.shape[0], len(alleles)), dtype=int)
527+
528+
genotypes = genotypes.clip(0, None)
529+
genotype_allele_counts = np.apply_along_axis(
530+
np.bincount, axis=1, arr=genotypes, minlength=len(alleles)
531+
)
532+
allele_counts += genotype_allele_counts
533+
534+
allele_counts[:, 0] = 0 # We don't count the reference allele
535+
max_row_length = 1
536+
537+
def nonzero_pad(arr: np.ndarray, *, length: int):
538+
nonlocal max_row_length
539+
alleles = arr.nonzero()[0]
540+
max_row_length = max(max_row_length, len(alleles))
541+
pad_length = length - len(alleles)
542+
return np.pad(
543+
alleles,
544+
(0, pad_length),
545+
mode="constant",
546+
constant_values=constants.INT_FILL,
547+
)
548+
549+
alleles = np.apply_along_axis(
550+
nonzero_pad, axis=1, arr=allele_counts, length=max(1, alt_allele_count)
551+
)
552+
alleles = alleles[:, :max_row_length]
553+
554+
return alleles
555+
556+
514557
@dataclasses.dataclass
515558
class VcfZarrWriteSummary(core.JsonDataclass):
516559
num_partitions: int
@@ -543,6 +586,12 @@ def has_genotypes(self):
543586
return True
544587
return False
545588

589+
def has_local_alleles(self):
590+
for field in self.schema.fields:
591+
if field.name == "call_LAA" and field.vcf_field is None:
592+
return True
593+
return False
594+
546595
#######################
547596
# init
548597
#######################
@@ -729,6 +778,8 @@ def encode_partition(self, partition_index):
729778
self.encode_array_partition(array_spec, partition_index)
730779
if self.has_genotypes():
731780
self.encode_genotypes_partition(partition_index)
781+
if self.has_local_alleles():
782+
self.encode_local_alleles_partition(partition_index)
732783

733784
final_path = self.partition_path(partition_index)
734785
logger.info(f"Finalising {partition_index} at {final_path}")
@@ -800,6 +851,30 @@ def encode_genotypes_partition(self, partition_index):
800851
self.finalise_partition_array(partition_index, "call_genotype_mask")
801852
self.finalise_partition_array(partition_index, "call_genotype_phased")
802853

854+
def encode_local_alleles_partition(self, partition_index):
855+
call_LAA_array = self.init_partition_array(partition_index, "call_LAA")
856+
partition = self.metadata.partitions[partition_index]
857+
call_LAA = core.BufferedArray(call_LAA_array, partition.start)
858+
859+
gt_array = zarr.open_array(
860+
store=self.wip_partition_array_path(partition_index, "call_genotype"),
861+
mode="r",
862+
)
863+
alleles_array = zarr.open_array(
864+
store=self.wip_partition_array_path(partition_index, "variant_allele"),
865+
mode="r",
866+
)
867+
for chunk_index in range(gt_array.cdata_shape[0]):
868+
A = alleles_array.blocks[chunk_index]
869+
G = gt_array.blocks[chunk_index]
870+
for alleles, var in zip(A, G):
871+
j = call_LAA.next_buffer_row()
872+
# TODO we should probably compute LAAs by chunk for efficiency
873+
call_LAA.buff[j] = compute_laa_field(var, alleles)
874+
875+
call_LAA.flush()
876+
self.finalise_partition_array(partition_index, "call_LAA")
877+
803878
def encode_alleles_partition(self, partition_index):
804879
array_name = "variant_allele"
805880
alleles_array = self.init_partition_array(partition_index, array_name)

0 commit comments

Comments
 (0)