@@ -829,7 +829,7 @@ def __exit__(self, exc_type, exc_val, exc_tb):
829829 return False
830830
831831
832- def convert_local_allele_field_types (fields ):
832+ def convert_local_allele_field_types (fields , schema_instance ):
833833 """
834834 Update the specified list of fields to include the LAA field, and to convert
835835 any supported localisable fields to the L* counterpart.
@@ -842,45 +842,45 @@ def convert_local_allele_field_types(fields):
842842 """
843843 fields_by_name = {field .name : field for field in fields }
844844 gt = fields_by_name ["call_genotype" ]
845- if gt .shape [- 1 ] != 2 :
846- raise ValueError ("Local alleles only supported on diploid data" )
847845
848- # TODO check if LA is already in here
846+ if schema_instance .get_shape (["ploidy" ])[0 ] != 2 :
847+ raise ValueError ("Local alleles only supported on diploid data" )
849848
850- shape = gt .shape [:- 1 ]
851- chunks = gt .chunks [:- 1 ]
852849 dimensions = gt .dimensions [:- 1 ]
853850
854851 la = vcz .ZarrArraySpec (
855852 name = "call_LA" ,
856853 dtype = "i1" ,
857- shape = gt .shape ,
858- chunks = gt .chunks ,
859854 dimensions = (* dimensions , "local_alleles" ),
860855 description = (
861856 "0-based indices into REF+ALT, indicating which alleles"
862857 " are relevant (local) for the current sample"
863858 ),
864859 )
860+ schema_instance .dimensions ["local_alleles" ] = vcz .VcfZarrDimension (
861+ size = schema_instance .dimensions ["ploidy" ].size
862+ )
863+
865864 ad = fields_by_name .get ("call_AD" , None )
866865 if ad is not None :
867866 # TODO check if call_LAD is in the list already
868867 ad .name = "call_LAD"
869868 ad .source = None
870- ad .shape = (* shape , 2 )
871- ad .chunks = (* chunks , 2 )
872- ad .dimensions = (* dimensions , "local_alleles" )
869+ ad .dimensions = (* dimensions , "local_alleles_AD" )
873870 ad .description += " (local-alleles)"
871+ schema_instance .dimensions ["local_alleles_AD" ] = vcz .VcfZarrDimension (size = 2 )
874872
875873 pl = fields_by_name .get ("call_PL" , None )
876874 if pl is not None :
877875 # TODO check if call_LPL is in the list already
878876 pl .name = "call_LPL"
879877 pl .source = None
880- pl .shape = (* shape , 3 )
881- pl .chunks = (* chunks , 3 )
882878 pl .description += " (local-alleles)"
883- pl .dimensions = (* dimensions , "local_" + pl .dimensions [- 1 ])
879+ pl .dimensions = (* dimensions , "local_" + pl .dimensions [- 1 ].split ("_" )[- 1 ])
880+ schema_instance .dimensions ["local_" + pl .dimensions [- 1 ].split ("_" )[- 1 ]] = (
881+ vcz .VcfZarrDimension (size = 3 )
882+ )
883+
884884 return [* fields , la ]
885885
886886
@@ -1042,36 +1042,40 @@ def generate_schema(
10421042 if local_alleles is None :
10431043 local_alleles = False
10441044
1045+ dimensions = {
1046+ "variants" : vcz .VcfZarrDimension (
1047+ size = m , chunk_size = variants_chunk_size or vcz .DEFAULT_VARIANT_CHUNK_SIZE
1048+ ),
1049+ "samples" : vcz .VcfZarrDimension (
1050+ size = n , chunk_size = samples_chunk_size or vcz .DEFAULT_SAMPLE_CHUNK_SIZE
1051+ ),
1052+ # ploidy added conditionally below
1053+ "alleles" : vcz .VcfZarrDimension (
1054+ size = max (self .fields ["ALT" ].vcf_field .summary .max_number + 1 , 2 )
1055+ ),
1056+ "filters" : vcz .VcfZarrDimension (size = self .metadata .num_filters ),
1057+ }
1058+
10451059 schema_instance = vcz .VcfZarrSchema (
10461060 format_version = vcz .ZARR_SCHEMA_FORMAT_VERSION ,
1047- samples_chunk_size = samples_chunk_size ,
1048- variants_chunk_size = variants_chunk_size ,
1061+ dimensions = dimensions ,
10491062 fields = [],
10501063 )
10511064
10521065 logger .info (
10531066 "Generating schema with chunks="
1054- f"{ schema_instance .variants_chunk_size , schema_instance .samples_chunk_size } "
1067+ f"variants={ dimensions ['variants' ].chunk_size } , "
1068+ f"samples={ dimensions ['samples' ].chunk_size } "
10551069 )
10561070
10571071 def spec_from_field (field , array_name = None ):
10581072 return vcz .ZarrArraySpec .from_field (
10591073 field ,
1060- num_samples = n ,
1061- num_variants = m ,
1062- samples_chunk_size = schema_instance .samples_chunk_size ,
1063- variants_chunk_size = schema_instance .variants_chunk_size ,
1074+ schema_instance ,
10641075 array_name = array_name ,
10651076 )
10661077
1067- def fixed_field_spec (
1068- name ,
1069- dtype ,
1070- source = None ,
1071- shape = (m ,),
1072- dimensions = ("variants" ,),
1073- chunks = None ,
1074- ):
1078+ def fixed_field_spec (name , dtype , source = None , dimensions = ("variants" ,)):
10751079 compressor = (
10761080 vcz .DEFAULT_ZARR_COMPRESSOR_BOOL .get_config ()
10771081 if dtype == "bool"
@@ -1081,16 +1085,11 @@ def fixed_field_spec(
10811085 source = source ,
10821086 name = name ,
10831087 dtype = dtype ,
1084- shape = shape ,
10851088 description = "" ,
10861089 dimensions = dimensions ,
1087- chunks = chunks or [schema_instance .variants_chunk_size ],
10881090 compressor = compressor ,
10891091 )
10901092
1091- alt_field = self .fields ["ALT" ]
1092- max_alleles = alt_field .vcf_field .summary .max_number + 1
1093-
10941093 array_specs = [
10951094 fixed_field_spec (
10961095 name = "variant_contig" ,
@@ -1099,16 +1098,12 @@ def fixed_field_spec(
10991098 fixed_field_spec (
11001099 name = "variant_filter" ,
11011100 dtype = "bool" ,
1102- shape = (m , self .metadata .num_filters ),
11031101 dimensions = ["variants" , "filters" ],
1104- chunks = (schema_instance .variants_chunk_size , self .metadata .num_filters ),
11051102 ),
11061103 fixed_field_spec (
11071104 name = "variant_allele" ,
11081105 dtype = "O" ,
1109- shape = (m , max_alleles ),
11101106 dimensions = ["variants" , "alleles" ],
1111- chunks = (schema_instance .variants_chunk_size , max_alleles ),
11121107 ),
11131108 fixed_field_spec (
11141109 name = "variant_id" ,
@@ -1142,32 +1137,23 @@ def fixed_field_spec(
11421137
11431138 if gt_field is not None and n > 0 :
11441139 ploidy = max (gt_field .summary .max_number - 1 , 1 )
1145- shape = [m , n ]
1146- chunks = [
1147- schema_instance .variants_chunk_size ,
1148- schema_instance .samples_chunk_size ,
1149- ]
1150- dimensions = ["variants" , "samples" ]
1140+ # Add ploidy dimension only when needed
1141+ schema_instance .dimensions ["ploidy" ] = vcz .VcfZarrDimension (size = ploidy )
1142+
11511143 array_specs .append (
11521144 vcz .ZarrArraySpec (
11531145 name = "call_genotype_phased" ,
11541146 dtype = "bool" ,
1155- shape = list (shape ),
1156- chunks = list (chunks ),
1157- dimensions = list (dimensions ),
1147+ dimensions = ["variants" , "samples" ],
11581148 description = "" ,
1149+ compressor = vcz .DEFAULT_ZARR_COMPRESSOR_BOOL .get_config (),
11591150 )
11601151 )
1161- shape += [ploidy ]
1162- chunks += [ploidy ]
1163- dimensions += ["ploidy" ]
11641152 array_specs .append (
11651153 vcz .ZarrArraySpec (
11661154 name = "call_genotype" ,
11671155 dtype = gt_field .smallest_dtype (),
1168- shape = list (shape ),
1169- chunks = list (chunks ),
1170- dimensions = list (dimensions ),
1156+ dimensions = ["variants" , "samples" , "ploidy" ],
11711157 description = "" ,
11721158 compressor = vcz .DEFAULT_ZARR_COMPRESSOR_GENOTYPES .get_config (),
11731159 )
@@ -1176,16 +1162,14 @@ def fixed_field_spec(
11761162 vcz .ZarrArraySpec (
11771163 name = "call_genotype_mask" ,
11781164 dtype = "bool" ,
1179- shape = list (shape ),
1180- chunks = list (chunks ),
1181- dimensions = list (dimensions ),
1165+ dimensions = ["variants" , "samples" , "ploidy" ],
11821166 description = "" ,
11831167 compressor = vcz .DEFAULT_ZARR_COMPRESSOR_BOOL .get_config (),
11841168 )
11851169 )
11861170
11871171 if local_alleles :
1188- array_specs = convert_local_allele_field_types (array_specs )
1172+ array_specs = convert_local_allele_field_types (array_specs , schema_instance )
11891173
11901174 schema_instance .fields = array_specs
11911175 return schema_instance
0 commit comments