Skip to content

Commit 9ad0cf6

Browse files
Switch partition_into_regions to return an iterator
1 parent a1ddde0 commit 9ad0cf6

File tree

4 files changed

+21
-17
lines changed

4 files changed

+21
-17
lines changed

bio2zarr/cli.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -518,10 +518,10 @@ def plink2zarr():
518518
@click.command
519519
@version
520520
@click.argument("vcf_path", type=click.Path())
521-
@click.option("-i", "--index", type=click.Path(), default=None)
522521
@click.option("-n", "--num-parts", type=int, default=None)
523522
# @click.option("-s", "--part-size", type=int, default=None)
524-
def vcfpartition(vcf_path, index, num_parts):
525-
indexed_vcf = vcf_utils.IndexedVcf(vcf_path, index)
523+
def vcfpartition(vcf_path, num_parts):
524+
indexed_vcf = vcf_utils.IndexedVcf(vcf_path)
526525
regions = indexed_vcf.partition_into_regions(num_parts=num_parts)
527-
click.echo("\n".join(map(str, regions)))
526+
for region in regions:
527+
click.echo(region)

bio2zarr/vcf2zarr/icf.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -263,9 +263,6 @@ def scan_vcf(path, target_num_partitions):
263263
)
264264

265265
regions = indexed_vcf.partition_into_regions(num_parts=target_num_partitions)
266-
logger.info(
267-
f"Split {path} into {len(regions)} regions (target={target_num_partitions})"
268-
)
269266
for region in regions:
270267
metadata.partitions.append(
271268
VcfPartition(
@@ -275,6 +272,10 @@ def scan_vcf(path, target_num_partitions):
275272
region=region,
276273
)
277274
)
275+
logger.info(
276+
f"Split {path} into {len(metadata.partitions)} "
277+
f"partitions target={target_num_partitions})"
278+
)
278279
core.update_progress(1)
279280
return metadata, vcf.raw_header
280281

bio2zarr/vcf_utils.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -450,15 +450,16 @@ def variants(self, region):
450450
def _filter_empty_and_refine(self, regions):
451451
"""
452452
Return all regions in the specified list that have one or more records,
453-
and refine the start coordinate of the region to be the actual first coord
453+
and refine the start coordinate of the region to be the actual first coord.
454+
455+
Because this is a relatively expensive operation requiring seeking around
456+
the file, we return the results as an iterator.
454457
"""
455-
ret = []
456458
for region in regions:
457459
var = next(self.variants(region), None)
458460
if var is not None:
459461
region.start = var.POS
460-
ret.append(region)
461-
return ret
462+
yield region
462463

463464
def partition_into_regions(
464465
self,

tests/test_vcf_utils.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ def test_contig_record_counts(self, index_file, expected):
9595
)
9696
def test_partition_into_one_part(self, index_file, expected):
9797
indexed_vcf = self.get_instance(index_file)
98-
regions = indexed_vcf.partition_into_regions(num_parts=1)
98+
regions = list(indexed_vcf.partition_into_regions(num_parts=1))
9999
assert all(isinstance(r, vcf_utils.Region) for r in regions)
100100
assert [str(r) for r in regions] == expected
101101

@@ -120,7 +120,7 @@ def test_partition_into_one_part(self, index_file, expected):
120120
)
121121
def test_partition_into_max_parts(self, index_file, num_expected, total_records):
122122
indexed_vcf = self.get_instance(index_file)
123-
regions = indexed_vcf.partition_into_regions(num_parts=1000)
123+
regions = list(indexed_vcf.partition_into_regions(num_parts=1000))
124124
assert all(isinstance(r, vcf_utils.Region) for r in regions)
125125
# print(regions)
126126
assert len(regions) == num_expected
@@ -151,7 +151,7 @@ def test_partition_into_max_parts(self, index_file, num_expected, total_records)
151151
@pytest.mark.parametrize("num_parts", [2, 3, 4, 5, 16, 33])
152152
def test_partition_into_n_parts(self, index_file, total_records, num_parts):
153153
indexed_vcf = self.get_instance(index_file)
154-
regions = indexed_vcf.partition_into_regions(num_parts=num_parts)
154+
regions = list(indexed_vcf.partition_into_regions(num_parts=num_parts))
155155
assert all(isinstance(r, vcf_utils.Region) for r in regions)
156156
part_variant_counts = np.array(
157157
[indexed_vcf.count_variants(region) for region in regions]
@@ -161,7 +161,7 @@ def test_partition_into_n_parts(self, index_file, total_records, num_parts):
161161

162162
def test_tabix_multi_chrom_bug(self):
163163
indexed_vcf = self.get_instance("multi_contig.vcf.gz.tbi")
164-
regions = indexed_vcf.partition_into_regions(num_parts=10)
164+
regions = list(indexed_vcf.partition_into_regions(num_parts=10))
165165
# An earlier version of the code returned this, i.e. with a duplicate
166166
# for 4 with end coord of 0
167167
# ["0:1-", "1", "2", "3", "4:1-0", "4:1-"]
@@ -185,7 +185,9 @@ def test_tabix_multi_chrom_bug(self):
185185
)
186186
def test_target_part_size(self, target_part_size, filename):
187187
indexed_vcf = self.get_instance(filename)
188-
regions = indexed_vcf.partition_into_regions(target_part_size=target_part_size)
188+
regions = list(
189+
indexed_vcf.partition_into_regions(target_part_size=target_part_size)
190+
)
189191
assert len(regions) == 5
190192
part_variant_counts = [indexed_vcf.count_variants(region) for region in regions]
191193
assert part_variant_counts == [3450, 3869, 4525, 7041, 1025]
@@ -197,7 +199,7 @@ def test_partition_invalid_arguments(self):
197199
with pytest.raises(
198200
ValueError, match=r"One of num_parts or target_part_size must be specified"
199201
):
200-
indexed_vcf.partition_into_regions()
202+
list(indexed_vcf.partition_into_regions())
201203

202204
with pytest.raises(
203205
ValueError,

0 commit comments

Comments
 (0)