@@ -829,7 +829,7 @@ def __exit__(self, exc_type, exc_val, exc_tb):
829
829
return False
830
830
831
831
832
- def convert_local_allele_field_types (fields ):
832
+ def convert_local_allele_field_types (fields , schema_instance ):
833
833
"""
834
834
Update the specified list of fields to include the LAA field, and to convert
835
835
any supported localisable fields to the L* counterpart.
@@ -842,45 +842,43 @@ def convert_local_allele_field_types(fields):
842
842
"""
843
843
fields_by_name = {field .name : field for field in fields }
844
844
gt = fields_by_name ["call_genotype" ]
845
- if gt .shape [- 1 ] != 2 :
846
- raise ValueError ("Local alleles only supported on diploid data" )
847
845
848
- # TODO check if LA is already in here
846
+ if schema_instance .get_shape (["ploidy" ])[0 ] != 2 :
847
+ raise ValueError ("Local alleles only supported on diploid data" )
849
848
850
- shape = gt .shape [:- 1 ]
851
- chunks = gt .chunks [:- 1 ]
852
849
dimensions = gt .dimensions [:- 1 ]
853
850
854
851
la = vcz .ZarrArraySpec (
855
852
name = "call_LA" ,
856
853
dtype = "i1" ,
857
- shape = gt .shape ,
858
- chunks = gt .chunks ,
859
854
dimensions = (* dimensions , "local_alleles" ),
860
855
description = (
861
856
"0-based indices into REF+ALT, indicating which alleles"
862
857
" are relevant (local) for the current sample"
863
858
),
864
859
)
860
+ schema_instance .dimensions ["local_alleles" ] = {"size" : 1 }
861
+
865
862
ad = fields_by_name .get ("call_AD" , None )
866
863
if ad is not None :
867
864
# TODO check if call_LAD is in the list already
868
865
ad .name = "call_LAD"
869
866
ad .source = None
870
- ad .shape = (* shape , 2 )
871
- ad .chunks = (* chunks , 2 )
872
- ad .dimensions = (* dimensions , "local_alleles" )
867
+ ad .dimensions = (* dimensions , "local_alleles_AD" )
873
868
ad .description += " (local-alleles)"
869
+ schema_instance .dimensions ["local_alleles_AD" ] = {"size" : 2 }
874
870
875
871
pl = fields_by_name .get ("call_PL" , None )
876
872
if pl is not None :
877
873
# TODO check if call_LPL is in the list already
878
874
pl .name = "call_LPL"
879
875
pl .source = None
880
- pl .shape = (* shape , 3 )
881
- pl .chunks = (* chunks , 3 )
882
876
pl .description += " (local-alleles)"
883
- pl .dimensions = (* dimensions , "local_" + pl .dimensions [- 1 ])
877
+ pl .dimensions = (* dimensions , "local_" + pl .dimensions [- 1 ].split ("_" )[- 1 ])
878
+ schema_instance .dimensions ["local_" + pl .dimensions [- 1 ].split ("_" )[- 1 ]] = {
879
+ "size" : 3
880
+ }
881
+
884
882
return [* fields , la ]
885
883
886
884
@@ -1042,36 +1040,36 @@ def generate_schema(
1042
1040
if local_alleles is None :
1043
1041
local_alleles = False
1044
1042
1043
+ dimensions = {
1044
+ "variants" : {"size" : m , "chunk_size" : variants_chunk_size or 1000 },
1045
+ "samples" : {"size" : n , "chunk_size" : samples_chunk_size or 10000 },
1046
+ # ploidy added conditionally below
1047
+ "alleles" : {
1048
+ "size" : max (self .fields ["ALT" ].vcf_field .summary .max_number + 1 , 2 )
1049
+ },
1050
+ "filters" : {"size" : self .metadata .num_filters },
1051
+ }
1052
+
1045
1053
schema_instance = vcz .VcfZarrSchema (
1046
1054
format_version = vcz .ZARR_SCHEMA_FORMAT_VERSION ,
1047
- samples_chunk_size = samples_chunk_size ,
1048
- variants_chunk_size = variants_chunk_size ,
1055
+ dimensions = dimensions ,
1049
1056
fields = [],
1050
1057
)
1051
1058
1052
1059
logger .info (
1053
1060
"Generating schema with chunks="
1054
- f"{ schema_instance .variants_chunk_size , schema_instance .samples_chunk_size } "
1061
+ f"variants={ dimensions ['variants' ]['chunk_size' ]} , "
1062
+ f"samples={ dimensions ['samples' ]['chunk_size' ]} "
1055
1063
)
1056
1064
1057
1065
def spec_from_field (field , array_name = None ):
1058
1066
return vcz .ZarrArraySpec .from_field (
1059
1067
field ,
1060
- num_samples = n ,
1061
- num_variants = m ,
1062
- samples_chunk_size = schema_instance .samples_chunk_size ,
1063
- variants_chunk_size = schema_instance .variants_chunk_size ,
1068
+ schema_instance ,
1064
1069
array_name = array_name ,
1065
1070
)
1066
1071
1067
- def fixed_field_spec (
1068
- name ,
1069
- dtype ,
1070
- source = None ,
1071
- shape = (m ,),
1072
- dimensions = ("variants" ,),
1073
- chunks = None ,
1074
- ):
1072
+ def fixed_field_spec (name , dtype , source = None , dimensions = ("variants" ,)):
1075
1073
compressor = (
1076
1074
vcz .DEFAULT_ZARR_COMPRESSOR_BOOL .get_config ()
1077
1075
if dtype == "bool"
@@ -1081,16 +1079,11 @@ def fixed_field_spec(
1081
1079
source = source ,
1082
1080
name = name ,
1083
1081
dtype = dtype ,
1084
- shape = shape ,
1085
1082
description = "" ,
1086
1083
dimensions = dimensions ,
1087
- chunks = chunks or [schema_instance .variants_chunk_size ],
1088
1084
compressor = compressor ,
1089
1085
)
1090
1086
1091
- alt_field = self .fields ["ALT" ]
1092
- max_alleles = alt_field .vcf_field .summary .max_number + 1
1093
-
1094
1087
array_specs = [
1095
1088
fixed_field_spec (
1096
1089
name = "variant_contig" ,
@@ -1099,16 +1092,12 @@ def fixed_field_spec(
1099
1092
fixed_field_spec (
1100
1093
name = "variant_filter" ,
1101
1094
dtype = "bool" ,
1102
- shape = (m , self .metadata .num_filters ),
1103
1095
dimensions = ["variants" , "filters" ],
1104
- chunks = (schema_instance .variants_chunk_size , self .metadata .num_filters ),
1105
1096
),
1106
1097
fixed_field_spec (
1107
1098
name = "variant_allele" ,
1108
1099
dtype = "O" ,
1109
- shape = (m , max_alleles ),
1110
1100
dimensions = ["variants" , "alleles" ],
1111
- chunks = (schema_instance .variants_chunk_size , max_alleles ),
1112
1101
),
1113
1102
fixed_field_spec (
1114
1103
name = "variant_id" ,
@@ -1142,32 +1131,23 @@ def fixed_field_spec(
1142
1131
1143
1132
if gt_field is not None and n > 0 :
1144
1133
ploidy = max (gt_field .summary .max_number - 1 , 1 )
1145
- shape = [m , n ]
1146
- chunks = [
1147
- schema_instance .variants_chunk_size ,
1148
- schema_instance .samples_chunk_size ,
1149
- ]
1150
- dimensions = ["variants" , "samples" ]
1134
+ # Add ploidy dimension only when needed
1135
+ schema_instance .dimensions ["ploidy" ] = {"size" : ploidy }
1136
+
1151
1137
array_specs .append (
1152
1138
vcz .ZarrArraySpec (
1153
1139
name = "call_genotype_phased" ,
1154
1140
dtype = "bool" ,
1155
- shape = list (shape ),
1156
- chunks = list (chunks ),
1157
- dimensions = list (dimensions ),
1141
+ dimensions = ["variants" , "samples" ],
1158
1142
description = "" ,
1143
+ compressor = vcz .DEFAULT_ZARR_COMPRESSOR_BOOL .get_config (),
1159
1144
)
1160
1145
)
1161
- shape += [ploidy ]
1162
- chunks += [ploidy ]
1163
- dimensions += ["ploidy" ]
1164
1146
array_specs .append (
1165
1147
vcz .ZarrArraySpec (
1166
1148
name = "call_genotype" ,
1167
1149
dtype = gt_field .smallest_dtype (),
1168
- shape = list (shape ),
1169
- chunks = list (chunks ),
1170
- dimensions = list (dimensions ),
1150
+ dimensions = ["variants" , "samples" , "ploidy" ],
1171
1151
description = "" ,
1172
1152
compressor = vcz .DEFAULT_ZARR_COMPRESSOR_GENOTYPES .get_config (),
1173
1153
)
@@ -1176,16 +1156,14 @@ def fixed_field_spec(
1176
1156
vcz .ZarrArraySpec (
1177
1157
name = "call_genotype_mask" ,
1178
1158
dtype = "bool" ,
1179
- shape = list (shape ),
1180
- chunks = list (chunks ),
1181
- dimensions = list (dimensions ),
1159
+ dimensions = ["variants" , "samples" , "ploidy" ],
1182
1160
description = "" ,
1183
1161
compressor = vcz .DEFAULT_ZARR_COMPRESSOR_BOOL .get_config (),
1184
1162
)
1185
1163
)
1186
1164
1187
1165
if local_alleles :
1188
- array_specs = convert_local_allele_field_types (array_specs )
1166
+ array_specs = convert_local_allele_field_types (array_specs , schema_instance )
1189
1167
1190
1168
schema_instance .fields = array_specs
1191
1169
return schema_instance
0 commit comments