3333
3434
3535class ObjectDetector ():
36- r'''
36+ r"""
3737 ObjectDetector is a high-level class defined for building model,preproceing the input image,
3838 predicting and postprocessing the prediction output data.
3939
4040 Args:
4141 config (dict): model config parsed from the json file under the app/object_detection/configs dir.
42- '''
42+ """
4343 def __init__ (self , config = None ):
4444 self .config = config
4545
4646 def data_preprocess (self , input ):
47- r'''
47+ r"""
4848 Preprocess the input image.
4949
5050 Args:
@@ -53,7 +53,7 @@ def data_preprocess(self, input):
5353 Returns:
5454 list, the preprocess image shape.
5555 numpy.ndarray, the preprocess image result.
56- '''
56+ """
5757 if not isinstance (input , np .ndarray ):
5858 err_msg = 'The input type should be numpy.ndarray, got {}.' .format (type (input ))
5959 raise TypeError (err_msg )
@@ -69,31 +69,31 @@ def data_preprocess(self, input):
6969 return image_shape , transform_input
7070
7171 def convert2tensor (self , transform_input ):
72- r'''
72+ r"""
7373 Convert the numpy data to the tensor format.
7474
7575 Args:
76- transform_input (numpy.ndarray): the preprocessed image.
76+ transform_input (numpy.ndarray): the preprocessing image.
7777
7878 Returns:
7979 Tensor, the converted image.
80- '''
80+ """
8181 if not isinstance (transform_input , np .ndarray ):
8282 err_msg = 'The transform_input type should be numpy.ndarray, got {}.' .format (type (transform_input ))
8383 raise TypeError (err_msg )
8484 input_tensor = ts .expand_dims (ts .array (list (transform_input )), 0 )
8585 return input_tensor
8686
8787 def model_build (self , is_training = False ):
88- r'''
88+ r"""
8989 Build the object detection model to predict the image.
9090
9191 Args:
9292 is_training (bool): default: False.
9393
9494 Returns:
9595 model.Model, generated object detection model.
96- '''
96+ """
9797 model_net = model_checker .get (self .config .get ('model_net' ))
9898 if not model_net :
9999 err_msg = 'Currently model_net only supports {}!' .format (str (list (model_checker .keys ())))
@@ -109,17 +109,17 @@ def model_build(self, is_training=False):
109109 return serve_model
110110
111111 def model_load_and_predict (self , serve_model , input_tensor ):
112- r'''
112+ r"""
113113 Load the object detection model to predict the image.
114114
115115 Args:
116116 serve_model (model.Model): object detection model.
117- input_tensor(Tensor): the converted input image
117+ input_tensor (Tensor): the converted input image.
118118
119119 Returns:
120120 model.Model, object detection model loaded the checkpoint file.
121121 list, predictions output result.
122- '''
122+ """
123123 ckpt_path = self .config .get ('checkpoint_path' )
124124 if not ckpt_path :
125125 err_msg = 'The ckpt_path {} can not be none.' .format (ckpt_path )
@@ -139,16 +139,16 @@ def model_load_and_predict(self, serve_model, input_tensor):
139139 return serve_model , predictions_output
140140
141141 def data_postprocess (self , predictions_output , image_shape ):
142- r'''
142+ r"""
143143 Postprocessing the predictions output data.
144144
145145 Args:
146146 predictions_output (list): predictions output data.
147- image_shape(list): the shapr of the input image.
147+ image_shape (list): the shape of the input image.
148148
149149 Returns:
150- dict, the postprocess result.
151- '''
150+ dict, the postprocessing result.
151+ """
152152 output_np = (ts .concatenate ((predictions_output [0 ], predictions_output [1 ]), axis = - 1 ).asnumpy ())
153153 transform_func = transform_checker .get (self .config .get ('dataset' ))
154154 if not transform_func :
@@ -158,17 +158,17 @@ def data_postprocess(self, predictions_output, image_shape):
158158
159159
160160def object_detection_predict (input , object_detector , is_training = False ):
161- r'''
161+ r"""
162162 An easy object detection model predicting method for beginning developers to use.
163163
164164 Args:
165165 input (numpy.ndarray): the input image.
166- object_detector (ObjectDetector): the instance of the ObjectDetector class
166+ object_detector (ObjectDetector): the instance of the ObjectDetector class.
167167 is_training (bool): default: False.
168168
169169 Returns:
170- dict, the postprocess result.
171- '''
170+ dict, the postprocessing result.
171+ """
172172 if not isinstance (object_detector , ObjectDetector ):
173173 err_msg = 'The object_detector is not the instance of ObjectDetector'
174174 raise TypeError (err_msg )
0 commit comments