Skip to content

Commit e08622d

Browse files
committed
add polygon result
1 parent e4bf4b2 commit e08622d

File tree

2 files changed

+170
-7
lines changed

2 files changed

+170
-7
lines changed

craft_utils.py

Lines changed: 159 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,13 @@
88
import cv2
99
import math
1010

11+
""" auxilary functions """
12+
# unwarp corodinates
13+
def warpCoord(Minv, pt):
14+
out = np.matmul(Minv, (pt[0], pt[1], 1))
15+
return np.array([out[0]/out[2], out[1]/out[2]])
16+
""" end of auxilary functions """
17+
1118

1219
def getDetBoxes_core(textmap, linkmap, text_threshold, link_threshold, low_text):
1320
# prepare data
@@ -71,11 +78,161 @@ def getDetBoxes_core(textmap, linkmap, text_threshold, link_threshold, low_text)
7178

7279
return det, labels, mapper
7380

81+
def getPoly_core(boxes, labels, mapper, linkmap):
82+
# configs
83+
num_cp = 5
84+
max_len_ratio = 0.7
85+
expand_ratio = 1.45
86+
max_r = 2.0
87+
step_r = 0.2
88+
89+
polys = []
90+
for k, box in enumerate(boxes):
91+
# size filter for small instance
92+
w, h = int(np.linalg.norm(box[0] - box[1]) + 1), int(np.linalg.norm(box[1] - box[2]) + 1)
93+
if w < 30 or h < 30:
94+
polys.append(None); continue
95+
96+
# warp image
97+
tar = np.float32([[0,0],[w,0],[w,h],[0,h]])
98+
M = cv2.getPerspectiveTransform(box, tar)
99+
word_label = cv2.warpPerspective(labels, M, (w, h), flags=cv2.INTER_NEAREST)
100+
try:
101+
Minv = np.linalg.inv(M)
102+
except:
103+
polys.append(None); continue
104+
105+
# binarization for selected label
106+
cur_label = mapper[k]
107+
word_label[word_label != cur_label] = 0
108+
word_label[word_label > 0] = 1
109+
110+
""" Polygon generation """
111+
# find top/bottom contours
112+
cp = []
113+
max_len = -1
114+
for i in range(w):
115+
region = np.where(word_label[:,i] != 0)[0]
116+
if len(region) < 2 : continue
117+
cp.append((i, region[0], region[-1]))
118+
length = region[-1] - region[0] + 1
119+
if length > max_len: max_len = length
120+
121+
# pass if max_len is similar to h
122+
if h * max_len_ratio < max_len:
123+
polys.append(None); continue
124+
125+
# get pivot points with fixed length
126+
tot_seg = num_cp * 2 + 1
127+
seg_w = w / tot_seg # segment width
128+
pp = [None] * num_cp # init pivot points
129+
cp_section = [[0, 0]] * tot_seg
130+
seg_height = [0] * num_cp
131+
seg_num = 0
132+
num_sec = 0
133+
prev_h = -1
134+
for i in range(0,len(cp)):
135+
(x, sy, ey) = cp[i]
136+
if (seg_num + 1) * seg_w <= x and seg_num <= tot_seg:
137+
# average previous segment
138+
if num_sec == 0: break
139+
cp_section[seg_num] = [cp_section[seg_num][0] / num_sec, cp_section[seg_num][1] / num_sec]
140+
num_sec = 0
141+
142+
# reset variables
143+
seg_num += 1
144+
prev_h = -1
145+
146+
# accumulate center points
147+
cy = (sy + ey) * 0.5
148+
cur_h = ey - sy + 1
149+
cp_section[seg_num] = [cp_section[seg_num][0] + x, cp_section[seg_num][1] + cy]
150+
num_sec += 1
151+
152+
if seg_num % 2 == 0: continue # No polygon area
153+
154+
if prev_h < cur_h:
155+
pp[int((seg_num - 1)/2)] = (x, cy)
156+
seg_height[int((seg_num - 1)/2)] = cur_h
157+
prev_h = cur_h
158+
159+
# processing last segment
160+
if num_sec != 0:
161+
cp_section[-1] = [cp_section[-1][0] / num_sec, cp_section[-1][1] / num_sec]
162+
163+
# pass if num of pivots is not sufficient or segment widh is smaller than character height
164+
if None in pp or seg_w < np.max(seg_height) * 0.25:
165+
polys.append(None); continue
166+
167+
# calc median maximum of pivot points
168+
half_char_h = np.median(seg_height) * expand_ratio / 2
169+
170+
# calc gradiant and apply to make horizontal pivots
171+
new_pp = []
172+
for i, (x, cy) in enumerate(pp):
173+
dx = cp_section[i * 2 + 2][0] - cp_section[i * 2][0]
174+
dy = cp_section[i * 2 + 2][1] - cp_section[i * 2][1]
175+
if dx == 0: # gradient if zero
176+
new_pp.append([x, cy - half_char_h, x, cy + half_char_h])
177+
continue
178+
rad = - math.atan2(dy, dx)
179+
c, s = half_char_h * math.cos(rad), half_char_h * math.sin(rad)
180+
new_pp.append([x - s, cy - c, x + s, cy + c])
181+
182+
# get edge points to cover character heatmaps
183+
isSppFound, isEppFound = False, False
184+
grad_s = (pp[1][1] - pp[0][1]) / (pp[1][0] - pp[0][0]) + (pp[2][1] - pp[1][1]) / (pp[2][0] - pp[1][0])
185+
grad_e = (pp[-2][1] - pp[-1][1]) / (pp[-2][0] - pp[-1][0]) + (pp[-3][1] - pp[-2][1]) / (pp[-3][0] - pp[-2][0])
186+
for r in np.arange(0.5, max_r, step_r):
187+
dx = 2 * half_char_h * r
188+
if not isSppFound:
189+
line_img = np.zeros(word_label.shape, dtype=np.uint8)
190+
dy = grad_s * dx
191+
p = np.array(new_pp[0]) - np.array([dx, dy, dx, dy])
192+
cv2.line(line_img, (int(p[0]), int(p[1])), (int(p[2]), int(p[3])), 1, thickness=1)
193+
if np.sum(np.logical_and(word_label, line_img)) == 0 or r + 2 * step_r >= max_r:
194+
spp = p
195+
isSppFound = True
196+
if not isEppFound:
197+
line_img = np.zeros(word_label.shape, dtype=np.uint8)
198+
dy = grad_e * dx
199+
p = np.array(new_pp[-1]) + np.array([dx, dy, dx, dy])
200+
cv2.line(line_img, (int(p[0]), int(p[1])), (int(p[2]), int(p[3])), 1, thickness=1)
201+
if np.sum(np.logical_and(word_label, line_img)) == 0 or r + 2 * step_r >= max_r:
202+
epp = p
203+
isEppFound = True
204+
if isSppFound and isEppFound:
205+
break
206+
207+
# pass if boundary of polygon is not found
208+
if not (isSppFound and isEppFound):
209+
polys.append(None); continue
210+
211+
# make final polygon
212+
poly = []
213+
poly.append(warpCoord(Minv, (spp[0], spp[1])))
214+
for p in new_pp:
215+
poly.append(warpCoord(Minv, (p[0], p[1])))
216+
poly.append(warpCoord(Minv, (epp[0], epp[1])))
217+
poly.append(warpCoord(Minv, (epp[2], epp[3])))
218+
for p in reversed(new_pp):
219+
poly.append(warpCoord(Minv, (p[2], p[3])))
220+
poly.append(warpCoord(Minv, (spp[2], spp[3])))
221+
222+
# add to final result
223+
polys.append(np.array(poly))
224+
225+
return polys
74226

