Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions tests/test_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import pytest
import zarr
from bio2zarr import vcf2zarr
from bio2zarr import icf

from vcztools.stats import nrecords, stats

Expand Down Expand Up @@ -39,7 +39,7 @@ def test_stats__no_index(tmp_path):
original = pathlib.Path("tests/data/vcf") / "sample.vcf.gz"
# don't use cache here since we want to make sure vcz is not indexed
vcz = tmp_path.joinpath("intermediate.vcz")
vcf2zarr.convert([original], vcz, worker_processes=0, local_alleles=False)
icf.convert([original], vcz, worker_processes=0, local_alleles=False)

# delete the index created by vcf2zarr
root = zarr.open(vcz, mode="a")
Expand Down
18 changes: 6 additions & 12 deletions tests/test_vcf_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import numpy as np
import pytest
import zarr
from bio2zarr import vcf2zarr
from cyvcf2 import VCF
from numpy.testing import assert_array_equal

Expand Down Expand Up @@ -301,15 +300,9 @@ def test_write_vcf__header_flags(tmp_path):
assert_vcfs_close(original, output)


def test_write_vcf__generate_header(tmp_path):
def test_write_vcf__generate_header():
original = pathlib.Path("tests/data/vcf") / "sample.vcf.gz"
# don't use cache here since we mutate the vcz
vcz = tmp_path.joinpath("intermediate.vcz")
vcf2zarr.convert([original], vcz, worker_processes=0, local_alleles=False)

# remove vcf_header
root = zarr.open(vcz, mode="r+")
del root.attrs["vcf_header"]
vcz = vcz_path_cache(original)

output_header = StringIO()
write_vcf(vcz, output_header, header_only=True, no_version=True)
Expand All @@ -324,9 +317,9 @@ def test_write_vcf__generate_header(tmp_path):
##INFO=<ID=DP,Number=1,Type=Integer,Description="Total Depth">
##INFO=<ID=H2,Number=0,Type=Flag,Description="HapMap2 membership">
##INFO=<ID=NS,Number=1,Type=Integer,Description="Number of Samples With Data">
##FILTER=<ID=PASS,Description="">
##FILTER=<ID=s50,Description="">
##FILTER=<ID=q10,Description="">
##FILTER=<ID=PASS,Description="All filters passed">
##FILTER=<ID=s50,Description="Less than 50% of samples have data">
##FILTER=<ID=q10,Description="Quality below 10">
##FORMAT=<ID=GT,Number=1,Type=String,Description="Genotype">
##FORMAT=<ID=DP,Number=1,Type=Integer,Description="Read Depth">
##FORMAT=<ID=GQ,Number=1,Type=Integer,Description="Genotype Quality">
Expand All @@ -338,6 +331,7 @@ def test_write_vcf__generate_header(tmp_path):
""" # noqa: E501

# substitute value of source
root = zarr.open(vcz, mode="r+")
expected_vcf_header = expected_vcf_header.format(root.attrs["source"])

assert output_header.getvalue() == expected_vcf_header
Expand Down
66 changes: 62 additions & 4 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import cyvcf2
import numpy as np
from bio2zarr import vcf2zarr
from bio2zarr import icf


@contextmanager
Expand All @@ -29,6 +29,64 @@ def normalise_info_missingness(info_dict, key):
return value


def _get_header_field_dicts(vcf, header_type):
def to_dict(header_field):
d = header_field.info(extra=True)
del d[b"IDX"] # remove IDX since we don't care about ordering

# cyvcf2 duplicates some keys as strings and bytes, so remove the bytes one
for k in list(d.keys()):
if isinstance(k, bytes) and k.decode("utf-8") in d:
del d[k]
return d

return {
field["ID"]: to_dict(field)
for field in vcf.header_iter()
if field["HeaderType"] == header_type
}


def _assert_header_field_dicts_equivalent(field_dicts1, field_dicts2):
assert len(field_dicts1) == len(field_dicts2)

for id in field_dicts1.keys():
assert id in field_dicts2
field_dict1 = field_dicts1[id]
field_dict2 = field_dicts2[id]

assert len(field_dict1) == len(field_dict2)
# all fields should be the same, except Number="." which can match any value
for k in field_dict1.keys():
assert k in field_dict2
v1 = field_dict1[k]
v2 = field_dict2[k]
if k == "Number" and (v1 == "." or v2 == "."):
continue
assert v1 == v2, f"Failed in field {id} with key {k}"


def _assert_vcf_headers_equivalent(vcf1, vcf2):
# Only compare INFO, FORMAT, FILTER, CONTIG fields, ignoring order
# Other fields are ignored

info1 = _get_header_field_dicts(vcf1, "INFO")
info2 = _get_header_field_dicts(vcf2, "INFO")
_assert_header_field_dicts_equivalent(info1, info2)

format1 = _get_header_field_dicts(vcf1, "FORMAT")
format2 = _get_header_field_dicts(vcf2, "FORMAT")
_assert_header_field_dicts_equivalent(format1, format2)

filter1 = _get_header_field_dicts(vcf1, "FILTER")
filter2 = _get_header_field_dicts(vcf2, "FILTER")
_assert_header_field_dicts_equivalent(filter1, filter2)

contig1 = _get_header_field_dicts(vcf1, "CONTIG")
contig2 = _get_header_field_dicts(vcf2, "CONTIG")
_assert_header_field_dicts_equivalent(contig1, contig2)


def assert_vcfs_close(f1, f2, *, rtol=1e-05, atol=1e-03, allow_zero_variants=False):
"""Like :py:func:`numpy.testing.assert_allclose()`, but for VCF files.

