Skip to content

Commit 737a2ae

Browse files
committed
WIP - switch to iterators from sources
1 parent 8112626 commit 737a2ae

File tree

5 files changed

+197
-202
lines changed

5 files changed

+197
-202
lines changed

bio2zarr/plink.py

Lines changed: 15 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ def __init__(self, path):
1717
self.num_samples = len(self.samples)
1818
self.root_attrs = {}
1919

20-
def write_alleles_to_buffered_array(self, ba, start, stop):
20+
def iter_alleles(self, start, stop, num_alleles):
2121
ref_field = self.bed.allele_1
2222
alt_field = self.bed.allele_2
2323

@@ -26,35 +26,23 @@ def write_alleles_to_buffered_array(self, ba, start, stop):
2626
ref_field[start:stop],
2727
alt_field[start:stop],
2828
):
29-
j = ba.next_buffer_row()
30-
ba.buff[j, :] = constants.STR_FILL
31-
ba.buff[j, 0] = ref
32-
ba.buff[j, 1] = alt
29+
alleles = np.full(num_alleles, constants.STR_FILL, dtype="O")
30+
alleles[0] = ref
31+
alleles[1 : 1 + len(alt)] = alt
32+
yield alleles
3333

34-
def write_other_field_to_buffered_array(self, ba, field_name, start, stop):
34+
def iter_field(self, field_name, shape, start, stop):
3535
data = {
36-
"POS": self.bed.bp_position,
36+
"position": self.bed.bp_position,
3737
}[field_name]
3838
for value in data[start:stop]:
39-
j = ba.next_buffer_row()
40-
ba.buff[j] = value
41-
42-
def write_genotypes_to_buffered_array(self, gt, gt_phased, start, stop):
43-
bed_chunk = self.bed.read(slice(start, stop), dtype=np.int8).T
44-
# Iterate through each sample's genotypes
45-
for values in bed_chunk:
46-
# Write to genotype array
47-
j = gt.next_buffer_row()
48-
g = np.zeros_like(gt.buff[j])
49-
g[values == -127] = -1 # Missing values
50-
g[values == 0] = [1, 1] # Homozygous ALT (2 in PLINK)
51-
g[values == 1] = [1, 0] # Heterozygous (1 in PLINK)
52-
g[values == 2] = [0, 0] # Homozygous REF (0 in PLINK)
53-
gt.buff[j] = g
54-
55-
# Write to phased array (PLINK data is unphased)
56-
j_phased = gt_phased.next_buffer_row()
57-
gt_phased.buff[j_phased] = False
39+
yield value
40+
41+
def iter_genotypes(self, start, stop):
42+
gt_calls = self.bed.gts.values[start:stop]
43+
phased = np.zeros_like(gt_calls, dtype=bool)
44+
for idx in range(len(gt_calls)):
45+
yield gt_calls[idx], phased[idx]
5846

5947

