Skip to content

Commit 0415869

Browse files
Will-Tylerjeromekelleher
authored andcommitted
Use NumPy API to compute LAA
1 parent 4969b0c commit 0415869

File tree

2 files changed

+47
-53
lines changed

2 files changed

+47
-53
lines changed

bio2zarr/vcf2zarr/icf.py

Lines changed: 46 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -486,25 +486,54 @@ def sanitise_value_int_2d(buff, j, value):
486486
buff[j, :, : value.shape[1]] = value
487487

488488

489-
def compute_laa_field(variant) -> list[list[int]]:
489+
def compute_laa_field(variant) -> np.ndarray:
490490
sample_count = variant.num_called + variant.num_unknown
491-
laa_val = [set() for _ in range(sample_count)]
491+
alleles = np.empty((sample_count, 0), dtype=int)
492492

493493
if "GT" in variant.FORMAT:
494-
for sample_index, genotype in enumerate(variant.genotypes):
495-
# The last element in the genotype is not an allele.
496-
for allele in genotype[:-1]:
497-
if allele > 0:
498-
laa_val[sample_index].add(allele)
494+
# The last element of each sample's genotype indicates the phasing
495+
# and is not an allele.
496+
genotypes = variant.genotype.array()[:, :-1]
497+
genotypes[genotypes < 0] = 0
498+
alleles = np.concatenate((alleles, genotypes), axis=1)
499499
if "AD" in variant.FORMAT:
500-
for sample_index, ad in enumerate(variant.format("AD")):
501-
# The first depth in AD is for the reference allele.
502-
alt_alleles = set(
503-
allele_index for allele_index, depth in enumerate(ad[1:]) if depth > 0
504-
)
505-
laa_val[sample_index] |= alt_alleles
506500

507-
return [sorted(laa) for laa in laa_val]
501+
def positive_pad(arr, *, length):
502+
positive = np.where(arr > 0)[0]
503+
pad_length = length - len(positive)
504+
return np.pad(positive, (0, pad_length), mode="constant", constant_values=0)
505+
506+
depths = variant.format("AD")
507+
depths = np.apply_along_axis(
508+
positive_pad, axis=1, arr=depths, length=depths.shape[1]
509+
)
510+
alleles = np.concatenate((alleles, depths), axis=1)
511+
if "PL" in variant.FORMAT:
512+
# TODO
513+
pass
514+
515+
max_unique_size = 1
516+
517+
def unique_pad(arr: np.ndarray, *, length: int):
518+
unique = np.unique(arr)
519+
520+
if unique[0] == 0:
521+
# We don't include the reference allele
522+
unique = unique[1:]
523+
524+
nonlocal max_unique_size
525+
max_unique_size = max(max_unique_size, len(unique))
526+
pad_length = length - len(unique)
527+
return np.pad(
528+
unique, (0, pad_length), mode="constant", constant_values=constants.INT_FILL
529+
)
530+
531+
alleles = np.apply_along_axis(
532+
unique_pad, axis=1, arr=alleles, length=alleles.shape[0]
533+
)
534+
alleles = alleles[:, :max_unique_size]
535+
536+
return alleles
508537

509538

510539
missing_value_map = {
@@ -1138,21 +1167,10 @@ def process_partition(self, partition_index):
11381167
for field in format_fields:
11391168
if field.full_name == "FORMAT/LAA":
11401169
laa_val = compute_laa_field(variant)
1141-
# Convert laa_val to a NumPy array
1142-
max_laa_len = max(len(laa) for laa in laa_val)
1143-
# At minimum, we want to have at least one value per sample
1144-
# so that the field is present.
1145-
max_laa_len = max(1, max_laa_len)
1146-
laa_val = [
1147-
sorted(laa)
1148-
+ [constants.INT_FILL] * (max_laa_len - len(laa))
1149-
for laa in laa_val
1150-
]
1151-
laa_val = np.array(laa_val)
11521170
tcw.append("FORMAT/LAA", laa_val)
1153-
continue
1154-
val = variant.format(field.name)
1155-
tcw.append(field.full_name, val)
1171+
else:
1172+
val = variant.format(field.name)
1173+
tcw.append(field.full_name, val)
11561174

11571175
# Note: an issue with updating the progress per variant here like
11581176
# this is that we get a significant pause at the end of the counter

tests/test_vcf_examples.py

Lines changed: 1 addition & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -548,31 +548,7 @@ def test_call_AD(self, ds):
548548
def test_call_LAA(self, ds):
549549
# The shape is (23, 3, 1).
550550
# None of the entries use any alternate alleles.
551-
call_LAA = [
552-
[[-2], [-2], [-2]],
553-
[[-2], [-2], [-2]],
554-
[[-2], [-2], [-2]],
555-
[[-2], [-2], [-2]],
556-
[[-2], [-2], [-2]],
557-
[[-2], [-2], [-2]],
558-
[[-2], [-2], [-2]],
559-
[[-2], [-2], [-2]],
560-
[[-2], [-2], [-2]],
561-
[[-2], [-2], [-2]],
562-
[[-2], [-2], [-2]],
563-
[[-2], [-2], [-2]],
564-
[[-2], [-2], [-2]],
565-
[[-2], [-2], [-2]],
566-
[[-2], [-2], [-2]],
567-
[[-2], [-2], [-2]],
568-
[[-2], [-2], [-2]],
569-
[[-2], [-2], [-2]],
570-
[[-2], [-2], [-2]],
571-
[[-2], [-2], [-2]],
572-
[[-2], [-2], [-2]],
573-
[[-2], [-2], [-2]],
574-
[[-2], [-2], [-2]],
575-
]
551+
call_LAA = np.full((23, 3, 1), -2, dtype=int)
576552
nt.assert_array_equal(ds.call_LAA.values, call_LAA)
577553

578554
def test_call_PID(self, ds):

0 commit comments

Comments
 (0)