@@ -89,7 +89,10 @@ class Region:
8989 end : Optional [int ] = None
9090
9191 def __post_init__ (self ):
92- if self .start is not None :
92+ assert self .contig is not None
93+ if self .start is None :
94+ self .start = 1
95+ else :
9396 self .start = int (self .start )
9497 assert self .start > 0
9598 if self .end is not None :
@@ -396,6 +399,9 @@ class VcfIndexType(Enum):
396399class IndexedVcf (contextlib .AbstractContextManager ):
397400 def __init__ (self , vcf_path , index_path = None ):
398401 self .vcf = None
402+ self .file_type = None
403+ self .index_type = None
404+
399405 vcf_path = pathlib .Path (vcf_path )
400406 if not vcf_path .exists ():
401407 raise FileNotFoundError (vcf_path )
@@ -408,30 +414,34 @@ def __init__(self, vcf_path, index_path=None):
408414 vcf_path .suffix + VcfIndexType .CSI .value
409415 )
410416 if not index_path .exists ():
411- raise FileNotFoundError (
412- f"Cannot find .tbi or .csi file for { vcf_path } "
413- )
417+ # No supported index found
418+ index_path = None
414419 else :
415420 index_path = pathlib .Path (index_path )
421+ if not index_path .exists ():
422+ raise FileNotFoundError (
423+ f"Specified index path { index_path } does not exist"
424+ )
416425
417426 self .vcf_path = vcf_path
418427 self .index_path = index_path
419- self .file_type = None
420- self .index_type = None
421-
422- if index_path .suffix == VcfIndexType .CSI .value :
423- self .index_type = VcfIndexType .CSI
424- elif index_path .suffix == VcfIndexType .TABIX .value :
425- self .index_type = VcfIndexType .TABIX
426- self .file_type = VcfFileType .VCF
427- else :
428- raise ValueError ("Only .tbi or .csi indexes are supported." )
428+ if index_path is not None :
429+ if index_path .suffix == VcfIndexType .CSI .value :
430+ self .index_type = VcfIndexType .CSI
431+ elif index_path .suffix == VcfIndexType .TABIX .value :
432+ self .index_type = VcfIndexType .TABIX
433+ self .file_type = VcfFileType .VCF
434+ else :
435+ raise ValueError ("Only .tbi or .csi indexes are supported." )
429436
430437 self .vcf = cyvcf2 .VCF (vcf_path )
431- self .vcf .set_index (str (self .index_path ))
438+ if self .index_path is not None :
439+ self .vcf .set_index (str (self .index_path ))
440+
432441 logger .debug (f"Loaded { vcf_path } with index { self .index_path } " )
433442 self .sequence_names = None
434443
444+ self .index = None
435445 if self .index_type == VcfIndexType .CSI :
436446 # Determine the file-type based on the "aux" field.
437447 self .index = read_csi (self .index_path )
@@ -441,9 +451,17 @@ def __init__(self, vcf_path, index_path=None):
441451 self .sequence_names = self .index .parse_vcf_aux ()
442452 else :
443453 self .sequence_names = self .vcf .seqnames
444- else :
454+ elif self . index_type == VcfIndexType . TABIX :
445455 self .index = read_tabix (self .index_path )
456+ self .file_type = VcfFileType .VCF
446457 self .sequence_names = self .index .sequence_names
458+ else :
459+ assert self .index is None
460+ var = next (self .vcf )
461+ self .sequence_names = [var .CHROM ]
462+ self .vcf .close ()
463+ # There doesn't seem to be a way to reset the iterator
464+ self .vcf = cyvcf2 .VCF (vcf_path )
447465
448466 def __exit__ (self , exc_type , exc_val , exc_tb ):
449467 if self .vcf is not None :
@@ -452,6 +470,8 @@ def __exit__(self, exc_type, exc_val, exc_tb):
452470 return False
453471
454472 def contig_record_counts (self ):
473+ if self .index is None :
474+ return {self .sequence_names [0 ]: np .inf }
455475 d = dict (zip (self .sequence_names , self .index .record_counts ))
456476 if self .file_type == VcfFileType .BCF :
457477 d = {k : v for k , v in d .items () if v > 0 }
@@ -460,12 +480,21 @@ def contig_record_counts(self):
460480 def count_variants (self , region ):
461481 return sum (1 for _ in self .variants (region ))
462482
463- def variants (self , region ):
464- start = 1 if region .start is None else region .start
465- for var in self .vcf (str (region )):
466- # Need to filter because of indels overlapping the region
467- if var .POS >= start :
483+ def variants (self , region = None ):
484+ if self .index is None :
485+ contig = self .sequence_names [0 ]
486+ if region is not None :
487+ assert region .contig == contig
488+ for var in self .vcf :
489+ if var .CHROM != contig :
490+ raise ValueError ("Multi-contig VCFs must be indexed" )
468491 yield var
492+ else :
493+ start = 1 if region .start is None else region .start
494+ for var in self .vcf (str (region )):
495+ # Need to filter because of indels overlapping the region
496+ if var .POS >= start :
497+ yield var
469498
470499 def _filter_empty_and_refine (self , regions ):
471500 """
@@ -505,6 +534,9 @@ def partition_into_regions(
505534 if target_part_size_bytes < 1 :
506535 raise ValueError ("target_part_size must be positive" )
507536
537+ if self .index is None :
538+ return [Region (self .sequence_names [0 ])]
539+
508540 # Calculate the desired part file boundaries
509541 file_length = os .stat (self .vcf_path ).st_size
510542 if num_parts is not None :
0 commit comments