Skip to content

Commit 13c9cf8

Browse files
committed
TinyMS v0.3.0 adapts for MindSpore 1.5.0
1 parent 206ed0e commit 13c9cf8

File tree

8 files changed

+62
-53
lines changed

8 files changed

+62
-53
lines changed

docs/en/source/design/concepts.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -265,7 +265,7 @@ input = image_np.copy()
265265
# 4.Detect the input image
266266
detection_bbox_data = object_detection_predict(input, detector, is_training=False)
267267

268-
# 5.Draw the box for the input image and visualize in the opencv window using OpenCV.
268+
# 5.Draw the box for the input image and and view it using OpenCV.
269269
detection_image_np = visualize_boxes_on_image(image_np, detection_bbox_data, box_color=(0, 255, 0),
270270
box_thickness=3, text_font=cv2.FONT_HERSHEY_PLAIN,
271271
font_scale=3, text_color=(0, 0, 255), font_size=3, show_scores=True)
@@ -298,7 +298,7 @@ while True:
298298
# 4.Detect the input frame image
299299
detection_bbox_data = object_detection_predict(input, detector, is_training=False)
300300

301-
# 5.Draw the box for the input frame image and visualize in the opencv window using OpenCV.
301+
# 5.Draw the box for the input frame image and view it using OpenCV.
302302
detection_image_np = visualize_boxes_on_image(image_np, detection_bbox_data, box_color=(0, 255, 0),
303303
box_thickness=3, text_font=cv2.FONT_HERSHEY_PLAIN,
304304
font_scale=3, text_color=(0, 0, 255), font_size=3, show_scores=True)

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def _write_version(file):
4747
'scipy>=1.5.2,<1.8.0',
4848
'matplotlib>=3.1.3',
4949
'Pillow>=6.2.0',
50-
'mindspore==1.3.0',
50+
'mindspore==1.5.0',
5151
'requests>=2.22.0',
5252
'flask>=1.1.1',
5353
'python-Levenshtein>=0.10.2',

tests/st/app/object_detection/opencv_camera_app.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,14 +29,14 @@
2929

3030
cap = cv2.VideoCapture(0)
3131
while True:
32-
# 3.Read the frame image from the camera
32+
# 3.Read the frame image from the camera using OpenCV
3333
ret, image_np = cap.read()
3434
input = image_np.copy()
3535

3636
# 4.Detect the input frame image
3737
detection_bbox_data = object_detection_predict(input, detector, is_training=False)
3838

39-
# 5.Draw the box for the input frame image and visualize in the opencv window.
39+
# 5.Draw the box for the input frame image and view it using OpenCV.
4040
detection_image_np = visualize_boxes_on_image(image_np, detection_bbox_data, box_color=(0, 255, 0),
4141
box_thickness=3, text_font=cv2.FONT_HERSHEY_PLAIN,
4242
font_scale=3, text_color=(0, 0, 255), font_size=3, show_scores=True)

tests/st/app/object_detection/opencv_image_app.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,14 +43,14 @@ def parse_args():
4343
# 2.Generate the instance of ObjectDetector
4444
detector = ObjectDetector(config=config)
4545

46-
# 3.Read the input image
46+
# 3.Read the input image using OpenCV
4747
image_np = cv2.imread(args_opt.img_path)
4848
input = image_np.copy()
4949

5050
# 4.Detect the input image
5151
detection_bbox_data = object_detection_predict(input, detector, is_training=False)
5252

53-
# 5.Draw the box for the input image and visualize in the opencv window.
53+
# 5.Draw the box for the input image and view it using OpenCV.
5454
detection_image_np = visualize_boxes_on_image(image_np, detection_bbox_data, box_color=(0, 255, 0),
5555
box_thickness=3, text_font=cv2.FONT_HERSHEY_PLAIN,
5656
font_scale=3, text_color=(0, 0, 255), font_size=3, show_scores=True)

tinyms/app/__init__.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,19 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
# ============================================================================
15-
15+
"""
16+
This module is to support vision visualization with opencv, which can help
17+
developers use pre-trained models to predict and show the reasoning image fast.
18+
Current it only supports object detection model.
19+
"""
1620
from . import object_detection
21+
from .object_detection.object_detector import object_detection_predict, ObjectDetector
22+
from .object_detection.utils.view_util import visualize_boxes_on_image, draw_boxes_on_image, save_image
23+
from .object_detection.utils.config_util import load_and_parse_config
24+
25+
26+
object_detection_utils = ['visualize_boxes_on_image', 'draw_boxes_on_image', 'save_image', 'load_and_parse_config']
1727

18-
__all__ = []
28+
__all__ = ['ObjectDetector', 'object_detection_predict']
29+
__all__.extend(object_detection_utils)
1930
__all__.extend(object_detection.__all__)

