Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# 0.1.5 2025-03-xx

- Add support for merging contig IDs across multiple VCFs (#342)
- Add support for merging contig IDs across multiple VCFs (#335)

- Add support for unindexed (and uncompressed) VCFs (#337)

# 0.1.4 2025-03-10

Expand Down
4 changes: 2 additions & 2 deletions bio2zarr/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -624,8 +624,8 @@ def vcfpartition(vcfs, verbose, num_partitions, partition_size):
num_parts_per_path = max(1, num_partitions // len(vcfs))

for vcf_path in vcfs:
indexed_vcf = vcf_utils.IndexedVcf(vcf_path)
regions = indexed_vcf.partition_into_regions(
vcf_file = vcf_utils.VcfFile(vcf_path)
regions = vcf_file.partition_into_regions(
num_parts=num_parts_per_path, target_part_size=partition_size
)
for region in regions:
Expand Down
12 changes: 6 additions & 6 deletions bio2zarr/vcf2zarr/icf.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,8 +228,8 @@ def make_field_def(name, vcf_type, vcf_number):


def scan_vcf(path, target_num_partitions):
with vcf_utils.IndexedVcf(path) as indexed_vcf:
vcf = indexed_vcf.vcf
with vcf_utils.VcfFile(path) as vcf_file:
vcf = vcf_file.vcf
filters = []
pass_index = -1
for h in vcf.header_iter():
Expand Down Expand Up @@ -270,10 +270,10 @@ def scan_vcf(path, target_num_partitions):
filters=filters,
fields=fields,
partitions=[],
num_records=sum(indexed_vcf.contig_record_counts().values()),
num_records=sum(vcf_file.contig_record_counts().values()),
)

regions = indexed_vcf.partition_into_regions(num_parts=target_num_partitions)
regions = vcf_file.partition_into_regions(num_parts=target_num_partitions)
for region in regions:
metadata.partitions.append(
VcfPartition(
Expand Down Expand Up @@ -1093,9 +1093,9 @@ def process_partition(self, partition_index):
self.path,
partition_index,
) as tcw:
with vcf_utils.IndexedVcf(partition.vcf_path) as ivcf:
with vcf_utils.VcfFile(partition.vcf_path) as vcf:
num_records = 0
for variant in ivcf.variants(partition.region):
for variant in vcf.variants(partition.region):
num_records += 1
last_position = variant.POS
tcw.append("CHROM", variant.CHROM)
Expand Down
76 changes: 54 additions & 22 deletions bio2zarr/vcf_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,10 @@ class Region:
end: Optional[int] = None

def __post_init__(self):
if self.start is not None:
assert self.contig is not None
if self.start is None:
self.start = 1
else:
self.start = int(self.start)
assert self.start > 0
if self.end is not None:
Expand Down Expand Up @@ -393,9 +396,12 @@ class VcfIndexType(Enum):
TABIX = ".tbi"


class IndexedVcf(contextlib.AbstractContextManager):
class VcfFile(contextlib.AbstractContextManager):
def __init__(self, vcf_path, index_path=None):
self.vcf = None
self.file_type = None
self.index_type = None

vcf_path = pathlib.Path(vcf_path)
if not vcf_path.exists():
raise FileNotFoundError(vcf_path)
Expand All @@ -408,30 +414,34 @@ def __init__(self, vcf_path, index_path=None):
vcf_path.suffix + VcfIndexType.CSI.value
)
if not index_path.exists():
raise FileNotFoundError(
f"Cannot find .tbi or .csi file for {vcf_path}"
)
# No supported index found
index_path = None
else:
index_path = pathlib.Path(index_path)
if not index_path.exists():
raise FileNotFoundError(
f"Specified index path {index_path} does not exist"
)

self.vcf_path = vcf_path
self.index_path = index_path
self.file_type = None
self.index_type = None

if index_path.suffix == VcfIndexType.CSI.value:
self.index_type = VcfIndexType.CSI
elif index_path.suffix == VcfIndexType.TABIX.value:
self.index_type = VcfIndexType.TABIX
self.file_type = VcfFileType.VCF
else:
raise ValueError("Only .tbi or .csi indexes are supported.")
if index_path is not None:
if index_path.suffix == VcfIndexType.CSI.value:
self.index_type = VcfIndexType.CSI
elif index_path.suffix == VcfIndexType.TABIX.value:
self.index_type = VcfIndexType.TABIX
self.file_type = VcfFileType.VCF
else:
raise ValueError("Only .tbi or .csi indexes are supported.")

self.vcf = cyvcf2.VCF(vcf_path)
self.vcf.set_index(str(self.index_path))
if self.index_path is not None:
self.vcf.set_index(str(self.index_path))

logger.debug(f"Loaded {vcf_path} with index {self.index_path}")
self.sequence_names = None

self.index = None
if self.index_type == VcfIndexType.CSI:
# Determine the file-type based on the "aux" field.
self.index = read_csi(self.index_path)
Expand All @@ -441,9 +451,17 @@ def __init__(self, vcf_path, index_path=None):
self.sequence_names = self.index.parse_vcf_aux()
else:
self.sequence_names = self.vcf.seqnames
else:
elif self.index_type == VcfIndexType.TABIX:
self.index = read_tabix(self.index_path)
self.file_type = VcfFileType.VCF
self.sequence_names = self.index.sequence_names
else:
assert self.index is None
var = next(self.vcf)
self.sequence_names = [var.CHROM]
self.vcf.close()
# There doesn't seem to be a way to reset the iterator
self.vcf = cyvcf2.VCF(vcf_path)

def __exit__(self, exc_type, exc_val, exc_tb):
if self.vcf is not None:
Expand All @@ -452,6 +470,8 @@ def __exit__(self, exc_type, exc_val, exc_tb):
return False

def contig_record_counts(self):
if self.index is None:
return {self.sequence_names[0]: RECORD_COUNT_UNKNOWN}
d = dict(zip(self.sequence_names, self.index.record_counts))
if self.file_type == VcfFileType.BCF:
d = {k: v for k, v in d.items() if v > 0}
Expand All @@ -460,12 +480,21 @@ def contig_record_counts(self):
def count_variants(self, region):
return sum(1 for _ in self.variants(region))

def variants(self, region):
start = 1 if region.start is None else region.start
for var in self.vcf(str(region)):
# Need to filter because of indels overlapping the region
if var.POS >= start:
def variants(self, region=None):
if self.index is None:
contig = self.sequence_names[0]
if region is not None:
assert region.contig == contig
for var in self.vcf:
if var.CHROM != contig:
raise ValueError("Multi-contig VCFs must be indexed")
yield var
else:
start = 1 if region.start is None else region.start
for var in self.vcf(str(region)):
# Need to filter because of indels overlapping the region
if var.POS >= start:
yield var

def _filter_empty_and_refine(self, regions):
"""
Expand Down Expand Up @@ -505,6 +534,9 @@ def partition_into_regions(
if target_part_size_bytes < 1:
raise ValueError("target_part_size must be positive")

if self.index is None:
return [Region(self.sequence_names[0])]

# Calculate the desired part file boundaries
file_length = os.stat(self.vcf_path).st_size
if num_parts is not None:
Expand Down
41 changes: 37 additions & 4 deletions tests/test_simulated_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,14 @@ def assert_ts_ds_equal(ts, ds, ploidy=1):
nt.assert_equal(ds.variant_position, ts.sites_position)


def write_vcf(ts, vcf_path, contig_id="1"):
def write_vcf(ts, vcf_path, contig_id="1", indexed=False):
with open(vcf_path, "w") as f:
ts.write_vcf(f, contig_id=contig_id)
# This also compresses the input file
pysam.tabix_index(str(vcf_path), preset="vcf")
return vcf_path.with_suffix(vcf_path.suffix + ".gz")
if indexed:
# This also compresses the input file
pysam.tabix_index(str(vcf_path), preset="vcf")
vcf_path = vcf_path.with_suffix(vcf_path.suffix + ".gz")
return vcf_path


# https://github.com/sgkit-dev/bio2zarr/issues/336
Expand Down Expand Up @@ -75,6 +77,9 @@ def test_multi_contig(self, contig_ids, tmp_path):
vcfs.append(vcf_path)
tss[contig_id] = ts

self.validate_tss_vcf_list(contig_ids, tss, vcfs, tmp_path)

def validate_tss_vcf_list(self, contig_ids, tss, vcfs, tmp_path):
out = tmp_path / "example.vcf.zarr"
vcf2zarr.convert(vcfs, out)
ds = sg.load_dataset(out).set_index(
Expand All @@ -93,6 +98,34 @@ def test_multi_contig(self, contig_ids, tmp_path):
dss = ds.sel(variants=(contig, slice(0, None)))
assert_ts_ds_equal(tss[contig_id], dss)

@pytest.mark.parametrize("indexed", [True, False])
def test_indexed(self, indexed, tmp_path):
ts = run_simulation(num_samples=12, seed=34)
vcf_path = write_vcf(ts, tmp_path / "sim.vcf", indexed=indexed)
out = tmp_path / "example.vcf.zarr"
vcf2zarr.convert([vcf_path], out)
ds = sg.load_dataset(out)
assert_ts_ds_equal(ts, ds)

@pytest.mark.parametrize("num_contigs", [2, 3, 6])
def test_mixed_indexed(self, num_contigs, tmp_path):
contig_ids = [f"x{j}" for j in range(num_contigs)]

vcfs = []
tss = {}
for seed, contig_id in enumerate(contig_ids, 1):
ts = run_simulation(num_samples=3, seed=seed)
vcf_path = write_vcf(
ts,
tmp_path / f"{contig_id}.vcf",
contig_id=contig_id,
indexed=seed % 2 == 0,
)
vcfs.append(vcf_path)
tss[contig_id] = ts

self.validate_tss_vcf_list(contig_ids, tss, vcfs, tmp_path)


# https://github.com/sgkit-dev/bio2zarr/issues/336
@pytest.mark.skipif(sys.platform == "darwin", reason="msprime OSX pip packages broken")
Expand Down
85 changes: 75 additions & 10 deletions tests/test_vcf_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import pathlib
import shutil

import numpy as np
import pytest
Expand All @@ -13,24 +14,24 @@ def assert_part_counts_non_zero(part_counts, index_file):
assert np.all(part_counts > 0)


class TestIndexedVcf:
class TestVcfFile:
def get_instance(self, index_file):
vcf_path = data_path / (".".join(list(index_file.split("."))[:-1]))
return vcf_utils.IndexedVcf(vcf_path, data_path / index_file)
return vcf_utils.VcfFile(vcf_path, data_path / index_file)

def test_context_manager_success(self):
# Nominal case
with vcf_utils.IndexedVcf(data_path / "sample.bcf") as iv:
with vcf_utils.VcfFile(data_path / "sample.bcf") as iv:
assert iv.vcf is not None
assert iv.vcf is None

def test_context_manager_error(self):
with pytest.raises(FileNotFoundError, match="no-such-file"):
with vcf_utils.IndexedVcf(data_path / "no-such-file.bcf"):
with vcf_utils.VcfFile(data_path / "no-such-file.bcf"):
pass

def test_indels_filtered(self):
with vcf_utils.IndexedVcf(data_path / "chr_m_indels.vcf.gz") as vfile:
with vcf_utils.VcfFile(data_path / "chr_m_indels.vcf.gz") as vfile:
# Hand-picked example that results in filtering
region = vcf_utils.Region("chrM", 300, 314)
pos = [var.POS for var in vfile.variants(region)]
Expand Down Expand Up @@ -159,6 +160,27 @@ def test_partition_into_n_parts(self, index_file, total_records, num_parts):
assert np.sum(part_variant_counts) == total_records
assert_part_counts_non_zero(part_variant_counts, index_file)

@pytest.mark.parametrize(
("vcf_file", "total_records"),
[
("1kg_2020_chrM.vcf.gz", 23),
("1kg_2020_chr20_annotations.bcf", 21),
],
)
@pytest.mark.parametrize("num_parts", [1, 2, 3])
def test_partition_into_n_parts_unindexed(
self, tmp_path, vcf_file, total_records, num_parts
):
copy_path = tmp_path / vcf_file
shutil.copyfile(data_path / vcf_file, copy_path)
indexed_vcf = vcf_utils.VcfFile(copy_path)
regions = list(indexed_vcf.partition_into_regions(num_parts=num_parts))
assert len(regions) == 1
part_variant_counts = np.array(
[indexed_vcf.count_variants(region) for region in regions]
)
assert np.sum(part_variant_counts) == total_records

def test_tabix_multi_chrom_bug(self):
indexed_vcf = self.get_instance("multi_contig.vcf.gz.tbi")
regions = list(indexed_vcf.partition_into_regions(num_parts=10))
Expand Down Expand Up @@ -213,11 +235,54 @@ def test_partition_invalid_arguments(self):
with pytest.raises(ValueError, match=r"target_part_size must be positive"):
indexed_vcf.partition_into_regions(target_part_size=0)

def test_bad_index(self):
with pytest.raises(
ValueError, match=r"Only .tbi or .csi indexes are supported."
):
vcf_utils.IndexedVcf(data_path / "sample.vcf.gz", "y")
@pytest.mark.parametrize("path", ["y", data_path / "xxx", "/x/y.csi"])
def test_missing_index_file(self, path):
with pytest.raises(FileNotFoundError, match="Specified index path"):
vcf_utils.VcfFile(data_path / "sample.vcf.gz", path)

def test_bad_index_format(self):
vcf_file = data_path / "sample.vcf.gz"
with pytest.raises(ValueError, match="Only .tbi or .csi indexes"):
vcf_utils.VcfFile(vcf_file, vcf_file)

@pytest.mark.parametrize(
"filename",
[
"1kg_2020_chrM.vcf.gz",
"1kg_2020_chrM.bcf",
"1kg_2020_chr20_annotations.bcf",
"chr_m_indels.vcf.gz",
"NA12878.prod.chr20snippet.g.vcf.gz",
],
)
def test_unindexed_single_contig(self, tmp_path, filename):
f1 = vcf_utils.VcfFile(data_path / filename)
assert f1.index is not None
copy_path = tmp_path / filename
shutil.copyfile(data_path / filename, copy_path)
f2 = vcf_utils.VcfFile(copy_path)
assert f2.index is None
crc1 = f1.contig_record_counts()
assert len(crc1) == 1
contig = next(iter(crc1.keys()))
assert f2.contig_record_counts() == {contig: np.inf}
region = vcf_utils.Region(contig)
# The full variants returned by cyvcf2 don't compare for equality,
# so just check the chrom/pos values
v1 = [(v.CHROM, v.POS) for v in f1.variants(region)]
v2 = [(v.CHROM, v.POS) for v in f2.variants(region)]
assert v1 == v2

@pytest.mark.parametrize(
"filename",
["sample.vcf.gz", "sample.bcf", "multi_contig.vcf.gz"],
)
def test_unindexed_multi_contig(self, tmp_path, filename):
copy_path = tmp_path / filename
shutil.copyfile(data_path / filename, copy_path)
f = vcf_utils.VcfFile(copy_path)
with pytest.raises(ValueError, match="Multi-contig VCFs must be indexed"):
list(f.variants())


class TestCsiIndex:
Expand Down