diff --git a/tests/test_stats.py b/tests/test_stats.py index 7a28c3f..d1de722 100644 --- a/tests/test_stats.py +++ b/tests/test_stats.py @@ -3,7 +3,7 @@ import pytest import zarr -from bio2zarr import vcf2zarr +from bio2zarr import icf from vcztools.stats import nrecords, stats @@ -39,7 +39,7 @@ def test_stats__no_index(tmp_path): original = pathlib.Path("tests/data/vcf") / "sample.vcf.gz" # don't use cache here since we want to make sure vcz is not indexed vcz = tmp_path.joinpath("intermediate.vcz") - vcf2zarr.convert([original], vcz, worker_processes=0, local_alleles=False) + icf.convert([original], vcz, worker_processes=0, local_alleles=False) # delete the index created by vcf2zarr root = zarr.open(vcz, mode="a") diff --git a/tests/test_vcf_writer.py b/tests/test_vcf_writer.py index 6ec9619..625eb66 100644 --- a/tests/test_vcf_writer.py +++ b/tests/test_vcf_writer.py @@ -6,7 +6,6 @@ import numpy as np import pytest import zarr -from bio2zarr import vcf2zarr from cyvcf2 import VCF from numpy.testing import assert_array_equal @@ -301,15 +300,9 @@ def test_write_vcf__header_flags(tmp_path): assert_vcfs_close(original, output) -def test_write_vcf__generate_header(tmp_path): +def test_write_vcf__generate_header(): original = pathlib.Path("tests/data/vcf") / "sample.vcf.gz" - # don't use cache here since we mutate the vcz - vcz = tmp_path.joinpath("intermediate.vcz") - vcf2zarr.convert([original], vcz, worker_processes=0, local_alleles=False) - - # remove vcf_header - root = zarr.open(vcz, mode="r+") - del root.attrs["vcf_header"] + vcz = vcz_path_cache(original) output_header = StringIO() write_vcf(vcz, output_header, header_only=True, no_version=True) @@ -324,9 +317,9 @@ def test_write_vcf__generate_header(tmp_path): ##INFO= ##INFO= ##INFO= -##FILTER= -##FILTER= -##FILTER= +##FILTER= +##FILTER= +##FILTER= ##FORMAT= ##FORMAT= ##FORMAT= @@ -338,6 +331,7 @@ def test_write_vcf__generate_header(tmp_path): """ # noqa: E501 # substitute value of source + root = zarr.open(vcz, mode="r+") expected_vcf_header = expected_vcf_header.format(root.attrs["source"]) assert output_header.getvalue() == expected_vcf_header diff --git a/tests/utils.py b/tests/utils.py index 6eb5c29..74ab821 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -5,7 +5,7 @@ import cyvcf2 import numpy as np -from bio2zarr import vcf2zarr +from bio2zarr import icf @contextmanager @@ -29,6 +29,64 @@ def normalise_info_missingness(info_dict, key): return value +def _get_header_field_dicts(vcf, header_type): + def to_dict(header_field): + d = header_field.info(extra=True) + del d[b"IDX"] # remove IDX since we don't care about ordering + + # cyvcf2 duplicates some keys as strings and bytes, so remove the bytes one + for k in list(d.keys()): + if isinstance(k, bytes) and k.decode("utf-8") in d: + del d[k] + return d + + return { + field["ID"]: to_dict(field) + for field in vcf.header_iter() + if field["HeaderType"] == header_type + } + + +def _assert_header_field_dicts_equivalent(field_dicts1, field_dicts2): + assert len(field_dicts1) == len(field_dicts2) + + for id in field_dicts1.keys(): + assert id in field_dicts2 + field_dict1 = field_dicts1[id] + field_dict2 = field_dicts2[id] + + assert len(field_dict1) == len(field_dict2) + # all fields should be the same, except Number="." which can match any value + for k in field_dict1.keys(): + assert k in field_dict2 + v1 = field_dict1[k] + v2 = field_dict2[k] + if k == "Number" and (v1 == "." or v2 == "."): + continue + assert v1 == v2, f"Failed in field {id} with key {k}" + + +def _assert_vcf_headers_equivalent(vcf1, vcf2): + # Only compare INFO, FORMAT, FILTER, CONTIG fields, ignoring order + # Other fields are ignored + + info1 = _get_header_field_dicts(vcf1, "INFO") + info2 = _get_header_field_dicts(vcf2, "INFO") + _assert_header_field_dicts_equivalent(info1, info2) + + format1 = _get_header_field_dicts(vcf1, "FORMAT") + format2 = _get_header_field_dicts(vcf2, "FORMAT") + _assert_header_field_dicts_equivalent(format1, format2) + + filter1 = _get_header_field_dicts(vcf1, "FILTER") + filter2 = _get_header_field_dicts(vcf2, "FILTER") + _assert_header_field_dicts_equivalent(filter1, filter2) + + contig1 = _get_header_field_dicts(vcf1, "CONTIG") + contig2 = _get_header_field_dicts(vcf2, "CONTIG") + _assert_header_field_dicts_equivalent(contig1, contig2) + + def assert_vcfs_close(f1, f2, *, rtol=1e-05, atol=1e-03, allow_zero_variants=False): """Like :py:func:`numpy.testing.assert_allclose()`, but for VCF files. @@ -48,7 +106,7 @@ def assert_vcfs_close(f1, f2, *, rtol=1e-05, atol=1e-03, allow_zero_variants=Fal Absolute tolerance. """ with open_vcf(f1) as vcf1, open_vcf(f2) as vcf2: - assert vcf1.raw_header == vcf2.raw_header + _assert_vcf_headers_equivalent(vcf1, vcf2) assert vcf1.samples == vcf2.samples count = 0 @@ -145,7 +203,7 @@ def vcz_path_cache(vcf_path): cached_vcz_path = (cache_path / vcf_path.name).with_suffix(".vcz") if not cached_vcz_path.exists(): if vcf_path.name.startswith("chr22"): - vcf2zarr.convert( + icf.convert( [vcf_path], cached_vcz_path, worker_processes=0, @@ -153,7 +211,7 @@ def vcz_path_cache(vcf_path): samples_chunk_size=10, ) else: - vcf2zarr.convert( + icf.convert( [vcf_path], cached_vcz_path, worker_processes=0, local_alleles=False ) return cached_vcz_path diff --git a/vcztools/vcf_writer.py b/vcztools/vcf_writer.py index 9953112..b3ced38 100644 --- a/vcztools/vcf_writer.py +++ b/vcztools/vcf_writer.py @@ -1,6 +1,5 @@ import io import logging -import re import sys from datetime import datetime from typing import Optional @@ -145,11 +144,9 @@ def write_vcf( ) if not no_header: - original_header = root.attrs.get("vcf_header", None) force_ac_an_header = not drop_genotypes and samples_selection is not None vcf_header = _generate_header( root, - original_header, sample_ids, no_version=no_version, force_ac_an=force_ac_an_header, @@ -300,7 +297,6 @@ def c_chunk_to_vcf( def _generate_header( ds, - original_header, sample_ids, *, no_version: bool = False, @@ -340,45 +336,12 @@ def _generate_header( if key in ("genotype", "genotype_phased"): continue format_fields.append(key) - if original_header is None: # generate entire header - # [1.4.1 File format] - print("##fileformat=VCFv4.3", file=output) - - if "source" in ds.attrs: - print(f'##source={ds.attrs["source"]}', file=output) - - else: # use original header fields where appropriate - unstructured_pattern = re.compile("##([^=]+)=([^<].*)") - structured_pattern = re.compile("##([^=]+)=(<.*)") - - for line in original_header.split("\n"): - if re.fullmatch(unstructured_pattern, line): - print(line, file=output) - else: - match = re.fullmatch(structured_pattern, line) - if match: - category = match.group(1) - id_pattern = re.compile("ID=([^,>]+)") - key = id_pattern.findall(line)[0] - if category not in ("contig", "FILTER", "INFO", "FORMAT"): - # output other structured fields - print(line, file=output) - # only output certain categories if in dataset - elif category == "contig" and key in contigs: - contigs.remove(key) - print(line, file=output) - elif category == "FILTER" and key in filters: - filters.remove(key) - print(line, file=output) - elif category == "INFO" and key in info_fields: - info_fields.remove(key) - print(line, file=output) - elif category == "FORMAT" and key in format_fields: - format_fields.remove(key) - print(line, file=output) - - # add all fields that are not in the original header - # or all fields if there was no original header + + # [1.4.1 File format] + print("##fileformat=VCFv4.3", file=output) + + if "source" in ds.attrs: + print(f'##source={ds.attrs["source"]}', file=output) # [1.4.2 Information field format] for key in info_fields: @@ -406,8 +369,17 @@ def _generate_header( ) # [1.4.3 Filter field format] - for filter in filters: - print(f'##FILTER=', file=output) + filter_descriptions = ( + ds["filter_description"] if "filter_description" in ds else None + ) + for i, filter in enumerate(filters): + filter_description = ( + "" if filter_descriptions is None else filter_descriptions[i] + ) + print( + f'##FILTER=', + file=output, + ) # [1.4.4 Individual format field format] for key in format_fields: @@ -430,9 +402,7 @@ def _generate_header( ) # [1.4.7 Contig field format] - contig_lengths = ( - ds.attrs["contig_lengths"] if "contig_lengths" in ds.attrs else None - ) + contig_lengths = ds["contig_length"] if "contig_length" in ds else None for i, contig in enumerate(contigs): if contig_lengths is None: print(f"##contig=", file=output)