tinyms/app/object_detection/object_detector.py

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -33,18 +33,18 @@
3333

3434

3535
class ObjectDetector():
36-
r'''
36+
r"""
3737
ObjectDetector is a high-level class defined for building model,preproceing the input image,
3838
predicting and postprocessing the prediction output data.
3939
4040
Args:
4141
config (dict): model config parsed from the json file under the app/object_detection/configs dir.
42-
'''
42+
"""
4343
def __init__(self, config=None):
4444
self.config = config
4545

4646
def data_preprocess(self, input):
47-
r'''
47+
r"""
4848
Preprocess the input image.
4949
5050
Args:
@@ -53,7 +53,7 @@ def data_preprocess(self, input):
5353
Returns:
5454
list, the preprocess image shape.
5555
numpy.ndarray, the preprocess image result.
56-
'''
56+
"""
5757
if not isinstance(input, np.ndarray):
5858
err_msg = 'The input type should be numpy.ndarray, got {}.'.format(type(input))
5959
raise TypeError(err_msg)
@@ -69,31 +69,31 @@ def data_preprocess(self, input):
6969
return image_shape, transform_input
7070

7171
def convert2tensor(self, transform_input):
72-
r'''
72+
r"""
7373
Convert the numpy data to the tensor format.
7474
7575
Args:
76-
transform_input (numpy.ndarray): the preprocessed image.
76+
transform_input (numpy.ndarray): the preprocessing image.
7777
7878
Returns:
7979
Tensor, the converted image.
80-
'''
80+
"""
8181
if not isinstance(transform_input, np.ndarray):
8282
err_msg = 'The transform_input type should be numpy.ndarray, got {}.'.format(type(transform_input))
8383
raise TypeError(err_msg)
8484
input_tensor = ts.expand_dims(ts.array(list(transform_input)), 0)
8585
return input_tensor
8686

8787
def model_build(self, is_training=False):
88-
r'''
88+
r"""
8989
Build the object detection model to predict the image.
9090
9191
Args:
9292
is_training (bool): default: False.
9393
9494
Returns:
9595
model.Model, generated object detection model.
96-
'''
96+
"""
9797
model_net = model_checker.get(self.config.get('model_net'))
9898
if not model_net:
9999
err_msg = 'Currently model_net only supports {}!'.format(str(list(model_checker.keys())))
@@ -109,17 +109,17 @@ def model_build(self, is_training=False):
109109
return serve_model
110110

111111
def model_load_and_predict(self, serve_model, input_tensor):
112-
r'''
112+
r"""
113113
Load the object detection model to predict the image.
114114
115115
Args:
116116
serve_model (model.Model): object detection model.
117-
input_tensor(Tensor): the converted input image
117+
input_tensor (Tensor): the converted input image.
118118
119119
Returns:
120120
model.Model, object detection model loaded the checkpoint file.
121121
list, predictions output result.
122-
'''
122+
"""
123123
ckpt_path = self.config.get('checkpoint_path')
124124
if not ckpt_path:
125125
err_msg = 'The ckpt_path {} can not be none.'.format(ckpt_path)
@@ -139,16 +139,16 @@ def model_load_and_predict(self, serve_model, input_tensor):
139139
return serve_model, predictions_output
140140

141141
def data_postprocess(self, predictions_output, image_shape):
142-
r'''
142+
r"""
143143
Postprocessing the predictions output data.
144144
145145
Args:
146146
predictions_output (list): predictions output data.
147-
image_shape(list): the shapr of the input image.
147+
image_shape (list): the shape of the input image.
148148
149149
Returns:
150-
dict, the postprocess result.
151-
'''
150+
dict, the postprocessing result.
151+
"""
152152
output_np = (ts.concatenate((predictions_output[0], predictions_output[1]), axis=-1).asnumpy())
153153
transform_func = transform_checker.get(self.config.get('dataset'))
154154
if not transform_func:
@@ -158,17 +158,17 @@ def data_postprocess(self, predictions_output, image_shape):
158158

159159

