1+ #!/usr/bin/env python
2+
3+ import os
4+ import re
5+ from future .moves .urllib .request import urlretrieve
6+ import unittest
7+
8+ from PIL import Image
9+ import rospkg
10+ from image_recognition_footwear .model import Model
11+ from image_recognition_footwear .process_data import heroPreprocess , detection_RGB
12+ import torch
13+
14+ @unittest .skip
15+ def test_footwear ():
16+ local_path = "~/data/pytorch_models/footwearModel.pth"
17+
18+ if not os .path .exists (local_path ):
19+ print ("File does not exit {}" .format (local_path ))
20+
21+ def is_there_footwear_from_asset_name (asset_name ):
22+ binary_str = re .search ("(\w+)_shoe" , asset_name ).groups ()
23+ return binary_str == "yes"
24+
25+ assets_path = os .path .join (rospkg .RosPack ().get_path ("image_recognition_footwear" ), 'test/assets' )
26+ images_gt = [(Image .open (os .path .join (assets_path , asset )), is_there_footwear_from_asset_name (asset ))
27+ for asset in os .listdir (assets_path )]
28+
29+ device = torch .device ('cuda' )
30+ model = Model (in_channel = 3 , channel_1 = 128 , channel_2 = 256 , channel_3 = 512 , node_1 = 1024 , node_2 = 1024 , num_classes = 2 )
31+ model .load_state_dict (torch .load (local_path ))
32+ model .to (device = device )
33+ detections = detection_RGB ([image for image , _ in images_gt ], model )
34+
35+ estimations = AgeGenderEstimator (local_path , 64 , 16 , 8 ).estimate ([image for image , _ in images_gt ])
36+
37+ for (_ , (is_footwear_gt )), (binary_detection ) in zip (images_gt , detections ):
38+ binary_detection = int (binary_detection )
39+ assert is_footwear_gt == binary_detection , f"{ binary_detection = } , { is_footwear_gt = } "
40+
41+
42+ if __name__ == "__main__" :
43+ test_footwear ()
0 commit comments