-
Notifications
You must be signed in to change notification settings - Fork 42
Expand file tree
/
Copy pathA2RL.py
More file actions
59 lines (44 loc) · 2.33 KB
/
A2RL.py
File metadata and controls
59 lines (44 loc) · 2.33 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
from __future__ import absolute_import
import pickle
import argparse
import numpy as np
import tensorflow as tf
import skimage.io as io
import network
from actions import command2action, generate_bbox, crop_input
global_dtype = tf.float32
with open('vfn_rl.pkl', 'rb') as f:
var_dict = pickle.load(f)
image_placeholder = tf.placeholder(dtype=global_dtype, shape=[None,227,227,3])
global_feature_placeholder = network.vfn_rl(image_placeholder, var_dict)
h_placeholder = tf.placeholder(dtype=global_dtype, shape=[None,1024])
c_placeholder = tf.placeholder(dtype=global_dtype, shape=[None,1024])
action, h, c = network.vfn_rl(image_placeholder, var_dict, global_feature=global_feature_placeholder,
h=h_placeholder, c=c_placeholder)
sess = tf.Session()
def auto_cropping(origin_image):
batch_size = len(origin_image)
terminals = np.zeros(batch_size)
ratios = np.repeat([[0, 0, 20, 20]], batch_size, axis=0)
img = crop_input(origin_image, generate_bbox(origin_image, ratios))
global_feature = sess.run(global_feature_placeholder, feed_dict={image_placeholder: img})
h_np = np.zeros([batch_size, 1024])
c_np = np.zeros([batch_size, 1024])
while True:
action_np, h_np, c_np = sess.run((action, h, c), feed_dict={image_placeholder: img,
global_feature_placeholder: global_feature,
h_placeholder: h_np,
c_placeholder: c_np})
ratios, terminals = command2action(action_np, ratios, terminals)
bbox = generate_bbox(origin_image, ratios)
if np.sum(terminals) == batch_size:
return bbox
img = crop_input(origin_image, bbox)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='A2RL: Auto Image Cropping')
parser.add_argument('--image_path', required=True, help='Path for the image to be cropped')
parser.add_argument('--save_path', required=True, help='Path for saving cropped image')
args = parser.parse_args()
im = io.imread(args.image_path).astype(np.float32) / 255
xmin, ymin, xmax, ymax = auto_cropping([im - 0.5])[0]
io.imsave(args.save_path, im[ymin:ymax, xmin:xmax])