75-
def getDetBoxes(textmap, linkmap, text_threshold, link_threshold, low_text):
227+
def getDetBoxes(textmap, linkmap, text_threshold, link_threshold, low_text, poly=False):
76228
boxes, labels, mapper = getDetBoxes_core(textmap, linkmap, text_threshold, link_threshold, low_text)
77229

78-
return boxes
230+
if poly:
231+
polys = getPoly_core(boxes, labels, mapper, linkmap)
232+
else:
233+
polys = [None] * len(boxes)
234+
235+
return boxes, polys
79236

80237
def adjustResultCoordinates(polys, ratio_w, ratio_h, ratio_net = 2):
81238
if len(polys) > 0:

test.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ def str2bool(v):
5050
parser.add_argument('--cuda', default=True, type=str2bool, help='Use cuda to train model')
5151
parser.add_argument('--canvas_size', default=1280, type=int, help='image size for inference')
5252
parser.add_argument('--mag_ratio', default=1.5, type=float, help='image magnification ratio')
53+
parser.add_argument('--poly', default=False, action='store_true', help='enable polygon type')
5354
parser.add_argument('--show_time', default=False, action='store_true', help='show processing time')
5455
parser.add_argument('--test_folder', default='/data/', type=str, help='folder path to input images')
5556

@@ -63,7 +64,7 @@ def str2bool(v):
6364
if not os.path.isdir(result_folder):
6465
os.mkdir(result_folder)
6566