160160
def object_detection_predict(input, object_detector, is_training=False):
161-
r'''
161+
r"""
162162
An easy object detection model predicting method for beginning developers to use.
163163
164164
Args:
165165
input (numpy.ndarray): the input image.
166-
object_detector (ObjectDetector): the instance of the ObjectDetector class
166+
object_detector (ObjectDetector): the instance of the ObjectDetector class.
167167
is_training (bool): default: False.
168168
169169
Returns:
170-
dict, the postprocess result.
171-
'''
170+
dict, the postprocessing result.
171+
"""
172172
if not isinstance(object_detector, ObjectDetector):
173173
err_msg = 'The object_detector is not the instance of ObjectDetector'
174174
raise TypeError(err_msg)

tinyms/app/object_detection/utils/config_util.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,15 +37,15 @@ def _download_ckeckpoint(checkpoint_url, sha256, checkpoint_path):
3737

3838

3939
def load_and_parse_config(config_path):
40-
r'''
40+
r"""
4141
Load and parse the json file the object detection model.
4242
4343
Args:
4444
config_path (numpy.ndarray): the config json file path.
4545
4646
Returns:
4747
dict, the model configuration.
48-
'''
48+
"""
4949
# Check if config_path existed
5050
if not os.path.exists(config_path):
5151
raise FileNotFoundError("The config file path {} does not exist!".format(config_path))

tinyms/app/object_detection/utils/view_util.py

Lines changed: 20 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -21,15 +21,15 @@
2121

2222

2323
def save_image(img, save_dir='./', img_name='no_name', img_format='jpg'):
24-
r'''
24+
r"""
2525
Save the prediction image.
2626
2727
Args:
2828
img (numpy.ndarray): the input image.
2929
save_dir (str): the dir to save the prediction image.
30-
img_name (str): the name of the prediction image.
31-
img_format (str): the format of the prediction image.
32-
'''
30+
img_name (str): the name of the prediction image. Default: 'no_name'.
31+
img_format (str): the format of the prediction image. Default: 'jpg'.
32+
"""
3333
if img_format.lower() not in IMG_FORMAT:
3434
raise Exception("当前图片格式仅支持", IMG_FORMAT)
3535
output_image = os.path.join(save_dir, '{}.{}'.format(img_name, img_format))
@@ -39,25 +39,23 @@ def save_image(img, save_dir='./', img_name='no_name', img_format='jpg'):
3939
def draw_boxes_on_image(img, boxes, box_scores, box_classes, box_color=(0, 255, 0),
4040
box_thickness=3, text_font=cv2.FONT_HERSHEY_PLAIN,
4141
font_scale=3, text_color=(0, 0, 255), font_size=3, show_scores=True):
42-
r'''
42+
r"""
4343
Draw the prediction box for the input image.
4444
4545
Args:
4646
img (numpy.ndarray): the input image.
4747
boxes (list): the box coordinates.
48-
box_scores (int): the prediction score.
49-
box_classes: the prediction category.
50-
box_color (list): the box color.
51-
box_thickness (int): box thickness.
52-
text_font (Enum): text font.
53-
font_scale (int): font scale.
54-
text_color (list): text color.
55-
font_size (int): font size.
48+
box_color (list): the box color. Default: (0, 255, 0).
49+
box_thickness (int): box thickness. Default: 3.
50+
text_font (Enum): text font. Default: cv2.FONT_HERSHEY_PLAIN.
51+
font_scale (int): font scale. Default: 3.
52+
text_color (list): text color. Default: (0, 0, 255).
53+
font_size (int): font size. Default: 3.
5654
show_scores (bool): whether to show scores. Default: True.
5755
5856
Returns:
5957
numpy.ndarray, the output image drawed the prediction box.
60-
'''
58+
"""
6159
x = int(boxes[0])
6260
y = int(boxes[1])
6361
w = int(boxes[2])
@@ -71,23 +69,23 @@ def draw_boxes_on_image(img, boxes, box_scores, box_classes, box_color=(0, 255,
7169
def visualize_boxes_on_image(img, bbox_data, box_color=(0, 255, 0),
7270
box_thickness=3, text_font=cv2.FONT_HERSHEY_PLAIN,
7371
font_scale=3, text_color=(0, 0, 255), font_size=3, show_scores=True):
74-
r'''
72+
r"""
7573
Visualize the prediction image.
7674
7775
Args:
7876
img (numpy.ndarray): the input image.
7977
bbox_data (dict): the predictions box data.
80-
box_color (list): the box color.
81-
box_thickness (int): box thickness.
82-
text_font (Enum): text font.
83-
font_scale (int): font scale.
84-
text_color (list): text color.
85-
font_size (int): font size.
78+
box_color (list): the box color. Default: (0, 255, 0).
79+
box_thickness (int): box thickness. Default: 3.
80+
text_font (Enum): text font. Default: cv2.FONT_HERSHEY_PLAIN.
81+
font_scale (int): font scale. Default: 3.
82+
text_color (list): text color. Default: (0, 0, 255).
83+
font_size (int): font size. Default: 3.
8684
show_scores (bool): whether to show scores. Default: True.
8785
8886
Returns:
8987
numpy.ndarray, the output image drawed the prediction box.
90-
'''
88+
"""
9189
bbox_num = len(bbox_data)
9290
if bbox_num:
9391
for i in range(bbox_num):

0 commit comments

Comments
 (0)