Skip to content

Commit 44faa91

Browse files
authored
VOC data augmentation with TF dataset API (#481)
* tl.iterate.minibatch support list with shuffle * release example for image aug - VOC with dataset API
1 parent 09a739d commit 44faa91

File tree

2 files changed

+99
-0
lines changed

2 files changed

+99
-0
lines changed

docs/modules/prepro.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -357,6 +357,10 @@ In practice, you may want to use threading method to process a batch of images a
357357
b_ann[i][0], b_ann[i][1], [], classes, True,
358358
save_name='_bbox_vis_%d.png' % i)
359359
360+
Image Aug with TF Dataset API
361+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
362+
363+
- Example code for VOC `here <https://github.com/tensorlayer/tensorlayer/blob/master/example/tutorial_tf_dataset_voc.py>`__.
360364

361365
Coordinate pixel unit to percentage
362366
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

example/tutorial_tf_dataset_voc.py

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
#! /usr/bin/python
2+
# -*- coding: utf8 -*-
3+
4+
# tf import data dataset.map https://www.tensorflow.org/programmers_guide/datasets#applying_arbitrary_python_logic_with_tfpy_func
5+
# tf.py_func https://www.tensorflow.org/api_docs/python/tf/py_func
6+
# tl ref: https://github.com/tensorlayer/tensorlayer/blob/master/example/tutorial_imagenet_inceptionV3_distributed.py
7+
# cn ref: https://blog.csdn.net/dQCFKyQDXYm3F8rB0/article/details/79342369
8+
# cn ref: https://zhuanlan.zhihu.com/p/31466173
9+
10+
import numpy as np
11+
import multiprocessing, random, json, time
12+
import tensorflow as tf
13+
import tensorlayer as tl
14+
15+
imgs_file_list, _, _, _, classes, _, _,\
16+
_, objs_info_list, _ = tl.files.load_voc_dataset(dataset="2007")
17+
18+
ann_list = []
19+
for info in objs_info_list:
20+
ann = tl.prepro.parse_darknet_ann_str_to_list(info)
21+
c, b = tl.prepro.parse_darknet_ann_list_to_cls_box(ann)
22+
ann_list.append([c, b])
23+
24+
n_epoch = 10
25+
batch_size = 64
26+
im_size = [416, 416]
27+
jitter = 0.2
28+
shuffle_buffer_size = 100
29+
30+
31+
def generator():
32+
inputs = imgs_file_list
33+
targets = objs_info_list
34+
assert len(inputs) == len(targets)
35+
for _input, _target in zip(inputs, targets):
36+
yield _input.encode('utf-8'), _target.encode('utf-8')
37+
38+
39+
def _data_aug_fn(im, ann):
40+
## parse annotation
41+
ann = ann.decode()
42+
ann = tl.prepro.parse_darknet_ann_str_to_list(ann)
43+
clas, coords = tl.prepro.parse_darknet_ann_list_to_cls_box(ann)
44+
## random brightness, contrast and saturation
45+
im = tl.prepro.brightness(im, gamma=0.5, gain=1, is_random=True)
46+
# im = tl.prepro.illumination(im, gamma=(0.5, 1.5),
47+
# contrast=(0.5, 1.5), saturation=(0.5, 1.5), is_random=True) # TypeError: Cannot handle this data type
48+
## random horizontal flip
49+
im, coords = tl.prepro.obj_box_left_right_flip(im, coords, is_rescale=True, is_center=True, is_random=True)
50+
## random resize and crop
51+
tmp0 = random.randint(1, int(im_size[0] * jitter))
52+
tmp1 = random.randint(1, int(im_size[1] * jitter))
53+
im, coords = tl.prepro.obj_box_imresize(im, coords, [im_size[0] + tmp0, im_size[1] + tmp1], is_rescale=True, interp='bicubic')
54+
im, clas, coords = tl.prepro.obj_box_crop(im, clas, coords, wrg=im_size[1], hrg=im_size[0], is_rescale=True, is_center=True, is_random=True)
55+
## value [0, 255] to [-1, 1] (optional)
56+
# im = im / 127.5 - 1
57+
## value [0, 255] to [0, 1] (optional)
58+
im = im / 255
59+
im = np.array(im, dtype=np.float32) # important
60+
return im, str([clas, coords]).encode('utf-8')
61+
62+
63+
def _map_fn(filename, annotation):
64+
## read image
65+
image = tf.read_file(filename)
66+
image = tf.image.decode_jpeg(image, channels=3)
67+
image = tf.image.convert_image_dtype(image, dtype=tf.float32)
68+
## data augmentation
69+
image, annotation = tf.py_func(_data_aug_fn, [image, annotation], [tf.float32, tf.string])
70+
return image, annotation
71+
72+
73+
ds = tf.data.Dataset().from_generator(generator, output_types=(tf.string, tf.string))
74+
ds = ds.map(_map_fn, num_parallel_calls=multiprocessing.cpu_count())
75+
ds = ds.repeat(n_epoch)
76+
ds = ds.shuffle(shuffle_buffer_size)
77+
ds = ds.batch(batch_size)
78+
value = ds.make_one_shot_iterator().get_next()
79+
80+
sess = tf.InteractiveSession()
81+
82+
## get a batch of images (after data augmentation)
83+
_, _ = sess.run(value) # 1st time takes time to compile
84+
st = time.time()
85+
im, annbyte = sess.run(value)
86+
print('took {}s'.format(time.time() - st))
87+
88+
ann = []
89+
for a in annbyte:
90+
a = a.decode()
91+
ann.append(json.loads(a))
92+
93+
## save all images
94+
for i in range(len(im)):
95+
tl.vis.draw_boxes_and_labels_to_image(im[i] * 255, ann[i][0], ann[i][1], [], classes, True, save_name='_bbox_vis_%d.png' % i)

0 commit comments

Comments
 (0)