@@ -640,22 +640,53 @@ def read_timestamp(self, schema: "Schema") -> datetime.datetime:
640640 return self ._value .as_timestamp ()
641641
642642
643- # A registry for on-demand deserialization of types by using a mapping of shape IDs to their deserializers.
644643class TypeRegistry :
644+ """A registry for on-demand deserialization of types by using a mapping of shape IDs
645+ to their deserializers."""
646+
645647 def __init__ (
646648 self ,
647649 types : dict [ShapeID , type [DeserializeableShape ]],
648650 sub_registry : "TypeRegistry | None" = None ,
649651 ):
652+ """Initialize a TypeRegistry.
653+
654+ :param types: A mapping of ShapeID to the shapes they deserialize to.
655+ :param sub_registry: A registry to delegate to if an ID is not found in types.
656+ """
650657 self ._types = types
651658 self ._sub_registry = sub_registry
652659
653660 def get (self , shape : ShapeID ) -> type [DeserializeableShape ]:
661+ """Get the deserializable shape for the given shape ID.
662+
663+ :param shape: The shape ID to get from the registry.
664+ :returns: The corresponding deserializable shape.
665+ :raises KeyError: If the shape ID is not found in the registry.
666+ """
654667 if shape in self ._types :
655668 return self ._types [shape ]
656669 if self ._sub_registry is not None :
657670 return self ._sub_registry .get (shape )
658671 raise KeyError (f"Unknown shape: { shape } " )
659672
673+ def __getitem__ (self , shape : ShapeID ):
674+ """Get the deserializable shape for the given shape ID.
675+
676+ :param shape: The shape ID to get from the registry.
677+ :returns: The corresponding deserializable shape.
678+ :raises KeyError: If the shape ID is not found in the registry.
679+ """
680+ return self .get (shape )
681+
682+ def __contains__ (self , shape : ShapeID ):
683+ """Check if the registry contains the given shape.
684+
685+ :param shape: The shape ID to check for.
686+ """
687+ return shape in self ._types or (
688+ self ._sub_registry is not None and shape in self ._sub_registry
689+ )
690+
660691 def deserialize (self , document : Document ) -> DeserializeableShape :
661692 return document .as_shape (self .get (document .discriminator ))
0 commit comments