6048
# Import here to avoid circular import
@@ -82,7 +70,7 @@ def generate_schema(
8270

8371
array_specs = [
8472
schema.ZarrArraySpec.new(
85-
vcf_field="POS",
73+
vcf_field="position",
8674
name="variant_position",
8775
dtype="i4",
8876
shape=[m],

bio2zarr/vcf2zarr/icf.py

Lines changed: 158 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,127 @@
1616
from bio2zarr import schema, zarr_utils
1717

1818
from .. import constants, core, provenance, vcf_utils
19+
from functools import partial
1920

2021
logger = logging.getLogger(__name__)
2122

23+
def sanitise_value_bool(shape, value):
24+
x = True
25+
if value is None:
26+
x = False
27+
return x
28+
29+
30+
def sanitise_value_float_scalar(shape, value):
31+
x = value
32+
if value is None:
33+
x = [constants.FLOAT32_MISSING]
34+
return x[0]
35+
36+
37+
def sanitise_value_int_scalar(shape, value):
38+
x = value
39+
if value is None:
40+
x = [constants.INT_MISSING]
41+
else:
42+
x = sanitise_int_array(value, ndmin=1, dtype=np.int32)
43+
return x[0]
44+
45+
46+
def sanitise_value_string_scalar(shape, value):
47+
if value is None:
48+
return "."
49+
else:
50+
return value[0]
51+
52+
53+
def sanitise_value_string_1d(shape, value):
54+
if value is None:
55+
return np.full(shape, ".", dtype='O')
56+
else:
57+
value = drop_empty_second_dim(value)
58+
result = np.full(shape, "", dtype=value.dtype)
59+
result[:value.shape[0]] = value
60+
return result
61+
62+
63+
def sanitise_value_string_2d(shape, value):
64+
if value is None:
65+
return np.full(shape, ".", dtype='O')
66+
else:
67+
result = np.full(shape, "", dtype='O')
68+
if value.ndim == 2:
69+
result[:value.shape[0], :value.shape[1]] = value
70+
else:
71+
# Convert 1D array into 2D with appropriate shape
72+
for k, val in enumerate(value):
73+
result[k, :len(val)] = val
74+
return result
75+
76+
77+
def drop_empty_second_dim(value):
78+
assert len(value.shape) == 1 or value.shape[1] == 1
79+
if len(value.shape) == 2 and value.shape[1] == 1:
80+
value = value[..., 0]
81+
return value
82+
83+
def sanitise_value_float_1d(shape, value):
84+
if value is None:
85+
return np.full(shape, constants.FLOAT32_MISSING)
86+
else:
87+
value = np.array(value, ndmin=1, dtype=np.float32, copy=True)
88+
# numpy will map None values to Nan, but we need a
89+
# specific NaN
90+
value[np.isnan(value)] = constants.FLOAT32_MISSING
91+
value = drop_empty_second_dim(value)
92+
result = np.full(shape, constants.FLOAT32_FILL, dtype=np.float32)
93+
result[:value.shape[0]] = value
94+
print(result)
95+
return result
96+
97+
def sanitise_value_float_2d(shape, value):
98+
if value is None:
99+
return np.full(shape, constants.FLOAT32_MISSING)
100+
else:
101+
value = np.array(value, ndmin=2, dtype=np.float32, copy=True)
102+
result = np.full(shape, constants.FLOAT32_FILL, dtype=np.float32)
103+
result[:, :value.shape[1]] = value
104+
print(result)
105+
return result
106+
107+
108+
def sanitise_int_array(value, ndmin, dtype):
109+
if isinstance(value, tuple):
110+
value = [
111+
constants.VCF_INT_MISSING if x is None else x for x in value
112+
] # NEEDS TEST
113+
value = np.array(value, ndmin=ndmin, copy=True)
114+
value[value == constants.VCF_INT_MISSING] = -1
115+
value[value == constants.VCF_INT_FILL] = -2
116+
# TODO watch out for clipping here!
117+
return value.astype(dtype)
118+
119+
120+
def sanitise_value_int_1d(shape, value):
121+
if value is None:
122+
return np.full(shape, -1)
123+
else:
124+
value = sanitise_int_array(value, 1, np.int32)
125+
value = drop_empty_second_dim(value)
126+
result = np.full(shape, -2, dtype=np.int32)
127+
result[:value.shape[0]] = value
128+
return result
129+
130+
131+
def sanitise_value_int_2d(shape, value):
132+
if value is None:
133+
return np.full(shape, -1)
134+
else:
135+
value = sanitise_int_array(value, 2, np.int32)
136+
result = np.full(shape, -2, dtype=np.int32)
137+
result[:, :value.shape[1]] = value
138+
return result
139+
22140

23141
@dataclasses.dataclass
24142
class VcfFieldSummary(core.JsonDataclass):
@@ -572,35 +690,41 @@ def values(self):
572690

573691
def sanitiser_factory(self, shape):
574692
"""
575-
Return a function that sanitised values from this column
576-
and writes into a buffer of the specified shape.
693+
Return a function that sanitises values from this column
694+
and returns a properly formatted array with the specified shape.
695+
696+
Args:
697+
shape: The shape of the target buffer, used to determine how to format the output
698+
699+
Returns:
700+
A function that takes a value and returns a sanitised version
577701
"""
578-
assert len(shape) <= 3
702+
assert len(shape) <= 2
579703
if self.vcf_field.vcf_type == "Flag":
580-
assert len(shape) == 1
581-
return zarr_utils.sanitise_value_bool
704+
assert len(shape) == 0
705+
return partial(sanitise_value_bool, shape)
582706
elif self.vcf_field.vcf_type == "Float":
583-
if len(shape) == 1:
584-
return zarr_utils.sanitise_value_float_scalar
585-
elif len(shape) == 2:
586-
return zarr_utils.sanitise_value_float_1d
707+
if len(shape) == 0:
708+
return partial(sanitise_value_float_scalar, shape)
709+
elif len(shape) == 1:
710+
return partial(sanitise_value_float_1d, shape)
587711
else:
588-
return zarr_utils.sanitise_value_float_2d
712+
return partial(sanitise_value_float_2d, shape)
589713
elif self.vcf_field.vcf_type == "Integer":
590-
if len(shape) == 1:
591-
return zarr_utils.sanitise_value_int_scalar
592-
elif len(shape) == 2:
593-
return zarr_utils.sanitise_value_int_1d
714+
if len(shape) == 0:
715+
return partial(sanitise_value_int_scalar, shape)
716+
elif len(shape) == 1:
717+
return partial(sanitise_value_int_1d, shape)
594718
else:
595-
return zarr_utils.sanitise_value_int_2d
719+
return partial(sanitise_value_int_2d, shape)
596720
else:
597721
assert self.vcf_field.vcf_type in ("String", "Character")
598-
if len(shape) == 1:
599-
return zarr_utils.sanitise_value_string_scalar
600-
elif len(shape) == 2:
601-
return zarr_utils.sanitise_value_string_1d
722+
if len(shape) == 0:
723+
return partial(sanitise_value_string_scalar, shape)
724+
elif len(shape) == 1:
725+
return partial(sanitise_value_string_1d, shape)
602726
else:
603-
return zarr_utils.sanitise_value_string_2d
727+
return partial(sanitise_value_string_2d, shape)
604728

605729

606730
@dataclasses.dataclass
@@ -790,40 +914,33 @@ def root_attrs(self):
790914
"vcf_header": self.vcf_header,
791915
}
792916

793-
def write_alleles_to_buffered_array(self, ba, start, stop):
917+
def iter_alleles(self, start, stop, num_alleles):
794918
ref_field = self.fields["REF"]
795919
alt_field = self.fields["ALT"]
796920

797921
for ref, alt in zip(
798922
ref_field.iter_values(start, stop),
799923
alt_field.iter_values(start, stop),
800924
):
801-
j = ba.next_buffer_row()
802-
ba.buff[j, :] = constants.STR_FILL
803-
ba.buff[j, 0] = ref[0]
804-
ba.buff[j, 1 : 1 + len(alt)] = alt
925+
alleles = np.full(num_alleles, constants.STR_FILL, dtype="O")
926+
alleles[0] = ref[0]
927+
alleles[1 : 1 + len(alt)] = alt
928+
yield alleles
805929

806-
def write_other_field_to_buffered_array(self, ba, field_name, start, stop):
930+
def iter_field(self, field_name, shape, start, stop):
807931
source_field = self.fields[field_name]
808-
sanitiser = source_field.sanitiser_factory(ba.buff.shape)
809-
932+
sanitiser = source_field.sanitiser_factory(shape)
810933
for value in source_field.iter_values(start, stop):
811-
# We write directly into the buffer in the sanitiser function
812-
# to make it easier to reason about dimension padding
813-
j = ba.next_buffer_row()
814-
sanitiser(ba.buff, j, value)
934+
yield sanitiser(value)
815935

816-
def write_genotypes_to_buffered_array(self, gt, gt_phased, start, stop):
936+
def iter_genotypes(self, shape, start, stop):
817937
source_field = self.fields["FORMAT/GT"]
818938
for value in source_field.iter_values(start, stop):
819-
j = gt.next_buffer_row()
820-
zarr_utils.sanitise_value_int_2d(
821-
gt.buff, j, value[:, :-1] if value is not None else None
822-
)
823-
j = gt_phased.next_buffer_row()
824-
zarr_utils.sanitise_value_int_1d(
825-
gt_phased.buff, j, value[:, -1] if value is not None else None
826-
)
939+
genotypes = value[:, :-1] if value is not None else None
940+
phased = value[:, -1] if value is not None else None
941+
sanitised_genotypes = sanitise_value_int_2d(shape, genotypes)
942+
sanitised_phased = sanitise_value_int_1d(shape[:-1], phased)
943+
yield sanitised_genotypes, sanitised_phased
827944

828945

829946
@dataclasses.dataclass

bio2zarr/writer.py

Lines changed: 22 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -87,10 +87,10 @@ class LocalisableFieldDescriptor:
8787

8888
localisable_fields = [
8989
LocalisableFieldDescriptor(
90-
"call_LAD", "FORMAT/AD", zarr_utils.sanitise_int_array, compute_lad_field
90+
"call_LAD", "FORMAT/AD", icf.sanitise_int_array, compute_lad_field
9191
),
9292
LocalisableFieldDescriptor(
93-
"call_LPL", "FORMAT/PL", zarr_utils.sanitise_int_array, compute_lpl_field
93+
"call_LPL", "FORMAT/PL", icf.sanitise_int_array, compute_lpl_field
9494
),
9595
]
9696

@@ -419,24 +419,33 @@ def finalise_partition_array(self, partition_index, buffered_array):
419419
def encode_array_partition(self, array_spec, partition_index):
420420
partition = self.metadata.partitions[partition_index]
421421
ba = self.init_partition_array(partition_index, array_spec.name)
422-
self.source.write_other_field_to_buffered_array(
423-
ba,
422+
for value in self.source.iter_field(
424423
array_spec.vcf_field,
424+
ba.buff.shape[1:],
425425
partition.start,
426426
partition.stop,
427-
)
427+
):
428+
j = ba.next_buffer_row()
429+
ba.buff[j] = value
430+
428431
self.finalise_partition_array(partition_index, ba)
429432

430433
def encode_genotypes_partition(self, partition_index):
431434
partition = self.metadata.partitions[partition_index]
432435
gt = self.init_partition_array(partition_index, "call_genotype")
433436
gt_phased = self.init_partition_array(partition_index, "call_genotype_phased")
434-
self.source.write_genotypes_to_buffered_array(
435-
gt,
436-
gt_phased,
437+
438+
for genotype, phased in self.source.iter_genotypes(
439+
gt.buff.shape[1:],
437440
partition.start,
438-
partition.stop,
439-
)
441+
partition.stop
442+
):
443+
j = gt.next_buffer_row()
444+
gt.buff[j] = genotype
445+
446+
j_phased = gt_phased.next_buffer_row()
447+
gt_phased.buff[j_phased] = phased
448+
440449
self.finalise_partition_array(partition_index, gt)
441450
self.finalise_partition_array(partition_index, gt_phased)
442451

@@ -504,9 +513,9 @@ def encode_alleles_partition(self, partition_index):
504513
alleles = self.init_partition_array(partition_index, "variant_allele")
505514
partition = self.metadata.partitions[partition_index]
506515

507-
self.source.write_alleles_to_buffered_array(
508-
alleles, partition.start, partition.stop
509-
)
516+
for value in self.source.iter_alleles(partition.start, partition.stop, alleles.array.shape[1]):
517+
j = alleles.next_buffer_row()
518+
alleles.buff[j] = value
510519

511520
self.finalise_partition_array(partition_index, alleles)
512521

0 commit comments

Comments
 (0)