@@ -64,7 +64,8 @@ def register(self, cls: type[T], qualname: str | None = None) -> None:
64
64
__pipeline_registry : Registry [CodecPipeline ] = Registry ()
65
65
__buffer_registry : Registry [Buffer ] = Registry ()
66
66
__ndbuffer_registry : Registry [NDBuffer ] = Registry ()
67
- __chunk_key_encoding_registry : Registry [ChunkKeyEncoding ] = Registry ()
67
+ # Now a dict[str, Registry[ChunkKeyEncoding]]
68
+ __chunk_key_encoding_registries : dict [str , Registry [ChunkKeyEncoding ]] = defaultdict (Registry )
68
69
69
70
# CHANGE: Consider updating docstring
70
71
"""
@@ -105,12 +106,10 @@ def _collect_entrypoints() -> list[Registry[Any]]:
105
106
data_type_registry ._lazy_load_list .extend (entry_points .select (group = "zarr.data_type" ))
106
107
data_type_registry ._lazy_load_list .extend (entry_points .select (group = "zarr" , name = "data_type" ))
107
108
108
- __chunk_key_encoding_registry .lazy_load_list .extend (
109
- entry_points .select (group = "zarr.chunk_key_encoding" )
110
- )
111
- __chunk_key_encoding_registry .lazy_load_list .extend (
112
- entry_points .select (group = "zarr" , name = "chunk_key_encoding" )
113
- )
109
+ for e in entry_points .select (group = "zarr.chunk_key_encoding" ):
110
+ __chunk_key_encoding_registries [e .name ].lazy_load_list .append (e )
111
+ for e in entry_points .select (group = "zarr" , name = "chunk_key_encoding" ):
112
+ __chunk_key_encoding_registries [e .name ].lazy_load_list .append (e )
114
113
115
114
__pipeline_registry .lazy_load_list .extend (entry_points .select (group = "zarr.codec_pipeline" ))
116
115
__pipeline_registry .lazy_load_list .extend (
@@ -127,7 +126,7 @@ def _collect_entrypoints() -> list[Registry[Any]]:
127
126
__pipeline_registry ,
128
127
__buffer_registry ,
129
128
__ndbuffer_registry ,
130
- __chunk_key_encoding_registry ,
129
+ * ( __chunk_key_encoding_registries . values ()) ,
131
130
]
132
131
133
132
@@ -158,8 +157,10 @@ def register_buffer(cls: type[Buffer], qualname: str | None = None) -> None:
158
157
__buffer_registry .register (cls , qualname )
159
158
160
159
161
- def register_chunk_key_encoding (cls : type , qualname : str | None = None ) -> None :
162
- __chunk_key_encoding_registry .register (cls , qualname )
160
+ def register_chunk_key_encoding (
161
+ key : str , cke_cls : type [ChunkKeyEncoding ], qualname : str | None = None
162
+ ) -> None :
163
+ __chunk_key_encoding_registries [key ].register (cke_cls , qualname )
163
164
164
165
165
166
def get_codec_class (key : str , reload_config : bool = False ) -> type [Codec ]:
@@ -299,14 +300,36 @@ def get_ndbuffer_class(reload_config: bool = False) -> type[NDBuffer]:
299
300
)
300
301
301
302
302
- def get_chunk_key_encoding_class (key : str ) -> type [ChunkKeyEncoding ]:
303
- __chunk_key_encoding_registry .lazy_load ()
304
- if key not in __chunk_key_encoding_registry :
303
+ def get_chunk_key_encoding_class (key : str , reload_config : bool = False ) -> type [ChunkKeyEncoding ]:
304
+ if reload_config :
305
+ _reload_config ()
306
+
307
+ if key in __chunk_key_encoding_registries :
308
+ __chunk_key_encoding_registries [key ].lazy_load ()
309
+ else :
305
310
raise KeyError (
306
- f"Chunk key encoding '{ key } ' not found in registered chunk key encodings: { list (__chunk_key_encoding_registry )} ."
311
+ f"Chunk key encoding '{ key } ' not found in registered chunk key encodings: { list (__chunk_key_encoding_registries )} ."
307
312
)
308
313
309
- return __chunk_key_encoding_registry [key ]
314
+ cke_classes = __chunk_key_encoding_registries [key ]
315
+ if not cke_classes :
316
+ raise KeyError (key )
317
+
318
+ config_entry = config .get ("chunk_key_encodings" , {}).get (key )
319
+ if config_entry is None :
320
+ if len (cke_classes ) == 1 :
321
+ return next (iter (cke_classes .values ()))
322
+ warnings .warn (
323
+ f"Chunk key encoding '{ key } ' not configured in config. Selecting any implementation." ,
324
+ stacklevel = 2 ,
325
+ category = ZarrUserWarning ,
326
+ )
327
+ return list (cke_classes .values ())[- 1 ]
328
+ selected_encoding_cls = cke_classes [config_entry ]
329
+
330
+ if selected_encoding_cls :
331
+ return selected_encoding_cls
332
+ raise KeyError (key )
310
333
311
334
312
335
_collect_entrypoints ()
0 commit comments