Skip to content

Commit 11a15d3

Browse files
Fixup genotypes iterator and call_LAA/call_LAD
1 parent 92d19c8 commit 11a15d3

File tree

3 files changed

+82
-20
lines changed

3 files changed

+82
-20
lines changed

bio2zarr/core.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,27 @@ def chunk_aligned_slices(z, n, max_chunks=None):
6363
return slices
6464

6565

66+
def first_dim_slice_iter(z, start, stop):
67+
"""
68+
Efficiently iterate over the specified slice of the first dimension of the zarr
69+
array z.
70+
"""
71+
chunk_size = z.chunks[0]
72+
first_chunk = start // chunk_size
73+
last_chunk = (stop // chunk_size) + (stop % chunk_size != 0)
74+
for chunk in range(first_chunk, last_chunk):
75+
Z = z.blocks[chunk]
76+
chunk_start = chunk * chunk_size
77+
chunk_stop = chunk_start + chunk_size
78+
slice_start = None
79+
if start > chunk_start:
80+
slice_start = start - chunk_start
81+
slice_stop = None
82+
if stop < chunk_stop:
83+
slice_stop = stop - chunk_start
84+
yield from Z[slice_start:slice_stop]
85+
86+
6687
def du(path):
6788
"""
6889
Return the total bytes stored at this path.

bio2zarr/vcf2zarr/vcz.py

Lines changed: 26 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -550,15 +550,20 @@ def compute_laa_field(genotypes):
550550

551551

552552
def compute_lad_field(ad, laa):
553-
lad = np.full((ad.shape[0], 2), -2, dtype=ad.dtype)
554-
ref_ref = np.where((laa[:, 0] == -2) & (laa[:, 1] == -2))[0]
555-
lad[ref_ref, 0] = ad[ref_ref, 0]
556-
ref_alt = np.where((laa[:, 0] != -2) & (laa[:, 1] == -2))[0]
557-
lad[ref_alt, 0] = ad[ref_alt, 0]
558-
lad[ref_alt, 1] = ad[ref_alt, laa[ref_alt, 0]]
559-
alt_alt = np.where((laa[:, 0] != -2) & (laa[:, 1] != -2))[0]
560-
lad[alt_alt, 0] = ad[alt_alt, laa[alt_alt, 0]]
561-
lad[alt_alt, 1] = ad[alt_alt, laa[alt_alt, 1]]
553+
try:
554+
lad = np.full((ad.shape[0], 2), -2, dtype=ad.dtype)
555+
ref_ref = np.where((laa[:, 0] == -2) & (laa[:, 1] == -2))[0]
556+
lad[ref_ref, 0] = ad[ref_ref, 0]
557+
ref_alt = np.where((laa[:, 0] != -2) & (laa[:, 1] == -2))[0]
558+
lad[ref_alt, 0] = ad[ref_alt, 0]
559+
lad[ref_alt, 1] = ad[ref_alt, laa[ref_alt, 0]]
560+
alt_alt = np.where((laa[:, 0] != -2) & (laa[:, 1] != -2))[0]
561+
lad[alt_alt, 0] = ad[alt_alt, laa[alt_alt, 0]]
562+
lad[alt_alt, 1] = ad[alt_alt, laa[alt_alt, 1]]
563+
except Exception as e:
564+
print("ad = ", ad)
565+
print("laa = ", laa)
566+
raise e
562567
return lad
563568

564569

@@ -875,21 +880,22 @@ def encode_local_alleles_partition(self, partition_index):
875880
call_AD_source = self.icf.fields["FORMAT/AD"].iter_values(
876881
partition.start, partition.stop
877882
)
878-
879883
gt_array = zarr.open_array(
880884
store=self.wip_partition_array_path(partition_index, "call_genotype"),
881885
mode="r",
882886
)
883-
for chunk_index in range(gt_array.cdata_shape[0]):
884-
for genotypes in gt_array.blocks[chunk_index]:
885-
laa = compute_laa_field(genotypes)
886-
j = call_LAA.next_buffer_row()
887-
call_LAA.buff[j] = laa
888-
889-
ad = next(call_AD_source)
890-
j = call_LAD.next_buffer_row()
891-
lad = compute_lad_field(ad, laa)
892-
call_LAD.buff[j] = lad
887+
for genotypes in core.first_dim_slice_iter(
888+
gt_array, partition.start, partition.stop
889+
):
890+
laa = compute_laa_field(genotypes)
891+
j = call_LAA.next_buffer_row()
892+
call_LAA.buff[j] = laa
893+
894+
ad = next(call_AD_source)
895+
k = call_LAD.next_buffer_row()
896+
assert j == k
897+
lad = compute_lad_field(ad, laa)
898+
call_LAD.buff[j] = lad
893899

894900
call_LAA.flush()
895901
self.finalise_partition_array(partition_index, "call_LAA")

tests/test_core.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,41 @@ def test_5_chunk_1(self, n, expected):
194194
assert result == expected
195195

196196

197+
class TestFirstDimSliceIter:
198+
@pytest.mark.parametrize("chunk_size", [1, 3, 4, 5])
199+
@pytest.mark.parametrize(
200+
("size", "start", "stop"),
201+
[
202+
(10, 0, 4),
203+
(10, 0, 8),
204+
(10, 0, 10),
205+
(10, 4, 4),
206+
(10, 4, 8),
207+
(10, 4, 10),
208+
(10, 0, 5),
209+
(10, 0, 3),
210+
(10, 0, 9),
211+
(10, 1, 5),
212+
(10, 1, 1),
213+
(10, 1, 2),
214+
(10, 1, 3),
215+
(10, 1, 4),
216+
(10, 1, 10),
217+
(10, 5, 5),
218+
(10, 5, 6),
219+
(10, 5, 7),
220+
(5, 0, 5),
221+
(5, 1, 1),
222+
(5, 1, 3),
223+
],
224+
)
225+
def test_examples(self, chunk_size, size, start, stop):
226+
a = np.arange(size, dtype=int)
227+
z = zarr.empty(size, chunks=chunk_size, dtype=int)
228+
z[:] = a
229+
assert list(core.first_dim_slice_iter(z, start, stop)) == list(a[start:stop])
230+
231+
197232
@pytest.mark.skipif(sys.platform != "linux", reason="Only valid on Linux")
198233
@pytest.mark.parametrize(
199234
("path", "expected"),

0 commit comments

Comments
 (0)