diff --git a/transform.py b/transform.py index 9479ae3..860a72f 100644 --- a/transform.py +++ b/transform.py @@ -47,12 +47,22 @@ def __call__(self, im_lb): # 10 'nose', 11 'mouth', 12 'u_lip', 13 'l_lip', 14 'neck', 15 'neck_l', 16 'cloth', 17 'hair', 18 'hat'] flip_lb = np.array(lb) - flip_lb[lb == 2] = 3 - flip_lb[lb == 3] = 2 - flip_lb[lb == 4] = 5 - flip_lb[lb == 5] = 4 - flip_lb[lb == 7] = 8 - flip_lb[lb == 8] = 7 + # flip_lb[lb == 2] = 3 + # flip_lb[lb == 3] = 2 + # flip_lb[lb == 4] = 5 + # flip_lb[lb == 5] = 4 + # flip_lb[lb == 7] = 8 + # flip_lb[lb == 8] = 7 + + right_idx = [3, 5, 8] + left_idx = [2, 4, 7] + + for i in range(3): + right_pos = np.where(flip_lb == right_idx[i]) + left_pos = np.where(flip_lb == left_idx[i]) + flip_lb[right_pos[0], right_pos[1]] = left_idx[i] + flip_lb[left_pos[0], left_pos[1]] = right_idx[i] + flip_lb = Image.fromarray(flip_lb) return dict(im = im.transpose(Image.FLIP_LEFT_RIGHT), lb = flip_lb.transpose(Image.FLIP_LEFT_RIGHT),