@@ -89,9 +89,7 @@ class Region:
89
89
end : Optional [int ] = None
90
90
91
91
def __post_init__ (self ):
92
- if self .contig is None :
93
- return
94
-
92
+ assert self .contig is not None
95
93
if self .start is not None :
96
94
self .start = int (self .start )
97
95
assert self .start > 0
@@ -399,6 +397,9 @@ class VcfIndexType(Enum):
399
397
class IndexedVcf (contextlib .AbstractContextManager ):
400
398
def __init__ (self , vcf_path , index_path = None ):
401
399
self .vcf = None
400
+ self .file_type = None
401
+ self .index_type = None
402
+
402
403
vcf_path = pathlib .Path (vcf_path )
403
404
if not vcf_path .exists ():
404
405
raise FileNotFoundError (vcf_path )
@@ -411,27 +412,28 @@ def __init__(self, vcf_path, index_path=None):
411
412
vcf_path .suffix + VcfIndexType .CSI .value
412
413
)
413
414
if not index_path .exists ():
414
- # Use this as a proxy for "no index"
415
- index_path = vcf_path
415
+ # No supported index found
416
+ index_path = None
416
417
else :
417
418
index_path = pathlib .Path (index_path )
419
+ if not index_path .exists ():
420
+ raise FileNotFoundError (
421
+ f"Specified index path { index_path } does not exist"
422
+ )
418
423
419
424
self .vcf_path = vcf_path
420
425
self .index_path = index_path
421
- self .file_type = None
422
- self .index_type = None
423
-
424
- if index_path .suffix == VcfIndexType .CSI .value :
425
- self .index_type = VcfIndexType .CSI
426
- elif index_path .suffix == VcfIndexType .TABIX .value :
427
- self .index_type = VcfIndexType .TABIX
428
- self .file_type = VcfFileType .VCF
429
- # else:
430
-
431
- # raise ValueError("Only .tbi or .csi indexes are supported.")
426
+ if index_path is not None :
427
+ if index_path .suffix == VcfIndexType .CSI .value :
428
+ self .index_type = VcfIndexType .CSI
429
+ elif index_path .suffix == VcfIndexType .TABIX .value :
430
+ self .index_type = VcfIndexType .TABIX
431
+ self .file_type = VcfFileType .VCF
432
+ else :
433
+ raise ValueError ("Only .tbi or .csi indexes are supported." )
432
434
433
435
self .vcf = cyvcf2 .VCF (vcf_path )
434
- if self .index_type is not None :
436
+ if self .index_path is not None :
435
437
self .vcf .set_index (str (self .index_path ))
436
438
437
439
logger .debug (f"Loaded { vcf_path } with index { self .index_path } " )
@@ -449,7 +451,15 @@ def __init__(self, vcf_path, index_path=None):
449
451
self .sequence_names = self .vcf .seqnames
450
452
elif self .index_type == VcfIndexType .TABIX :
451
453
self .index = read_tabix (self .index_path )
454
+ self .file_type = VcfFileType .VCF
452
455
self .sequence_names = self .index .sequence_names
456
+ else :
457
+ assert self .index is None
458
+ var = next (self .vcf )
459
+ self .sequence_names = [var .CHROM ]
460
+ self .vcf .close ()
461
+ # There doesn't seem to be a way to reset the iterator
462
+ self .vcf = cyvcf2 .VCF (vcf_path )
453
463
454
464
def __exit__ (self , exc_type , exc_val , exc_tb ):
455
465
if self .vcf is not None :
@@ -459,7 +469,7 @@ def __exit__(self, exc_type, exc_val, exc_tb):
459
469
460
470
def contig_record_counts (self ):
461
471
if self .index is None :
462
- return {None : np .inf }
472
+ return {self . sequence_names [ 0 ] : np .inf }
463
473
d = dict (zip (self .sequence_names , self .index .record_counts ))
464
474
if self .file_type == VcfFileType .BCF :
465
475
d = {k : v for k , v in d .items () if v > 0 }
@@ -468,10 +478,15 @@ def contig_record_counts(self):
468
478
def count_variants (self , region ):
469
479
return sum (1 for _ in self .variants (region ))
470
480
471
- def variants (self , region ):
481
+ def variants (self , region = None ):
472
482
if self .index is None :
473
- assert region .contig is None
474
- yield from self .vcf
483
+ contig = self .sequence_names [0 ]
484
+ if region is not None :
485
+ assert region .contig == contig
486
+ for var in self .vcf :
487
+ if var .CHROM != contig :
488
+ raise ValueError ("Multi-contig VCFs must be indexed" )
489
+ yield var
475
490
else :
476
491
start = 1 if region .start is None else region .start
477
492
for var in self .vcf (str (region )):
@@ -498,9 +513,6 @@ def partition_into_regions(
498
513
num_parts : Optional [int ] = None ,
499
514
target_part_size : Union [None , int , str ] = None ,
500
515
):
501
- if self .index is None :
502
- return [Region ()]
503
-
504
516
if num_parts is None and target_part_size is None :
505
517
raise ValueError ("One of num_parts or target_part_size must be specified" )
506
518
@@ -520,6 +532,9 @@ def partition_into_regions(
520
532
if target_part_size_bytes < 1 :
521
533
raise ValueError ("target_part_size must be positive" )
522
534
535
+ if self .index is None :
536
+ return [Region (self .sequence_names [0 ])]
537
+
523
538
# Calculate the desired part file boundaries
524
539
file_length = os .stat (self .vcf_path ).st_size
525
540
if num_parts is not None :
0 commit comments