Skip to content

Commit c096b69

Browse files
committed
from easyocr
* fix typo * add estimate_num_chars option * return mapper
1 parent e332dd8 commit c096b69

File tree

2 files changed

+18
-15
lines changed

2 files changed

+18
-15
lines changed

craft.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
"""
1+
"""
22
Copyright (c) 2019-present NAVER Corp.
33
MIT License
44
"""
@@ -54,7 +54,7 @@ def __init__(self, pretrained=False, freeze=False):
5454
init_weights(self.upconv3.modules())
5555
init_weights(self.upconv4.modules())
5656
init_weights(self.conv_cls.modules())
57-
57+
5858
def forward(self, x):
5959
""" Base network """
6060
sources = self.basenet(x)
@@ -78,8 +78,3 @@ def forward(self, x):
7878
y = self.conv_cls(feature)
7979

8080
return y.permute(0,2,3,1), feature
81-
82-
if __name__ == '__main__':
83-
model = CRAFT(pretrained=True).cuda()
84-
output, _ = model(torch.randn(1, 3, 768, 768).cuda())
85-
print(output.shape)

craft_utils.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,16 +7,17 @@
77
import numpy as np
88
import cv2
99
import math
10+
from scipy.ndimage import label
1011

11-
""" auxilary functions """
12+
""" auxiliary functions """
1213
# unwarp corodinates
1314
def warpCoord(Minv, pt):
1415
out = np.matmul(Minv, (pt[0], pt[1], 1))
1516
return np.array([out[0]/out[2], out[1]/out[2]])
16-
""" end of auxilary functions """
17+
""" end of auxiliary functions """
1718

1819

19-
def getDetBoxes_core(textmap, linkmap, text_threshold, link_threshold, low_text):
20+
def getDetBoxes_core(textmap, linkmap, text_threshold, link_threshold, low_text, estimate_num_chars=False):
2021
# prepare data
2122
linkmap = linkmap.copy()
2223
textmap = textmap.copy()
@@ -42,6 +43,12 @@ def getDetBoxes_core(textmap, linkmap, text_threshold, link_threshold, low_text)
4243
# make segmentation map
4344
segmap = np.zeros(textmap.shape, dtype=np.uint8)
4445
segmap[labels==k] = 255
46+
if estimate_num_chars:
47+
_, character_locs = cv2.threshold((textmap - linkmap) * segmap /255., text_threshold, 1, 0)
48+
_, n_chars = label(character_locs)
49+
mapper.append(n_chars)
50+
else:
51+
mapper.append(k)
4552
segmap[np.logical_and(link_score==1, text_score==0)] = 0 # remove link area
4653
x, y = stats[k, cv2.CC_STAT_LEFT], stats[k, cv2.CC_STAT_TOP]
4754
w, h = stats[k, cv2.CC_STAT_WIDTH], stats[k, cv2.CC_STAT_HEIGHT]
@@ -74,7 +81,6 @@ def getDetBoxes_core(textmap, linkmap, text_threshold, link_threshold, low_text)
7481
box = np.array(box)
7582

7683
det.append(box)
77-
mapper.append(k)
7884

7985
return det, labels, mapper
8086

@@ -160,7 +166,7 @@ def getPoly_core(boxes, labels, mapper, linkmap):
160166
if num_sec != 0:
161167
cp_section[-1] = [cp_section[-1][0] / num_sec, cp_section[-1][1] / num_sec]
162168

163-
# pass if num of pivots is not sufficient or segment widh is smaller than character height
169+
# pass if num of pivots is not sufficient or segment width is smaller than character height
164170
if None in pp or seg_w < np.max(seg_height) * 0.25:
165171
polys.append(None); continue
166172

@@ -224,15 +230,17 @@ def getPoly_core(boxes, labels, mapper, linkmap):
224230

225231
return polys
226232

227-
def getDetBoxes(textmap, linkmap, text_threshold, link_threshold, low_text, poly=False):
228-
boxes, labels, mapper = getDetBoxes_core(textmap, linkmap, text_threshold, link_threshold, low_text)
233+
def getDetBoxes(textmap, linkmap, text_threshold, link_threshold, low_text, poly=False, estimate_num_chars=False):
234+
if poly and estimate_num_chars:
235+
raise Exception("Estimating the number of characters not currently supported with poly.")
236+
boxes, labels, mapper = getDetBoxes_core(textmap, linkmap, text_threshold, link_threshold, low_text, estimate_num_chars)
229237

230238
if poly:
231239
polys = getPoly_core(boxes, labels, mapper, linkmap)
232240
else:
233241
polys = [None] * len(boxes)
234242

235-
return boxes, polys
243+
return boxes, polys, mapper
236244

237245
def adjustResultCoordinates(polys, ratio_w, ratio_h, ratio_net = 2):
238246
if len(polys) > 0:

0 commit comments

Comments
 (0)