Skip to content

Commit 14174b7

Browse files
Initial pass at computing LAA values
1 parent 5e52f6d commit 14174b7

File tree

1 file changed

+75
-0
lines changed

1 file changed

+75
-0
lines changed

bio2zarr/vcf2zarr/vcz.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,8 @@ def convert_local_allele_field_types(fields):
197197
gt = fields_by_name["call_genotype"]
198198
if gt.shape[-1] != 2:
199199
raise ValueError("Local alleles only supported on diploid data")
200+
# TODO check if LAA is already in here
201+
200202
shape = gt.shape[:-1]
201203
chunks = gt.chunks[:-1]
202204

@@ -214,6 +216,7 @@ def convert_local_allele_field_types(fields):
214216
)
215217
pl = fields_by_name.get("call_PL", None)
216218
if pl is not None:
219+
# TODO check if call_LPL is in the list already
217220
pl.name = "call_LPL"
218221
pl.vcf_field = None
219222
pl.shape = (*shape, 3)
@@ -510,6 +513,46 @@ def fromdict(d):
510513
return ret
511514

512515

516+
def compute_laa_field(genotypes, alleles) -> np.ndarray:
517+
"""
518+
Computes the value of the LAA field for each sample given the genotypes
519+
for a variant.
520+
521+
The LAA field is a list of one-based indices into the ALT alleles
522+
that indicates which alternate alleles are observed in the sample.
523+
"""
524+
alt_allele_count = len(alleles) - 1
525+
allele_counts = np.zeros((genotypes.shape[0], len(alleles)), dtype=int)
526+
527+
genotypes = genotypes.clip(0, None)
528+
genotype_allele_counts = np.apply_along_axis(
529+
np.bincount, axis=1, arr=genotypes, minlength=len(alleles)
530+
)
531+
allele_counts += genotype_allele_counts
532+
533+
allele_counts[:, 0] = 0 # We don't count the reference allele
534+
max_row_length = 1
535+
536+
def nonzero_pad(arr: np.ndarray, *, length: int):
537+
nonlocal max_row_length
538+
alleles = arr.nonzero()[0]
539+
max_row_length = max(max_row_length, len(alleles))
540+
pad_length = length - len(alleles)
541+
return np.pad(
542+
alleles,
543+
(0, pad_length),
544+
mode="constant",
545+
constant_values=constants.INT_FILL,
546+
)
547+
548+
alleles = np.apply_along_axis(
549+
nonzero_pad, axis=1, arr=allele_counts, length=max(1, alt_allele_count)
550+
)
551+
alleles = alleles[:, :max_row_length]
552+
553+
return alleles
554+
555+
513556
@dataclasses.dataclass
514557
class VcfZarrWriteSummary(core.JsonDataclass):
515558
num_partitions: int
@@ -542,6 +585,12 @@ def has_genotypes(self):
542585
return True
543586
return False
544587

588+
def has_local_alleles(self):
589+
for field in self.schema.fields:
590+
if field.name == "call_LAA" and field.vcf_field is None:
591+
return True
592+
return False
593+
545594
#######################
546595
# init
547596
#######################
@@ -734,6 +783,8 @@ def encode_partition(self, partition_index):
734783
self.encode_array_partition(array_spec, partition_index)
735784
if self.has_genotypes():
736785
self.encode_genotypes_partition(partition_index)
786+
if self.has_local_alleles():
787+
self.encode_local_alleles_partition(partition_index)
737788

738789
final_path = self.partition_path(partition_index)
739790
logger.info(f"Finalising {partition_index} at {final_path}")
@@ -805,6 +856,30 @@ def encode_genotypes_partition(self, partition_index):
805856
self.finalise_partition_array(partition_index, "call_genotype_mask")
806857
self.finalise_partition_array(partition_index, "call_genotype_phased")
807858

859+
def encode_local_alleles_partition(self, partition_index):
860+
call_LAA_array = self.init_partition_array(partition_index, "call_LAA")
861+
partition = self.metadata.partitions[partition_index]
862+
call_LAA = core.BufferedArray(call_LAA_array, partition.start)
863+
864+
gt_array = zarr.open_array(
865+
store=self.wip_partition_array_path(partition_index, "call_genotype"),
866+
mode="r",
867+
)
868+
alleles_array = zarr.open_array(
869+
store=self.wip_partition_array_path(partition_index, "variant_allele"),
870+
mode="r",
871+
)
872+
for chunk_index in range(gt_array.cdata_shape[0]):
873+
A = alleles_array.blocks[chunk_index]
874+
G = gt_array.blocks[chunk_index]
875+
for alleles, var in zip(A, G):
876+
j = call_LAA.next_buffer_row()
877+
# TODO we should probably compute LAAs by chunk for efficiency
878+
call_LAA.buff[j] = compute_laa_field(var, alleles)
879+
880+
call_LAA.flush()
881+
self.finalise_partition_array(partition_index, "call_LAA")
882+
808883
def encode_alleles_partition(self, partition_index):
809884
array_name = "variant_allele"
810885
alleles_array = self.init_partition_array(partition_index, array_name)

0 commit comments

Comments
 (0)