|
16 | 16 | from bio2zarr import schema, zarr_utils
|
17 | 17 |
|
18 | 18 | from .. import constants, core, provenance, vcf_utils
|
| 19 | +from functools import partial |
19 | 20 |
|
20 | 21 | logger = logging.getLogger(__name__)
|
21 | 22 |
|
| 23 | +def sanitise_value_bool(shape, value): |
| 24 | + x = True |
| 25 | + if value is None: |
| 26 | + x = False |
| 27 | + return x |
| 28 | + |
| 29 | + |
| 30 | +def sanitise_value_float_scalar(shape, value): |
| 31 | + x = value |
| 32 | + if value is None: |
| 33 | + x = [constants.FLOAT32_MISSING] |
| 34 | + return x[0] |
| 35 | + |
| 36 | + |
| 37 | +def sanitise_value_int_scalar(shape, value): |
| 38 | + x = value |
| 39 | + if value is None: |
| 40 | + x = [constants.INT_MISSING] |
| 41 | + else: |
| 42 | + x = sanitise_int_array(value, ndmin=1, dtype=np.int32) |
| 43 | + return x[0] |
| 44 | + |
| 45 | + |
| 46 | +def sanitise_value_string_scalar(shape, value): |
| 47 | + if value is None: |
| 48 | + return "." |
| 49 | + else: |
| 50 | + return value[0] |
| 51 | + |
| 52 | + |
| 53 | +def sanitise_value_string_1d(shape, value): |
| 54 | + if value is None: |
| 55 | + return np.full(shape, ".", dtype='O') |
| 56 | + else: |
| 57 | + value = drop_empty_second_dim(value) |
| 58 | + result = np.full(shape, "", dtype=value.dtype) |
| 59 | + result[:value.shape[0]] = value |
| 60 | + return result |
| 61 | + |
| 62 | + |
| 63 | +def sanitise_value_string_2d(shape, value): |
| 64 | + if value is None: |
| 65 | + return np.full(shape, ".", dtype='O') |
| 66 | + else: |
| 67 | + result = np.full(shape, "", dtype='O') |
| 68 | + if value.ndim == 2: |
| 69 | + result[:value.shape[0], :value.shape[1]] = value |
| 70 | + else: |
| 71 | + # Convert 1D array into 2D with appropriate shape |
| 72 | + for k, val in enumerate(value): |
| 73 | + result[k, :len(val)] = val |
| 74 | + return result |
| 75 | + |
| 76 | + |
| 77 | +def drop_empty_second_dim(value): |
| 78 | + assert len(value.shape) == 1 or value.shape[1] == 1 |
| 79 | + if len(value.shape) == 2 and value.shape[1] == 1: |
| 80 | + value = value[..., 0] |
| 81 | + return value |
| 82 | + |
| 83 | +def sanitise_value_float_1d(shape, value): |
| 84 | + if value is None: |
| 85 | + return np.full(shape, constants.FLOAT32_MISSING) |
| 86 | + else: |
| 87 | + value = np.array(value, ndmin=1, dtype=np.float32, copy=True) |
| 88 | + # numpy will map None values to Nan, but we need a |
| 89 | + # specific NaN |
| 90 | + value[np.isnan(value)] = constants.FLOAT32_MISSING |
| 91 | + value = drop_empty_second_dim(value) |
| 92 | + result = np.full(shape, constants.FLOAT32_FILL, dtype=np.float32) |
| 93 | + result[:value.shape[0]] = value |
| 94 | + print(result) |
| 95 | + return result |
| 96 | + |
| 97 | +def sanitise_value_float_2d(shape, value): |
| 98 | + if value is None: |
| 99 | + return np.full(shape, constants.FLOAT32_MISSING) |
| 100 | + else: |
| 101 | + value = np.array(value, ndmin=2, dtype=np.float32, copy=True) |
| 102 | + result = np.full(shape, constants.FLOAT32_FILL, dtype=np.float32) |
| 103 | + result[:, :value.shape[1]] = value |
| 104 | + print(result) |
| 105 | + return result |
| 106 | + |
| 107 | + |
| 108 | +def sanitise_int_array(value, ndmin, dtype): |
| 109 | + if isinstance(value, tuple): |
| 110 | + value = [ |
| 111 | + constants.VCF_INT_MISSING if x is None else x for x in value |
| 112 | + ] # NEEDS TEST |
| 113 | + value = np.array(value, ndmin=ndmin, copy=True) |
| 114 | + value[value == constants.VCF_INT_MISSING] = -1 |
| 115 | + value[value == constants.VCF_INT_FILL] = -2 |
| 116 | + # TODO watch out for clipping here! |
| 117 | + return value.astype(dtype) |
| 118 | + |
| 119 | + |
| 120 | +def sanitise_value_int_1d(shape, value): |
| 121 | + if value is None: |
| 122 | + return np.full(shape, -1) |
| 123 | + else: |
| 124 | + value = sanitise_int_array(value, 1, np.int32) |
| 125 | + value = drop_empty_second_dim(value) |
| 126 | + result = np.full(shape, -2, dtype=np.int32) |
| 127 | + result[:value.shape[0]] = value |
| 128 | + return result |
| 129 | + |
| 130 | + |
| 131 | +def sanitise_value_int_2d(shape, value): |
| 132 | + if value is None: |
| 133 | + return np.full(shape, -1) |
| 134 | + else: |
| 135 | + value = sanitise_int_array(value, 2, np.int32) |
| 136 | + result = np.full(shape, -2, dtype=np.int32) |
| 137 | + result[:, :value.shape[1]] = value |
| 138 | + return result |
| 139 | + |
22 | 140 |
|
23 | 141 | @dataclasses.dataclass
|
24 | 142 | class VcfFieldSummary(core.JsonDataclass):
|
@@ -572,35 +690,41 @@ def values(self):
|
572 | 690 |
|
573 | 691 | def sanitiser_factory(self, shape):
|
574 | 692 | """
|
575 |
| - Return a function that sanitised values from this column |
576 |
| - and writes into a buffer of the specified shape. |
| 693 | + Return a function that sanitises values from this column |
| 694 | + and returns a properly formatted array with the specified shape. |
| 695 | + |
| 696 | + Args: |
| 697 | + shape: The shape of the target buffer, used to determine how to format the output |
| 698 | + |
| 699 | + Returns: |
| 700 | + A function that takes a value and returns a sanitised version |
577 | 701 | """
|
578 |
| - assert len(shape) <= 3 |
| 702 | + assert len(shape) <= 2 |
579 | 703 | if self.vcf_field.vcf_type == "Flag":
|
580 |
| - assert len(shape) == 1 |
581 |
| - return zarr_utils.sanitise_value_bool |
| 704 | + assert len(shape) == 0 |
| 705 | + return partial(sanitise_value_bool, shape) |
582 | 706 | elif self.vcf_field.vcf_type == "Float":
|
583 |
| - if len(shape) == 1: |
584 |
| - return zarr_utils.sanitise_value_float_scalar |
585 |
| - elif len(shape) == 2: |
586 |
| - return zarr_utils.sanitise_value_float_1d |
| 707 | + if len(shape) == 0: |
| 708 | + return partial(sanitise_value_float_scalar, shape) |
| 709 | + elif len(shape) == 1: |
| 710 | + return partial(sanitise_value_float_1d, shape) |
587 | 711 | else:
|
588 |
| - return zarr_utils.sanitise_value_float_2d |
| 712 | + return partial(sanitise_value_float_2d, shape) |
589 | 713 | elif self.vcf_field.vcf_type == "Integer":
|
590 |
| - if len(shape) == 1: |
591 |
| - return zarr_utils.sanitise_value_int_scalar |
592 |
| - elif len(shape) == 2: |
593 |
| - return zarr_utils.sanitise_value_int_1d |
| 714 | + if len(shape) == 0: |
| 715 | + return partial(sanitise_value_int_scalar, shape) |
| 716 | + elif len(shape) == 1: |
| 717 | + return partial(sanitise_value_int_1d, shape) |
594 | 718 | else:
|
595 |
| - return zarr_utils.sanitise_value_int_2d |
| 719 | + return partial(sanitise_value_int_2d, shape) |
596 | 720 | else:
|
597 | 721 | assert self.vcf_field.vcf_type in ("String", "Character")
|
598 |
| - if len(shape) == 1: |
599 |
| - return zarr_utils.sanitise_value_string_scalar |
600 |
| - elif len(shape) == 2: |
601 |
| - return zarr_utils.sanitise_value_string_1d |
| 722 | + if len(shape) == 0: |
| 723 | + return partial(sanitise_value_string_scalar, shape) |
| 724 | + elif len(shape) == 1: |
| 725 | + return partial(sanitise_value_string_1d, shape) |
602 | 726 | else:
|
603 |
| - return zarr_utils.sanitise_value_string_2d |
| 727 | + return partial(sanitise_value_string_2d, shape) |
604 | 728 |
|
605 | 729 |
|
606 | 730 | @dataclasses.dataclass
|
@@ -790,40 +914,33 @@ def root_attrs(self):
|
790 | 914 | "vcf_header": self.vcf_header,
|
791 | 915 | }
|
792 | 916 |
|
793 |
| - def write_alleles_to_buffered_array(self, ba, start, stop): |
| 917 | + def iter_alleles(self, start, stop, num_alleles): |
794 | 918 | ref_field = self.fields["REF"]
|
795 | 919 | alt_field = self.fields["ALT"]
|
796 | 920 |
|
797 | 921 | for ref, alt in zip(
|
798 | 922 | ref_field.iter_values(start, stop),
|
799 | 923 | alt_field.iter_values(start, stop),
|
800 | 924 | ):
|
801 |
| - j = ba.next_buffer_row() |
802 |
| - ba.buff[j, :] = constants.STR_FILL |
803 |
| - ba.buff[j, 0] = ref[0] |
804 |
| - ba.buff[j, 1 : 1 + len(alt)] = alt |
| 925 | + alleles = np.full(num_alleles, constants.STR_FILL, dtype="O") |
| 926 | + alleles[0] = ref[0] |
| 927 | + alleles[1 : 1 + len(alt)] = alt |
| 928 | + yield alleles |
805 | 929 |
|
806 |
| - def write_other_field_to_buffered_array(self, ba, field_name, start, stop): |
| 930 | + def iter_field(self, field_name, shape, start, stop): |
807 | 931 | source_field = self.fields[field_name]
|
808 |
| - sanitiser = source_field.sanitiser_factory(ba.buff.shape) |
809 |
| - |
| 932 | + sanitiser = source_field.sanitiser_factory(shape) |
810 | 933 | for value in source_field.iter_values(start, stop):
|
811 |
| - # We write directly into the buffer in the sanitiser function |
812 |
| - # to make it easier to reason about dimension padding |
813 |
| - j = ba.next_buffer_row() |
814 |
| - sanitiser(ba.buff, j, value) |
| 934 | + yield sanitiser(value) |
815 | 935 |
|
816 |
| - def write_genotypes_to_buffered_array(self, gt, gt_phased, start, stop): |
| 936 | + def iter_genotypes(self, shape, start, stop): |
817 | 937 | source_field = self.fields["FORMAT/GT"]
|
818 | 938 | for value in source_field.iter_values(start, stop):
|
819 |
| - j = gt.next_buffer_row() |
820 |
| - zarr_utils.sanitise_value_int_2d( |
821 |
| - gt.buff, j, value[:, :-1] if value is not None else None |
822 |
| - ) |
823 |
| - j = gt_phased.next_buffer_row() |
824 |
| - zarr_utils.sanitise_value_int_1d( |
825 |
| - gt_phased.buff, j, value[:, -1] if value is not None else None |
826 |
| - ) |
| 939 | + genotypes = value[:, :-1] if value is not None else None |
| 940 | + phased = value[:, -1] if value is not None else None |
| 941 | + sanitised_genotypes = sanitise_value_int_2d(shape, genotypes) |
| 942 | + sanitised_phased = sanitise_value_int_1d(shape[:-1], phased) |
| 943 | + yield sanitised_genotypes, sanitised_phased |
827 | 944 |
|
828 | 945 |
|
829 | 946 | @dataclasses.dataclass
|
|
0 commit comments