Skip to content

Commit c2b7a63

Browse files
Fixup tests
1 parent 19c6ac2 commit c2b7a63

File tree

3 files changed

+45
-12
lines changed

3 files changed

+45
-12
lines changed

bio2zarr/vcf2zarr/icf.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -354,13 +354,13 @@ def scan_vcfs(paths, show_progress, target_num_partitions, worker_processes=1):
354354
# Note: this will be infinity here if any of the chunks has an index
355355
# that doesn't keep track of the number of records per-contig
356356
icf_metadata.num_records = total_records
357-
if len(all_partitions) > 1:
358-
# Sort by contig (in the order they appear in the header) first,
359-
# then by start coordinate
360-
contig_index_map = {contig.id: j for j, contig in enumerate(metadata.contigs)}
361-
all_partitions.sort(
362-
key=lambda x: (contig_index_map[x.region.contig], x.region.start)
363-
)
357+
358+
# Sort by contig (in the order they appear in the header) first,
359+
# then by start coordinate
360+
contig_index_map = {contig.id: j for j, contig in enumerate(metadata.contigs)}
361+
all_partitions.sort(
362+
key=lambda x: (contig_index_map[x.region.contig], x.region.start)
363+
)
364364
icf_metadata.partitions = all_partitions
365365
logger.info(f"Scan complete, resulting in {len(all_partitions)} partitions.")
366366
return icf_metadata, header

bio2zarr/vcf_utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,9 @@ class Region:
9090

9191
def __post_init__(self):
9292
assert self.contig is not None
93-
if self.start is not None:
93+
if self.start is None:
94+
self.start = 1
95+
else:
9496
self.start = int(self.start)
9597
assert self.start > 0
9698
if self.end is not None:

tests/test_simulated_data.py

Lines changed: 35 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import sys
22

33
import numpy.testing as nt
4+
import pysam
45
import pytest
56
import sgkit as sg
67

@@ -41,12 +42,13 @@ def assert_ts_ds_equal(ts, ds, ploidy=1):
4142
nt.assert_equal(ds.variant_position, ts.sites_position)
4243

4344

44-
def write_vcf(ts, vcf_path, contig_id="1"):
45+
def write_vcf(ts, vcf_path, contig_id="1", indexed=False):
4546
with open(vcf_path, "w") as f:
4647
ts.write_vcf(f, contig_id=contig_id)
47-
# # This also compresses the input file
48-
# pysam.tabix_index(str(vcf_path), preset="vcf")
49-
# return vcf_path.with_suffix(vcf_path.suffix + ".gz")
48+
if indexed:
49+
# This also compresses the input file
50+
pysam.tabix_index(str(vcf_path), preset="vcf")
51+
vcf_path = vcf_path.with_suffix(vcf_path.suffix + ".gz")
5052
return vcf_path
5153

5254

@@ -75,6 +77,7 @@ def test_multi_contig(self, contig_ids, tmp_path):
7577
vcfs.append(vcf_path)
7678
tss[contig_id] = ts
7779

80+
def validate_tss_vcf_list(self, contig_ids, tss, vcfs, tmp_path):
7881
out = tmp_path / "example.vcf.zarr"
7982
vcf2zarr.convert(vcfs, out)
8083
ds = sg.load_dataset(out).set_index(
@@ -93,6 +96,34 @@ def test_multi_contig(self, contig_ids, tmp_path):
9396
dss = ds.sel(variants=(contig, slice(0, None)))
9497
assert_ts_ds_equal(tss[contig_id], dss)
9598

99+
@pytest.mark.parametrize("indexed", [True, False])
100+
def test_indexed(self, indexed, tmp_path):
101+
ts = run_simulation(num_samples=12, seed=34)
102+
vcf_path = write_vcf(ts, tmp_path / "sim.vcf", indexed=indexed)
103+
out = tmp_path / "example.vcf.zarr"
104+
vcf2zarr.convert([vcf_path], out)
105+
ds = sg.load_dataset(out)
106+
assert_ts_ds_equal(ts, ds)
107+
108+
@pytest.mark.parametrize("num_contigs", [2, 3, 6])
109+
def test_mixed_indexed(self, num_contigs, tmp_path):
110+
contig_ids = [f"x{j}" for j in range(num_contigs)]
111+
112+
vcfs = []
113+
tss = {}
114+
for seed, contig_id in enumerate(contig_ids, 1):
115+
ts = run_simulation(num_samples=3, seed=seed)
116+
vcf_path = write_vcf(
117+
ts,
118+
tmp_path / f"{contig_id}.vcf",
119+
contig_id=contig_id,
120+
indexed=seed % 2 == 0,
121+
)
122+
vcfs.append(vcf_path)
123+
tss[contig_id] = ts
124+
125+
self.validate_tss_vcf_list(contig_ids, tss, vcfs, tmp_path)
126+
96127

97128
# https://github.com/sgkit-dev/bio2zarr/issues/336
98129
@pytest.mark.skipif(sys.platform == "darwin", reason="msprime OSX pip packages broken")

0 commit comments

Comments
 (0)