diff --git a/bio2zarr/cli.py b/bio2zarr/cli.py index e37bd766..6e21576c 100644 --- a/bio2zarr/cli.py +++ b/bio2zarr/cli.py @@ -221,7 +221,6 @@ def show_work_summary(work_summary, json): @compressor @progress @worker_processes -@local_alleles def explode( vcfs, icf_path, @@ -231,7 +230,6 @@ def explode( compressor, progress, worker_processes, - local_alleles, ): """ Convert VCF(s) to intermediate columnar format @@ -245,7 +243,6 @@ def explode( column_chunk_size=column_chunk_size, compressor=get_compressor(compressor), show_progress=progress, - local_alleles=local_alleles, ) @@ -260,7 +257,6 @@ def explode( @verbose @progress @worker_processes -@local_alleles def dexplode_init( vcfs, icf_path, @@ -272,7 +268,6 @@ def dexplode_init( verbose, progress, worker_processes, - local_alleles, ): """ Initial step for distributed conversion of VCF(s) to intermediate columnar format @@ -289,7 +284,6 @@ def dexplode_init( worker_processes=worker_processes, compressor=get_compressor(compressor), show_progress=progress, - local_alleles=local_alleles, ) show_work_summary(work_summary, json) @@ -340,7 +334,8 @@ def inspect(path, verbose): @icf_path @variants_chunk_size @samples_chunk_size -def mkschema(icf_path, variants_chunk_size, samples_chunk_size): +@local_alleles +def mkschema(icf_path, variants_chunk_size, samples_chunk_size, local_alleles): """ Generate a schema for zarr encoding """ @@ -350,6 +345,7 @@ def mkschema(icf_path, variants_chunk_size, samples_chunk_size): stream, variants_chunk_size=variants_chunk_size, samples_chunk_size=samples_chunk_size, + local_alleles=local_alleles, ) diff --git a/bio2zarr/core.py b/bio2zarr/core.py index ac4fc293..af61d6d1 100644 --- a/bio2zarr/core.py +++ b/bio2zarr/core.py @@ -63,6 +63,27 @@ def chunk_aligned_slices(z, n, max_chunks=None): return slices +def first_dim_slice_iter(z, start, stop): + """ + Efficiently iterate over the specified slice of the first dimension of the zarr + array z. + """ + chunk_size = z.chunks[0] + first_chunk = start // chunk_size + last_chunk = (stop // chunk_size) + (stop % chunk_size != 0) + for chunk in range(first_chunk, last_chunk): + Z = z.blocks[chunk] + chunk_start = chunk * chunk_size + chunk_stop = chunk_start + chunk_size + slice_start = None + if start > chunk_start: + slice_start = start - chunk_start + slice_stop = None + if stop < chunk_stop: + slice_stop = stop - chunk_start + yield from Z[slice_start:slice_stop] + + def du(path): """ Return the total bytes stored at this path. diff --git a/bio2zarr/vcf2zarr/icf.py b/bio2zarr/vcf2zarr/icf.py index 8e313f8c..dd0bbc91 100644 --- a/bio2zarr/vcf2zarr/icf.py +++ b/bio2zarr/vcf2zarr/icf.py @@ -217,7 +217,7 @@ def make_field_def(name, vcf_type, vcf_number): return fields -def scan_vcf(path, target_num_partitions, *, local_alleles): +def scan_vcf(path, target_num_partitions): with vcf_utils.IndexedVcf(path) as indexed_vcf: vcf = indexed_vcf.vcf filters = [] @@ -237,10 +237,6 @@ def scan_vcf(path, target_num_partitions, *, local_alleles): pass_filter = filters.pop(pass_index) filters.insert(0, pass_filter) - # Indicates whether vcf2zarr can introduce local alleles - can_localize = False - should_add_laa_field = True - should_add_lpl_field = True fields = fixed_vcf_field_definitions() for h in vcf.header_iter(): if h["HeaderType"] in ["INFO", "FORMAT"]: @@ -249,36 +245,6 @@ def scan_vcf(path, target_num_partitions, *, local_alleles): field.vcf_type = "Integer" field.vcf_number = "." fields.append(field) - if field.category == "FORMAT": - if field.name == "PL": - can_localize = True - if field.name == "LAA": - should_add_laa_field = False - if field.name == "LPL": - should_add_lpl_field = False - - if local_alleles and can_localize: - if should_add_laa_field: - laa_field = VcfField( - category="FORMAT", - name="LAA", - vcf_type="Integer", - vcf_number=".", - description="1-based indices into ALT, indicating which alleles" - " are relevant (local) for the current sample", - summary=VcfFieldSummary(), - ) - fields.append(laa_field) - if should_add_lpl_field: - lpl_field = VcfField( - category="FORMAT", - name="LPL", - vcf_type="Integer", - vcf_number="LG", - description="Local-allele representation of PL", - summary=VcfFieldSummary(), - ) - fields.append(lpl_field) try: contig_lengths = vcf.seqlens @@ -315,14 +281,7 @@ def scan_vcf(path, target_num_partitions, *, local_alleles): return metadata, vcf.raw_header -def scan_vcfs( - paths, - show_progress, - target_num_partitions, - worker_processes=1, - *, - local_alleles, -): +def scan_vcfs(paths, show_progress, target_num_partitions, worker_processes=1): logger.info( f"Scanning {len(paths)} VCFs attempting to split into {target_num_partitions}" f" partitions." @@ -346,7 +305,6 @@ def scan_vcfs( scan_vcf, path, max(1, target_num_partitions // len(paths)), - local_alleles=local_alleles, ) results = list(pwm.results_as_completed()) @@ -505,104 +463,6 @@ def sanitise_value_int_2d(buff, j, value): buff[j, :, : value.shape[1]] = value -def compute_laa_field(variant) -> np.ndarray: - """ - Computes the value of the LAA field for each sample given a variant. - - The LAA field is a list of one-based indices into the ALT alleles - that indicates which alternate alleles are observed in the sample. - - This method infers which alleles are observed from the GT field. - """ - sample_count = variant.num_called + variant.num_unknown - alt_allele_count = len(variant.ALT) - allele_count = alt_allele_count + 1 - allele_counts = np.zeros((sample_count, allele_count), dtype=int) - - if "GT" in variant.FORMAT: - # The last element of each sample's genotype indicates the phasing - # and is not an allele. - genotypes = variant.genotype.array()[:, :-1] - genotypes.clip(0, None, out=genotypes) - genotype_allele_counts = np.apply_along_axis( - np.bincount, axis=1, arr=genotypes, minlength=allele_count - ) - allele_counts += genotype_allele_counts - - allele_counts[:, 0] = 0 # We don't count the reference allele - max_row_length = 1 - - def nonzero_pad(arr: np.ndarray, *, length: int): - nonlocal max_row_length - alleles = arr.nonzero()[0] - max_row_length = max(max_row_length, len(alleles)) - pad_length = length - len(alleles) - return np.pad( - alleles, - (0, pad_length), - mode="constant", - constant_values=constants.INT_FILL, - ) - - alleles = np.apply_along_axis( - nonzero_pad, axis=1, arr=allele_counts, length=max(1, alt_allele_count) - ) - alleles = alleles[:, :max_row_length] - - return alleles - - -def compute_lpl_field(variant, laa_val: np.ndarray) -> np.ndarray: - assert laa_val is not None - - la_val = np.zeros((laa_val.shape[0], laa_val.shape[1] + 1), dtype=laa_val.dtype) - la_val[:, 1:] = laa_val - ploidy = variant.ploidy - - if "PL" not in variant.FORMAT: - sample_count = variant.num_called + variant.num_unknown - local_allele_count = la_val.shape[1] - - if ploidy == 1: - local_genotype_count = local_allele_count - elif ploidy == 2: - local_genotype_count = local_allele_count * (local_allele_count + 1) // 2 - else: - raise ValueError(f"Cannot handle ploidy = {ploidy}") - - return np.full((sample_count, local_genotype_count), constants.INT_MISSING) - - # Compute a and b - if ploidy == 1: - a = la_val - b = np.zeros_like(la_val) - elif ploidy == 2: - repeats = np.arange(1, la_val.shape[1] + 1) - b = np.repeat(la_val, repeats, axis=1) - arange_tile = np.tile(np.arange(la_val.shape[1]), (la_val.shape[1], 1)) - tril_indices = np.tril_indices_from(arange_tile) - a_index = np.tile(arange_tile[tril_indices], (b.shape[0], 1)) - row_index = np.arange(la_val.shape[0]).reshape(-1, 1) - a = la_val[row_index, a_index] - else: - raise ValueError(f"Cannot handle ploidy = {ploidy}") - - # Compute n, the local indices of the PL field - n = (b * (b + 1) / 2 + a).astype(int) - - pl_val = variant.format("PL") - pl_val[pl_val == constants.VCF_INT_MISSING] = constants.INT_MISSING - # When the PL value is missing in all samples, pl_val has shape (sample_count, 1). - # In that case, we need to broadcast the PL value. - if pl_val.shape[1] < n.shape[1]: - pl_val = np.broadcast_to(pl_val, n.shape) - row_index = np.arange(pl_val.shape[0]).reshape(-1, 1) - lpl_val = pl_val[row_index, n] - lpl_val[b == constants.INT_FILL] = constants.INT_FILL - - return lpl_val - - missing_value_map = { "Integer": constants.INT_MISSING, "Float": constants.FLOAT32_MISSING, @@ -1107,14 +967,11 @@ def init( target_num_partitions=None, show_progress=False, compressor=None, - local_alleles=None, ): if self.path.exists(): raise ValueError(f"ICF path already exists: {self.path}") if compressor is None: compressor = ICF_DEFAULT_COMPRESSOR - if local_alleles is None: - local_alleles = False vcfs = [pathlib.Path(vcf) for vcf in vcfs] target_num_partitions = max(target_num_partitions, len(vcfs)) @@ -1124,7 +981,6 @@ def init( worker_processes=worker_processes, show_progress=show_progress, target_num_partitions=target_num_partitions, - local_alleles=local_alleles, ) check_field_clobbering(icf_metadata) self.metadata = icf_metadata @@ -1207,17 +1063,6 @@ def process_partition(self, partition_index): else: format_fields.append(field) - format_field_names = [format_field.name for format_field in format_fields] - if "LAA" in format_field_names and "LPL" in format_field_names: - laa_index = format_field_names.index("LAA") - lpl_index = format_field_names.index("LPL") - # LAA needs to come before LPL - if lpl_index < laa_index: - format_fields[laa_index], format_fields[lpl_index] = ( - format_fields[lpl_index], - format_fields[laa_index], - ) - last_position = None with IcfPartitionWriter( self.metadata, @@ -1245,18 +1090,8 @@ def process_partition(self, partition_index): else: val = variant.genotype.array() tcw.append("FORMAT/GT", val) - laa_val = None for field in format_fields: - if field.name == "LAA": - if "LAA" not in variant.FORMAT: - laa_val = compute_laa_field(variant) - else: - laa_val = variant.format("LAA") - val = laa_val - elif field.name == "LPL" and "LPL" not in variant.FORMAT: - val = compute_lpl_field(variant, laa_val) - else: - val = variant.format(field.name) + val = variant.format(field.name) tcw.append(field.full_name, val) # Note: an issue with updating the progress per variant here like @@ -1352,7 +1187,6 @@ def explode( worker_processes=1, show_progress=False, compressor=None, - local_alleles=None, ): writer = IntermediateColumnarFormatWriter(icf_path) writer.init( @@ -1363,7 +1197,6 @@ def explode( show_progress=show_progress, column_chunk_size=column_chunk_size, compressor=compressor, - local_alleles=local_alleles, ) writer.explode(worker_processes=worker_processes, show_progress=show_progress) writer.finalise() @@ -1379,7 +1212,6 @@ def explode_init( worker_processes=1, show_progress=False, compressor=None, - local_alleles=None, ): writer = IntermediateColumnarFormatWriter(icf_path) return writer.init( @@ -1389,7 +1221,6 @@ def explode_init( show_progress=show_progress, column_chunk_size=column_chunk_size, compressor=compressor, - local_alleles=local_alleles, ) diff --git a/bio2zarr/vcf2zarr/vcz.py b/bio2zarr/vcf2zarr/vcz.py index 1096bb4a..a23888ab 100644 --- a/bio2zarr/vcf2zarr/vcz.py +++ b/bio2zarr/vcf2zarr/vcz.py @@ -182,6 +182,61 @@ def variant_chunk_nbytes(self): ZARR_SCHEMA_FORMAT_VERSION = "0.4" +def convert_local_allele_field_types(fields): + """ + Update the specified list of fields to include the LAA field, and to convert + any supported localisable fields to the L* counterpart. + + Note that we currently support only two ALT alleles per sample, and so the + dimensions of these fields are fixed by that requirement. Later versions may + use summry data storted in the ICF to make different choices, if information + about subsequent alleles (not in the actual genotype calls) should also be + stored. + """ + fields_by_name = {field.name: field for field in fields} + gt = fields_by_name["call_genotype"] + if gt.shape[-1] != 2: + raise ValueError("Local alleles only supported on diploid data") + + # TODO check if LA is already in here + + shape = gt.shape[:-1] + chunks = gt.chunks[:-1] + + la = ZarrArraySpec.new( + vcf_field=None, + name="call_LA", + dtype="i1", + shape=gt.shape, + chunks=gt.chunks, + dimensions=gt.dimensions, # FIXME + description=( + "0-based indices into REF+ALT, indicating which alleles" + " are relevant (local) for the current sample" + ), + ) + ad = fields_by_name.get("call_AD", None) + if ad is not None: + # TODO check if call_LAD is in the list already + ad.name = "call_LAD" + ad.vcf_field = None + ad.shape = (*shape, 2) + ad.chunks = (*chunks, 2) + ad.description += " (local-alleles)" + # TODO fix dimensions + + pl = fields_by_name.get("call_PL", None) + if pl is not None: + # TODO check if call_LPL is in the list already + pl.name = "call_LPL" + pl.vcf_field = None + pl.shape = (*shape, 3) + pl.chunks = (*chunks, 3) + pl.description += " (local-alleles)" + # TODO fix dimensions + return [*fields, la] + + @dataclasses.dataclass class VcfZarrSchema(core.JsonDataclass): format_version: str @@ -232,13 +287,17 @@ def fromjson(s): return VcfZarrSchema.fromdict(json.loads(s)) @staticmethod - def generate(icf, variants_chunk_size=None, samples_chunk_size=None): + def generate( + icf, variants_chunk_size=None, samples_chunk_size=None, local_alleles=None + ): m = icf.num_records n = icf.num_samples if samples_chunk_size is None: samples_chunk_size = 10_000 if variants_chunk_size is None: variants_chunk_size = 1000 + if local_alleles is None: + local_alleles = False logger.info( f"Generating schema with chunks={variants_chunk_size, samples_chunk_size}" ) @@ -365,6 +424,9 @@ def fixed_field_spec( ) ) + if local_alleles: + array_specs = convert_local_allele_field_types(array_specs) + return VcfZarrSchema( format_version=ZARR_SCHEMA_FORMAT_VERSION, samples_chunk_size=samples_chunk_size, @@ -462,6 +524,84 @@ def fromdict(d): return ret +def compute_la_field(genotypes): + """ + Computes the value of the LA field for each sample given the genotypes + for a variant. The LA field lists the unique alleles observed for + each sample, including the REF. + """ + v = 2**31 - 1 + if np.any(genotypes >= v): + raise ValueError("Extreme allele value not supported") + G = genotypes.astype(np.int32) + if len(G) > 0: + # Anything < 0 gets mapped to -2 (pad) in the output, which comes last. + # So, to get this sorting correctly, we remap to the largest value for + # sorting, then map back. We promote the genotypes up to 32 bit for convenience + # here, assuming that we'll never have a allele of 2**31 - 1. + assert np.all(G != v) + G[G < 0] = v + G.sort(axis=1) + G[G[:, 0] == G[:, 1], 1] = -2 + # Equal values result in padding also + G[G == v] = -2 + return G.astype(genotypes.dtype) + + +def compute_lad_field(ad, la): + assert ad.shape[0] == la.shape[0] + assert la.shape[1] == 2 + lad = np.full((ad.shape[0], 2), -2, dtype=ad.dtype) + homs = np.where((la[:, 0] != -2) & (la[:, 1] == -2)) + lad[homs, 0] = ad[homs, la[homs, 0]] + hets = np.where(la[:, 1] != -2) + lad[hets, 0] = ad[hets, la[hets, 0]] + lad[hets, 1] = ad[hets, la[hets, 1]] + return lad + + +def pl_index(a, b): + """ + Returns the PL index for alleles a and b. + """ + return b * (b + 1) // 2 + a + + +def compute_lpl_field(pl, la): + lpl = np.full((pl.shape[0], 3), -2, dtype=pl.dtype) + + homs = np.where((la[:, 0] != -2) & (la[:, 1] == -2)) + a = la[homs, 0] + lpl[homs, 0] = pl[homs, pl_index(a, a)] + + hets = np.where(la[:, 1] != -2)[0] + a = la[hets, 0] + b = la[hets, 1] + lpl[hets, 0] = pl[hets, pl_index(a, a)] + lpl[hets, 1] = pl[hets, pl_index(a, b)] + lpl[hets, 2] = pl[hets, pl_index(b, b)] + + return lpl + + +@dataclasses.dataclass +class LocalisableFieldDescriptor: + array_name: str + vcf_field: str + sanitise: callable + convert: callable + + +localisable_fields = [ + LocalisableFieldDescriptor( + "call_LAD", "FORMAT/AD", icf.sanitise_int_array, compute_lad_field + ), + LocalisableFieldDescriptor( + "call_LPL", "FORMAT/PL", icf.sanitise_int_array, compute_lpl_field + ), +] + + @dataclasses.dataclass class VcfZarrWriteSummary(core.JsonDataclass): num_partitions: int @@ -494,6 +634,12 @@ def has_genotypes(self): return True return False + def has_local_alleles(self): + for field in self.schema.fields: + if field.name == "call_LA" and field.vcf_field is None: + return True + return False + ####################### # init ####################### @@ -686,6 +832,9 @@ def encode_partition(self, partition_index): self.encode_array_partition(array_spec, partition_index) if self.has_genotypes(): self.encode_genotypes_partition(partition_index) + if self.has_local_alleles(): + self.encode_local_alleles_partition(partition_index) + self.encode_local_allele_fields_partition(partition_index) final_path = self.partition_path(partition_index) logger.info(f"Finalising {partition_index} at {final_path}") @@ -757,6 +906,54 @@ def encode_genotypes_partition(self, partition_index): self.finalise_partition_array(partition_index, "call_genotype_mask") self.finalise_partition_array(partition_index, "call_genotype_phased") + def encode_local_alleles_partition(self, partition_index): + partition = self.metadata.partitions[partition_index] + call_LA_array = self.init_partition_array(partition_index, "call_LA") + call_LA = core.BufferedArray(call_LA_array, partition.start) + + gt_array = zarr.open_array( + store=self.wip_partition_array_path(partition_index, "call_genotype"), + mode="r", + ) + for genotypes in core.first_dim_slice_iter( + gt_array, partition.start, partition.stop + ): + la = compute_la_field(genotypes) + j = call_LA.next_buffer_row() + call_LA.buff[j] = la + + call_LA.flush() + self.finalise_partition_array(partition_index, "call_LA") + + def encode_local_allele_fields_partition(self, partition_index): + partition = self.metadata.partitions[partition_index] + la_array = zarr.open_array( + store=self.wip_partition_array_path(partition_index, "call_LA"), + mode="r", + ) + field_map = self.schema.field_map() + # We got through the localisable fields one-by-one so that we don't need to + # keep several large arrays in memory at once for each partition. + for descriptor in localisable_fields: + if descriptor.array_name not in field_map: + continue + assert field_map[descriptor.array_name].vcf_field is None + + array = self.init_partition_array(partition_index, descriptor.array_name) + buff = core.BufferedArray(array, partition.start) + source = self.icf.fields[descriptor.vcf_field].iter_values( + partition.start, partition.stop + ) + for la in core.first_dim_slice_iter( + la_array, partition.start, partition.stop + ): + raw_value = next(source) + value = descriptor.sanitise(raw_value, 2, raw_value.dtype) + j = buff.next_buffer_row() + buff.buff[j] = descriptor.convert(value, la) + buff.flush() + self.finalise_partition_array(partition_index, "array_name") + def encode_alleles_partition(self, partition_index): array_name = "variant_allele" alleles_array = self.init_partition_array(partition_index, array_name) @@ -1035,12 +1232,20 @@ def encode_all_partitions( pwm.submit(self.encode_partition, partition_index) -def mkschema(if_path, out, *, variants_chunk_size=None, samples_chunk_size=None): +def mkschema( + if_path, + out, + *, + variants_chunk_size=None, + samples_chunk_size=None, + local_alleles=None, +): store = icf.IntermediateColumnarFormat(if_path) spec = VcfZarrSchema.generate( store, variants_chunk_size=variants_chunk_size, samples_chunk_size=samples_chunk_size, + local_alleles=local_alleles, ) out.write(spec.asjson()) @@ -1054,6 +1259,7 @@ def encode( max_variant_chunks=None, dimension_separator=None, max_memory=None, + local_alleles=None, worker_processes=1, show_progress=False, ): @@ -1066,6 +1272,7 @@ def encode( schema_path=schema_path, variants_chunk_size=variants_chunk_size, samples_chunk_size=samples_chunk_size, + local_alleles=local_alleles, max_variant_chunks=max_variant_chunks, dimension_separator=dimension_separator, ) @@ -1087,6 +1294,7 @@ def encode_init( schema_path=None, variants_chunk_size=None, samples_chunk_size=None, + local_alleles=None, max_variant_chunks=None, dimension_separator=None, max_memory=None, @@ -1099,6 +1307,7 @@ def encode_init( icf_store, variants_chunk_size=variants_chunk_size, samples_chunk_size=samples_chunk_size, + local_alleles=local_alleles, ) else: logger.info(f"Reading schema from {schema_path}") @@ -1151,7 +1360,6 @@ def convert( vcfs, worker_processes=worker_processes, show_progress=show_progress, - local_alleles=local_alleles, ) encode( icf_path, @@ -1160,6 +1368,7 @@ def convert( samples_chunk_size=samples_chunk_size, worker_processes=worker_processes, show_progress=show_progress, + local_alleles=local_alleles, ) diff --git a/tests/data/vcf/local_alleles.vcf.gz b/tests/data/vcf/local_alleles.vcf.gz deleted file mode 100644 index 5ee4c11b..00000000 Binary files a/tests/data/vcf/local_alleles.vcf.gz and /dev/null differ diff --git a/tests/data/vcf/local_alleles.vcf.gz.csi b/tests/data/vcf/local_alleles.vcf.gz.csi deleted file mode 100644 index d338d60e..00000000 Binary files a/tests/data/vcf/local_alleles.vcf.gz.csi and /dev/null differ diff --git a/tests/test_cli.py b/tests/test_cli.py index 2ede6ee2..c6be17ee 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -14,7 +14,6 @@ compressor=None, worker_processes=1, show_progress=True, - local_alleles=False, ) DEFAULT_DEXPLODE_PARTITION_ARGS = dict() @@ -24,7 +23,6 @@ column_chunk_size=64, compressor=None, show_progress=True, - local_alleles=False, ) DEFAULT_ENCODE_ARGS = dict( @@ -49,6 +47,12 @@ DEFAULT_DENCODE_FINALISE_ARGS = dict(show_progress=True) +DEFAULT_MKSHCHEMA_ARGS = dict( + variants_chunk_size=None, + samples_chunk_size=None, + local_alleles=False, +) + DEFAULT_CONVERT_ARGS = dict( variants_chunk_size=None, samples_chunk_size=None, @@ -297,29 +301,6 @@ def test_vcf_explode_missing_and_existing_vcf(self, mocked, tmp_path): assert "'no_such_file' does not exist" in result.stderr mocked.assert_not_called() - @pytest.mark.parametrize("local_alleles", [False, True]) - @mock.patch("bio2zarr.vcf2zarr.explode") - def test_vcf_explode_local_alleles(self, mocked, tmp_path, local_alleles): - icf_path = tmp_path / "icf" - runner = ct.CliRunner(mix_stderr=False) - - if local_alleles: - local_alleles_flag = "--local-alleles" - else: - local_alleles_flag = "--no-local-alleles" - - result = runner.invoke( - cli.vcf2zarr_main, - f"explode {self.vcf_path} {icf_path} {local_alleles_flag}", - catch_exceptions=False, - ) - assert result.exit_code == 0 - assert len(result.stdout) == 0 - assert len(result.stderr) == 0 - args = dict(DEFAULT_EXPLODE_ARGS) - args["local_alleles"] = local_alleles - mocked.assert_called_once_with(str(icf_path), (self.vcf_path,), **args) - @pytest.mark.parametrize(("progress", "flag"), [(True, "-P"), (False, "-Q")]) @mock.patch("bio2zarr.vcf2zarr.explode_init", return_value=FakeWorkSummary(5)) def test_vcf_dexplode_init(self, mocked, tmp_path, progress, flag): @@ -462,7 +443,7 @@ def test_mkschema(self, mocked, tmp_path): runner = ct.CliRunner(mix_stderr=False) result = runner.invoke( cli.vcf2zarr_main, - f"mkschema {tmp_path} --variants-chunk-size=3 " "--samples-chunk-size=4", + f"mkschema {tmp_path} --variants-chunk-size=3 --samples-chunk-size=4", catch_exceptions=False, ) assert result.exit_code == 0 @@ -726,6 +707,29 @@ def test_mkschema(self, tmp_path): assert d["samples_chunk_size"] == 2 assert d["variants_chunk_size"] == 3 + @pytest.mark.parametrize("local_alleles", [False, True]) + def test_mkschema_local_alleles(self, tmp_path, local_alleles): + icf_path = tmp_path / "icf" + local_alleles_flag = {True: "--local-alleles", False: "--no-local-alleles"}[ + local_alleles + ] + runner = ct.CliRunner(mix_stderr=False) + result = runner.invoke( + cli.vcf2zarr_main, + f"explode {self.vcf_path} {icf_path}", + catch_exceptions=False, + ) + assert result.exit_code == 0 + result = runner.invoke( + cli.vcf2zarr_main, + f"mkschema {icf_path} {local_alleles_flag}", + catch_exceptions=False, + ) + assert result.exit_code == 0 + d = json.loads(result.stdout) + call_LA_exists = "call_LA" in [f["name"] for f in d["fields"]] + assert call_LA_exists == local_alleles + def test_encode(self, tmp_path): icf_path = tmp_path / "icf" zarr_path = tmp_path / "zarr" diff --git a/tests/test_core.py b/tests/test_core.py index 16e2c0c0..3607578f 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -194,6 +194,41 @@ def test_5_chunk_1(self, n, expected): assert result == expected +class TestFirstDimSliceIter: + @pytest.mark.parametrize("chunk_size", [1, 3, 4, 5]) + @pytest.mark.parametrize( + ("size", "start", "stop"), + [ + (10, 0, 4), + (10, 0, 8), + (10, 0, 10), + (10, 4, 4), + (10, 4, 8), + (10, 4, 10), + (10, 0, 5), + (10, 0, 3), + (10, 0, 9), + (10, 1, 5), + (10, 1, 1), + (10, 1, 2), + (10, 1, 3), + (10, 1, 4), + (10, 1, 10), + (10, 5, 5), + (10, 5, 6), + (10, 5, 7), + (5, 0, 5), + (5, 1, 1), + (5, 1, 3), + ], + ) + def test_examples(self, chunk_size, size, start, stop): + a = np.arange(size, dtype=int) + z = zarr.empty(size, chunks=chunk_size, dtype=int) + z[:] = a + assert list(core.first_dim_slice_iter(z, start, stop)) == list(a[start:stop]) + + @pytest.mark.skipif(sys.platform != "linux", reason="Only valid on Linux") @pytest.mark.parametrize( ("path", "expected"), @@ -202,8 +237,8 @@ def test_5_chunk_1(self, n, expected): # It works in CI on Linux, but it'll probably break at some point. # It's also necessary to update these numbers each time a new data # file gets added - ("tests/data", 4977391), - ("tests/data/vcf", 4965254), + ("tests/data", 4976329), + ("tests/data/vcf", 4964192), ("tests/data/vcf/sample.vcf.gz", 1089), ], ) diff --git a/tests/test_icf.py b/tests/test_icf.py index 954a5102..b59ae1dd 100644 --- a/tests/test_icf.py +++ b/tests/test_icf.py @@ -25,7 +25,7 @@ class TestSmallExample: @pytest.fixture(scope="class") def icf(self, tmp_path_factory): out = tmp_path_factory.mktemp("data") / "example.exploded" - return vcf2zarr.explode(out, [self.data_path], local_alleles=False) + return vcf2zarr.explode(out, [self.data_path]) def test_format_version(self, icf): assert icf.metadata.format_version == icf_mod.ICF_METADATA_FORMAT_VERSION @@ -91,46 +91,6 @@ def test_INFO_NS(self, icf): assert icf["INFO/NS"].values == [None, None, 3, 3, 2, 3, 3, None, None] -class TestLocalAllelesExample: - data_path = "tests/data/vcf/local_alleles.vcf.gz" - - fields = ( - "ALT", - "CHROM", - "FILTERS", - "FORMAT/AD", - "FORMAT/DP", - "FORMAT/GQ", - "FORMAT/GT", - "FORMAT/LAA", - "FORMAT/LPL", - "FORMAT/PL", - "ID", - "INFO/AA", - "INFO/AC", - "INFO/AF", - "INFO/AN", - "INFO/DB", - "INFO/DP", - "INFO/H2", - "INFO/NS", - "POS", - "QUAL", - "REF", - "rlen", - ) - - @pytest.fixture(scope="class") - def icf(self, tmp_path_factory): - out = tmp_path_factory.mktemp("data") / "example.exploded" - return vcf2zarr.explode(out, [self.data_path]) - - def test_summary_table(self, icf): - data = icf.summary_table() - fields = [d["name"] for d in data] - assert tuple(sorted(fields)) == self.fields - - class TestIcfWriterExample: data_path = "tests/data/vcf/sample.vcf.gz" diff --git a/tests/test_local_alleles.py b/tests/test_local_alleles.py new file mode 100644 index 00000000..91c183ac --- /dev/null +++ b/tests/test_local_alleles.py @@ -0,0 +1,108 @@ +import numpy as np +import numpy.testing as nt +import pytest + +from bio2zarr.vcf2zarr import vcz + + +class TestComputeLA: + @pytest.mark.parametrize( + ("genotypes", "expected"), + [ + ([], []), + ([[0, 0]], [[0, -2]]), + ([[0, 0], [0, 0]], [[0, -2], [0, -2]]), + ([[1, 1], [0, 0]], [[1, -2], [0, -2]]), + ([[0, 1], [3, 2], [3, 0]], [[0, 1], [2, 3], [0, 3]]), + ([[0, 0], [2, 3]], [[0, -2], [2, 3]]), + ([[2, 3], [0, 0]], [[2, 3], [0, -2]]), + ([[128, 0], [6, 5]], [[0, 128], [5, 6]]), + ([[0, -1], [-1, 5]], [[0, -2], [5, -2]]), + ([[-1, -1], [-1, 5]], [[-2, -2], [5, -2]]), + ], + ) + def test_simple_examples(self, genotypes, expected): + G = np.array(genotypes) + result = vcz.compute_la_field(G) + nt.assert_array_equal(result, expected) + + def test_extreme_value(self): + G = np.array([[0, 2**32 - 1]]) + with pytest.raises(ValueError, match="Extreme"): + vcz.compute_la_field(G) + + +class TestComputeLAD: + @pytest.mark.parametrize( + ("ad", "la", "expected"), + [ + # Missing data + ([[0, 0]], [[-2, -2]], [[-2, -2]]), + # 0/0 calls + ([[10, 0]], [[0, -2]], [[10, -2]]), + ([[10, 0, 0]], [[0, -2]], [[10, -2]]), + ([[10, 0, 0], [11, 0, 0]], [[0, -2], [0, -2]], [[10, -2], [11, -2]]), + # 0/1 calls + ([[10, 11]], [[0, 1]], [[10, 11]]), + ([[10, 11], [12, 0]], [[0, 1], [0, -2]], [[10, 11], [12, -2]]), + # 0/2 calls + ([[10, 0, 11]], [[0, 2]], [[10, 11]]), + ([[10, 0, 11], [10, 11, 0]], [[0, 2], [0, 1]], [[10, 11], [10, 11]]), + ( + [[10, 0, 11], [10, 11, 0], [12, 0, 0]], + [[0, 2], [0, 1], [0, -2]], + [[10, 11], [10, 11], [12, -2]], + ), + # 1/2 calls + ([[0, 10, 11]], [[1, 2]], [[10, 11]]), + ([[0, 10, 11], [12, 0, 13]], [[1, 2], [0, 2]], [[10, 11], [12, 13]]), + ( + [[0, 10, 11], [12, 0, 13], [14, 0, 0]], + [[1, 2], [0, 2], [0, -2]], + [[10, 11], [12, 13], [14, -2]], + ), + ], + ) + def test_simple_examples(self, ad, la, expected): + result = vcz.compute_lad_field(np.array(ad), np.array(la)) + nt.assert_array_equal(result, expected) + + +# PL translation indexes: +# a b i +# 0 0 0 +# 0 1 1 +# 0 2 3 +# 0 3 6 +# 1 1 2 +# 1 2 4 +# 1 3 7 +# 2 2 5 +# 2 3 8 +# 3 3 9 + + +class TestComputeLPL: + @pytest.mark.parametrize( + ("pl", "la", "expected"), + [ + # Missing + ([range(3)], [[-2, -2]], [[-2, -2, -2]]), + # 0/0 calls + ([range(3)], [[0, -2]], [[0, -2, -2]]), + # 0/0 calls + ([[-1, -1, -1]], [[0, -2]], [[-1, -2, -2]]), + # 1/1 calls + ([range(3)], [[1, -2]], [[2, -2, -2]]), + ([range(3), range(3)], [[0, -2], [1, -2]], [[0, -2, -2], [2, -2, -2]]), + # 2/2 calls + ([range(6)], [[2, -2]], [[5, -2, -2]]), + # 0/1 calls + ([range(3)], [[0, 1]], [[0, 1, 2]]), + # 0/2 calls + ([range(6)], [[0, 2]], [[0, 3, 5]]), + ], + ) + def test_simple_examples(self, pl, la, expected): + result = vcz.compute_lpl_field(np.array(pl), np.array(la)) + nt.assert_array_equal(result, expected) diff --git a/tests/test_vcf_examples.py b/tests/test_vcf_examples.py index 1f1d0216..c1cb2ae8 100644 --- a/tests/test_vcf_examples.py +++ b/tests/test_vcf_examples.py @@ -243,11 +243,6 @@ def test_call_HQ(self, ds): ] nt.assert_array_equal(ds["call_HQ"], call_HQ) - def test_no_local_alleles(self, ds): - # The small example VCF does not have a PL field - assert "call_LAA" not in ds - assert "call_LPL" not in ds - def test_no_genotypes(self, ds, tmp_path): path = "tests/data/vcf/sample_no_genotypes.vcf.gz" out = tmp_path / "example.vcf.zarr" @@ -444,79 +439,52 @@ def test_vcf_field_description(self, ds, field, description): assert ds[field].attrs["description"] == description -class TestLocalAllelesExample: - data_path = "tests/data/vcf/local_alleles.vcf.gz" +class TestSmallExampleLocalAlleles: + data_path = "tests/data/vcf/sample.vcf.gz" @pytest.fixture(scope="class") def ds(self, tmp_path_factory): - out = tmp_path_factory.mktemp("data") / "local_alleles.vcf.zarr" - vcf2zarr.convert([self.data_path], out, worker_processes=0) + out = tmp_path_factory.mktemp("data") / "example.vcf.zarr" + vcf2zarr.convert([self.data_path], out, local_alleles=True) return sg.load_dataset(out) - def test_call_LAA(self, ds): - call_LAA = [ - [[1, -2, -2], [-2, -2, -2]], - [[-2, -2, -2], [1, -2, -2]], - [[1, -2, -2], [2, -2, -2]], - [[1, 2, 3], [2, 3, -2]], - [[-2, -2, -2], [-2, -2, -2]], - [[-2, -2, -2], [1, -2, -2]], - [[1, -2, -2], [-1, -2, -2]], - [[2, -2, -2], [2, -2, -2]], - [[-2, -2, -2], [-2, -2, -2]], - [[-2, -2, -2], [-2, -2, -2]], - [[-2, -2, -2], [1, -2, -2]], - ] - nt.assert_array_equal(ds.call_LAA.values, call_LAA) - - def test_call_LPL(self, ds): - call_LPL = [ - [ - [100, 0, 105, -2, -2, -2, -2, -2, -2, -2], - [0, -2, -2, -2, -2, -2, -2, -2, -2, -2], - ], - [ - [0, -2, -2, -2, -2, -2, -2, -2, -2, -2], - [154, 22, 0, -2, -2, -2, -2, -2, -2, -2], - ], - [ - [1002, 55, 1002, -2, -2, -2, -2, -2, -2, -2], - [154, 154, 102, -2, -2, -2, -2, -2, -2, -2], - ], - [ - [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1], - [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1], - ], - [ - [30, -2, -2, -2, -2, -2, -2, -2, -2, -2], - [30, -2, -2, -2, -2, -2, -2, -2, -2, -2], - ], - [ - [0, 30, -2, -2, -2, -2, -2, -2, -2, -2], - [0, 0, -2, -2, -2, -2, -2, -2, -2, -2], - ], - [ - [30, 30, 30, -2, -2, -2, -2, -2, -2, -2], - [30, 30, 30, -2, -2, -2, -2, -2, -2, -2], - ], - [ - [10, 40, 60, -2, -2, -2, -2, -2, -2, -2], - [10, 40, 60, -2, -2, -2, -2, -2, -2, -2], - ], - [ - [30, -2, -2, -2, -2, -2, -2, -2, -2, -2], - [30, -2, -2, -2, -2, -2, -2, -2, -2, -2], - ], + def test_call_LA(self, ds): + call_genotype = np.array( [ - [-1, -2, -2, -2, -2, -2, -2, -2, -2, -2], - [-1, -2, -2, -2, -2, -2, -2, -2, -2, -2], + [[0, 0], [0, 0], [0, 1]], + [[0, 0], [0, 0], [0, 1]], + [[0, 0], [1, 0], [1, 1]], + [[0, 0], [0, 1], [0, 0]], + [[1, 2], [2, 1], [2, 2]], + [[0, 0], [0, 0], [0, 0]], + [[0, 1], [0, 2], [-1, -1]], + [[0, 0], [0, 0], [-1, -1]], + # FIXME this depends on "mixed ploidy" interpretation. + [[0, -2], [0, 1], [0, 2]], ], + dtype="i1", + ) + nt.assert_array_equal(ds["call_genotype"], call_genotype) + nt.assert_array_equal(ds["call_genotype_mask"], call_genotype < 0) + + call_LA = np.array( [ - [-1, -1, -2, -2, -2, -2, -2, -2, -2, -2], - [-1, -1, -2, -2, -2, -2, -2, -2, -2, -2], + [[0, -2], [0, -2], [0, 1]], + [[0, -2], [0, -2], [0, 1]], + [[0, -2], [0, 1], [1, -2]], + [[0, -2], [0, 1], [0, -2]], + [[1, 2], [1, 2], [2, -2]], + [[0, -2], [0, -2], [0, -2]], + [[0, 1], [0, 2], [-2, -2]], + [[0, -2], [0, -2], [-2, -2]], + [[0, -2], [0, 1], [0, 2]], ], - ] - nt.assert_array_equal(ds.call_LPL.values, call_LPL) + ) + nt.assert_array_equal(ds.call_LA.values, call_LA) + + @pytest.mark.parametrize("field", ["call_LPL", "call_LAD"]) + def test_no_localised_fields(self, ds, field): + assert field not in ds class TestTriploidExample: @@ -530,11 +498,11 @@ def ds(self, tmp_path_factory, request): @pytest.mark.parametrize("name", ["triploid", "triploid2", "triploid3"]) def test_error_with_local_alleles(self, tmp_path_factory, name): data_path = f"tests/data/vcf/{name}.vcf.gz" - icf_path = tmp_path_factory.mktemp("data") / f"{name}.icf" - with pytest.raises(ValueError, match=re.escape("Cannot handle ploidy = 3")): - vcf2zarr.explode( - icf_path, [data_path], worker_processes=0, local_alleles=True - ) + out = tmp_path_factory.mktemp("data") / "example.vcf.zarr" + with pytest.raises( + ValueError, match=re.escape("Local alleles only supported on diploid") + ): + vcf2zarr.convert([data_path], out, local_alleles=True) def test_ok_without_local_alleles(self, ds): nt.assert_array_equal(ds.call_genotype.values, [[[0, 0, 0]]]) @@ -546,7 +514,7 @@ class Test1000G2020Example: @pytest.fixture(scope="class") def ds(self, tmp_path_factory): out = tmp_path_factory.mktemp("data") / "example.vcf.zarr" - vcf2zarr.convert([self.data_path], out, worker_processes=0, local_alleles=True) + vcf2zarr.convert([self.data_path], out, worker_processes=0) return sg.load_dataset(out) def test_position(self, ds): @@ -645,20 +613,96 @@ def test_call_AD(self, ds): ] nt.assert_array_equal(ds.call_AD.values, call_AD) - def test_call_LAA(self, ds): + def test_call_PID(self, ds): + call_PGT = ds["call_PGT"].values + assert np.all(call_PGT == ".") + assert call_PGT.shape == (23, 3) + + +class Test1000G2020ExampleLocalAlleles: + data_path = "tests/data/vcf/1kg_2020_chrM.vcf.gz" + + @pytest.fixture(scope="class") + def ds(self, tmp_path_factory): + out = tmp_path_factory.mktemp("data") / "example.vcf.zarr" + vcf2zarr.convert([self.data_path], out, worker_processes=0, local_alleles=True) + return sg.load_dataset(out) + + def test_position(self, ds): + # fmt: off + pos = [ + 26, 35, 40, 41, 42, 46, 47, 51, 52, 53, 54, 55, 56, + 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, + ] + # fmt: on + nt.assert_array_equal(ds.variant_position.values, pos) + + def test_alleles(self, ds): + alleles = [ + ["C", "T", "", "", ""], + ["G", "A", "", "", ""], + ["T", "C", "", "", ""], + ["C", "T", "CT", "", ""], + ["T", "TC", "C", "TG", ""], + ["T", "C", "", "", ""], + ["G", "A", "", "", ""], + ["T", "C", "", "", ""], + ["T", "C", "", "", ""], + ["G", "A", "", "", ""], + ["G", "A", "", "", ""], + ["TA", "TAA", "T", "CA", "AA"], + ["ATT", "*", "ATTT", "ACTT", "A"], + ["T", "C", "G", "*", "TC"], + ["T", "A", "C", "*", ""], + ["T", "A", "", "", ""], + ["T", "A", "", "", ""], + ["C", "A", "T", "", ""], + ["G", "A", "", "", ""], + ["T", "C", "A", "", ""], + ["C", "T", "CT", "A", ""], + ["TG", "T", "CG", "TGG", "TCGG"], + ["G", "T", "*", "A", ""], + ] + nt.assert_array_equal(ds.variant_allele.values, alleles) + + def test_call_LAD(self, ds): + call_LAD = [ + [[446, -2], [393, -2], [486, -2]], + [[446, -2], [393, -2], [486, -2]], + [[446, -2], [393, -2], [486, -2]], + [[446, -2], [393, -2], [486, -2]], + [[446, -2], [393, -2], [486, -2]], + [[446, -2], [393, -2], [486, -2]], + [[446, -2], [393, -2], [486, -2]], + [[446, -2], [393, -2], [486, -2]], + [[446, -2], [393, -2], [486, -2]], + [[446, -2], [393, -2], [486, -2]], + [[446, -2], [393, -2], [486, -2]], + [[446, -2], [393, -2], [486, -2]], + [[446, -2], [393, -2], [486, -2]], + [[446, -2], [393, -2], [486, -2]], + [[446, -2], [393, -2], [486, -2]], + [[446, -2], [393, -2], [486, -2]], + [[446, -2], [393, -2], [486, -2]], + [[446, -2], [393, -2], [486, -2]], + [[446, -2], [393, -2], [486, -2]], + [[446, -2], [393, -2], [486, -2]], + [[446, -2], [393, -2], [486, -2]], + [[446, -2], [393, -2], [486, -2]], + [[446, -2], [393, -2], [486, -2]], + ] + nt.assert_array_equal(ds.call_LAD.values, call_LAD) + + def test_call_LA(self, ds): # All the genotypes are 0/0 - call_LAA = np.full((23, 3, 1), -2) - nt.assert_array_equal(ds.call_LAA.values, call_LAA) + call_LA = np.full((23, 3, 2), -2) + call_LA[:, :, 0] = 0 + nt.assert_array_equal(ds.call_LA.values, call_LA) def test_call_LPL(self, ds): call_LPL = np.tile([0, -2, -2], (23, 3, 1)) nt.assert_array_equal(ds.call_LPL.values, call_LPL) - def test_call_PID(self, ds): - call_PGT = ds["call_PGT"].values - assert np.all(call_PGT == ".") - assert call_PGT.shape == (23, 3) - class Test1000G2020AnnotationsExample: data_path = "tests/data/vcf/1kg_2020_chr20_annotations.bcf"