66-
def test_net(net, image, text_threshold, link_threshold, low_text, cuda):
67+
def test_net(net, image, text_threshold, link_threshold, low_text, cuda, poly):
6768
t0 = time.time()
6869

6970
# resize
@@ -88,8 +89,13 @@ def test_net(net, image, text_threshold, link_threshold, low_text, cuda):
8889
t1 = time.time()
8990

9091
# Post-processing
91-
boxes = craft_utils.getDetBoxes(score_text, score_link, text_threshold, link_threshold, low_text)
92+
boxes, polys = craft_utils.getDetBoxes(score_text, score_link, text_threshold, link_threshold, low_text, poly)
93+
94+
# coordinate adjustment
9295
boxes = craft_utils.adjustResultCoordinates(boxes, ratio_w, ratio_h)
96+
polys = craft_utils.adjustResultCoordinates(polys, ratio_w, ratio_h)
97+
for k in range(len(polys)):
98+
if polys[k] is None: polys[k] = boxes[k]
9399

94100
t1 = time.time() - t1
95101

@@ -100,7 +106,7 @@ def test_net(net, image, text_threshold, link_threshold, low_text, cuda):
100106

101107
if args.show_time : print("\ninfer/postproc time : {:.3f}/{:.3f}".format(t0, t1))
102108

103-
return boxes, ret_score_text
109+
return boxes, polys, ret_score_text
104110

105111

106112

@@ -128,13 +134,13 @@ def test_net(net, image, text_threshold, link_threshold, low_text, cuda):
128134
print("Test image {:d}/{:d}: {:s}".format(k+1, len(image_list), image_path), end='\r')
129135
image = imgproc.loadImage(image_path)
130136

131-
bboxes, score_text = test_net(net, image, args.text_threshold, args.link_threshold, args.low_text, args.cuda)
137+
bboxes, polys, score_text = test_net(net, image, args.text_threshold, args.link_threshold, args.low_text, args.cuda, args.poly)
132138

133139
# save score text
134140
filename, file_ext = os.path.splitext(os.path.basename(image_path))
135141
mask_file = result_folder + "/res_" + filename + '_mask.jpg'
136142
cv2.imwrite(mask_file, score_text)
137143

138-
file_utils.saveResult(image_path, image[:,:,::-1], bboxes, dirname=result_folder)
144+
file_utils.saveResult(image_path, image[:,:,::-1], polys, dirname=result_folder)
139145

140146
print("elapsed time : {}s".format(time.time() - t))

0 commit comments

Comments
 (0)