1- from typing import Tuple , Union , List , Dict , Generator , Optional
1+ # mypy: allow-untyped-defs
22import json
3+ from typing import Dict , Generator , Iterable , List , Optional , Tuple , Union , NamedTuple
4+
35import numpy as np
46from wkcuber .mag import Mag
57
68Shape3D = Union [List [int ], Tuple [int , int , int ], np .ndarray ]
79
810
11+ class BoundingBoxNamedTuple (NamedTuple ):
12+ topleft : Tuple [int , int , int ]
13+ size : Tuple [int , int , int ]
14+
15+
916class BoundingBox :
1017 def __init__ (self , topleft : Shape3D , size : Shape3D ):
1118
1219 self .topleft = np .array (topleft , dtype = np .int )
1320 self .size = np .array (size , dtype = np .int )
1421
1522 @property
16- def bottomright (self ):
23+ def bottomright (self ) -> np . ndarray :
1724
1825 return self .topleft + self .size
1926
2027 @staticmethod
21- def from_wkw (bbox : Dict ):
28+ def from_wkw (bbox : Dict ) -> "BoundingBox" :
2229 return BoundingBox (
2330 bbox ["topLeft" ], [bbox ["width" ], bbox ["height" ], bbox ["depth" ]]
2431 )
2532
2633 @staticmethod
27- def from_config (bbox : Dict ):
34+ def from_config (bbox : Dict ) -> "BoundingBox" :
2835 return BoundingBox (bbox ["topleft" ], bbox ["size" ])
2936
3037 @staticmethod
31- def from_tuple6 (tuple6 : Tuple ) :
38+ def from_tuple6 (tuple6 : Tuple [ int , int , int , int , int , int ]) -> "BoundingBox" :
3239 return BoundingBox (tuple6 [0 :3 ], tuple6 [3 :6 ])
3340
3441 @staticmethod
35- def from_tuple2 (tuple2 : Tuple ) :
42+ def from_tuple2 (tuple2 : Tuple [ Shape3D , Shape3D ]) -> "BoundingBox" :
3643 return BoundingBox (tuple2 [0 ], tuple2 [1 ])
3744
3845 @staticmethod
39- def from_auto (obj ):
46+ def from_points (points : Iterable [Shape3D ]) -> "BoundingBox" :
47+
48+ all_points = np .array (points )
49+ topleft = all_points .min (axis = 0 )
50+ bottomright = all_points .max (axis = 0 )
51+
52+ # bottomright is exclusive
53+ bottomright += 1
54+
55+ return BoundingBox (topleft , bottomright - topleft )
56+
57+ @staticmethod
58+ def from_named_tuple (bb_named_tuple : BoundingBoxNamedTuple ):
59+
60+ return BoundingBox (bb_named_tuple .topleft , bb_named_tuple .size )
61+
62+ @staticmethod
63+ def from_auto (obj ) -> "BoundingBox" :
4064 if isinstance (obj , BoundingBox ):
4165 return obj
4266 elif isinstance (obj , str ):
4367 return BoundingBox .from_auto (json .loads (obj ))
4468 elif isinstance (obj , dict ):
4569 return BoundingBox .from_wkw (obj )
70+ elif isinstance (obj , BoundingBoxNamedTuple ):
71+ return BoundingBox .from_named_tuple (obj )
4672 elif isinstance (obj , list ) or isinstance (obj , tuple ):
4773 if len (obj ) == 2 :
48- return BoundingBox .from_tuple2 (obj )
74+ return BoundingBox .from_tuple2 (obj ) # type: ignore
4975 elif len (obj ) == 6 :
50- return BoundingBox .from_tuple6 (obj )
76+ return BoundingBox .from_tuple6 (obj ) # type: ignore
5177
5278 raise Exception ("Unknown bounding box format." )
5379
54- def as_tuple2_string (self ):
55- return str ([self .topleft , self .size ])
56-
57- def as_wkw (self ):
80+ def as_wkw (self ) -> dict :
5881
5982 width , height , depth = self .size .tolist ()
6083
@@ -65,29 +88,33 @@ def as_wkw(self):
6588 "depth" : depth ,
6689 }
6790
68- def as_config (self ):
91+ def as_config (self ) -> dict :
6992
7093 return {"topleft" : self .topleft .tolist (), "size" : self .size .tolist ()}
7194
72- def as_checkpoint_name (self ):
95+ def as_checkpoint_name (self ) -> str :
7396
7497 x , y , z = self .topleft
7598 width , height , depth = self .size
7699 return "{x}_{y}_{z}_{width}_{height}_{depth}" .format (
77100 x = x , y = y , z = z , width = width , height = height , depth = depth
78101 )
79102
80- def __repr__ (self ):
103+ def as_tuple6 (self ) -> Tuple [int , int , int , int , int , int ]:
104+
105+ return tuple (self .topleft .tolist () + self .size .tolist ()) # type: ignore
106+
107+ def __repr__ (self ) -> str :
81108
82109 return "BoundingBox(topleft={}, size={})" .format (
83110 str (tuple (self .topleft )), str (tuple (self .size ))
84111 )
85112
86- def __str__ (self ):
113+ def __str__ (self ) -> str :
87114
88115 return self .__repr__ ()
89116
90- def __eq__ (self , other ):
117+ def __eq__ (self , other ) -> bool :
91118
92119 return np .array_equal (self .topleft , other .topleft ) and np .array_equal (
93120 self .size , other .size
@@ -122,11 +149,11 @@ def intersected_with(
122149 if not dont_assert :
123150 assert (
124151 not intersection .is_empty ()
125- ), "No intersection between bounding boxes {} and {}." . format ( self , other )
152+ ), f "No intersection between bounding boxes { self } and { other } ."
126153
127154 return intersection
128155
129- def extended_by (self , other : "BoundingBox" ):
156+ def extended_by (self , other : "BoundingBox" ) -> "BoundingBox" :
130157
131158 topleft = np .minimum (self .topleft , other .topleft )
132159 bottomright = np .maximum (self .bottomright , other .bottomright )
@@ -138,13 +165,18 @@ def is_empty(self) -> bool:
138165
139166 return not all (self .size > 0 )
140167
141- def in_mag (self , mag : Mag ) -> "BoundingBox" :
168+ def in_mag (self , mag : Mag , ceil : bool = False ) -> "BoundingBox" :
142169
143170 np_mag = np .array (mag .to_array ())
144171
172+ def ceil_maybe (array : np .ndarray ) -> np .ndarray :
173+ if ceil :
174+ return np .ceil (array )
175+ return array
176+
145177 return BoundingBox (
146- topleft = (self .topleft / np_mag ).astype (np .int ),
147- size = (self .size / np_mag ).astype (np .int ),
178+ topleft = ceil_maybe (self .topleft / np_mag ).astype (np .int ),
179+ size = ceil_maybe (self .size / np_mag ).astype (np .int ),
148180 )
149181
150182 def contains (self , coord : Shape3D ) -> bool :
@@ -155,6 +187,9 @@ def contains(self, coord: Shape3D) -> bool:
155187 coord < self .topleft + self .size
156188 )
157189
190+ def contains_bbox (self , inner_bbox : "BoundingBox" ) -> bool :
191+ return inner_bbox .intersected_with (self ) == inner_bbox
192+
158193 def chunk (
159194 self , chunk_size : Shape3D , chunk_border_alignments : Optional [List [int ]] = None
160195 ) -> Generator ["BoundingBox" , None , None ]:
@@ -174,7 +209,7 @@ def chunk(
174209 chunk_border_alignments = np .array (chunk_border_alignments )
175210 assert np .all (
176211 chunk_size % chunk_border_alignments == 0
177- ), "{ } not divisible by {}" . format ( chunk_size , chunk_border_alignments )
212+ ), f" { chunk_size } not divisible by { chunk_border_alignments } "
178213
179214 # Move the start to be aligned correctly. This doesn't actually change
180215 # the start of the first chunk, because we'll intersect with `self`,
0 commit comments