Expand All @@ -48,7 +106,7 @@ def assert_vcfs_close(f1, f2, *, rtol=1e-05, atol=1e-03, allow_zero_variants=Fal
Absolute tolerance.
"""
with open_vcf(f1) as vcf1, open_vcf(f2) as vcf2:
assert vcf1.raw_header == vcf2.raw_header
_assert_vcf_headers_equivalent(vcf1, vcf2)
assert vcf1.samples == vcf2.samples

count = 0
Expand Down Expand Up @@ -145,15 +203,15 @@ def vcz_path_cache(vcf_path):
cached_vcz_path = (cache_path / vcf_path.name).with_suffix(".vcz")
if not cached_vcz_path.exists():
if vcf_path.name.startswith("chr22"):
vcf2zarr.convert(
icf.convert(
[vcf_path],
cached_vcz_path,
worker_processes=0,
variants_chunk_size=10,
samples_chunk_size=10,
)
else:
vcf2zarr.convert(
icf.convert(
[vcf_path], cached_vcz_path, worker_processes=0, local_alleles=False
)
return cached_vcz_path
66 changes: 18 additions & 48 deletions vcztools/vcf_writer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import io
import logging
import re
import sys
from datetime import datetime
from typing import Optional
Expand Down Expand Up @@ -145,11 +144,9 @@ def write_vcf(
)

if not no_header:
original_header = root.attrs.get("vcf_header", None)
force_ac_an_header = not drop_genotypes and samples_selection is not None
vcf_header = _generate_header(
root,
original_header,
sample_ids,
no_version=no_version,
force_ac_an=force_ac_an_header,
Expand Down Expand Up @@ -300,7 +297,6 @@ def c_chunk_to_vcf(

def _generate_header(
ds,
original_header,
sample_ids,
*,
no_version: bool = False,
Expand Down Expand Up @@ -340,45 +336,12 @@ def _generate_header(
if key in ("genotype", "genotype_phased"):
continue
format_fields.append(key)
if original_header is None: # generate entire header
# [1.4.1 File format]
print("##fileformat=VCFv4.3", file=output)

if "source" in ds.attrs:
print(f'##source={ds.attrs["source"]}', file=output)

else: # use original header fields where appropriate
unstructured_pattern = re.compile("##([^=]+)=([^<].*)")
structured_pattern = re.compile("##([^=]+)=(<.*)")

for line in original_header.split("\n"):
if re.fullmatch(unstructured_pattern, line):
print(line, file=output)
else:
match = re.fullmatch(structured_pattern, line)
if match:
category = match.group(1)
id_pattern = re.compile("ID=([^,>]+)")
key = id_pattern.findall(line)[0]
if category not in ("contig", "FILTER", "INFO", "FORMAT"):
# output other structured fields
print(line, file=output)
# only output certain categories if in dataset
elif category == "contig" and key in contigs:
contigs.remove(key)
print(line, file=output)
elif category == "FILTER" and key in filters:
filters.remove(key)
print(line, file=output)
elif category == "INFO" and key in info_fields:
info_fields.remove(key)
print(line, file=output)
elif category == "FORMAT" and key in format_fields:
format_fields.remove(key)
print(line, file=output)

# add all fields that are not in the original header
# or all fields if there was no original header

# [1.4.1 File format]
print("##fileformat=VCFv4.3", file=output)

if "source" in ds.attrs:
print(f'##source={ds.attrs["source"]}', file=output)

# [1.4.2 Information field format]
for key in info_fields:
Expand Down Expand Up @@ -406,8 +369,17 @@ def _generate_header(
)

# [1.4.3 Filter field format]
for filter in filters:
print(f'##FILTER=<ID={filter},Description="">', file=output)
filter_descriptions = (
ds["filter_description"] if "filter_description" in ds else None
)
for i, filter in enumerate(filters):
filter_description = (
"" if filter_descriptions is None else filter_descriptions[i]
)
print(
f'##FILTER=<ID={filter},Description="{filter_description}">',
file=output,
)

# [1.4.4 Individual format field format]
for key in format_fields:
Expand All @@ -430,9 +402,7 @@ def _generate_header(
)

# [1.4.7 Contig field format]
contig_lengths = (
ds.attrs["contig_lengths"] if "contig_lengths" in ds.attrs else None
)
contig_lengths = ds["contig_length"] if "contig_length" in ds else None
for i, contig in enumerate(contigs):
if contig_lengths is None:
print(f"##contig=<ID={contig}>", file=output)
Expand Down
Loading