|
| 1 | +#! /usr/bin/python |
| 2 | +# -*- coding: utf8 -*- |
| 3 | + |
| 4 | +# tf import data dataset.map https://www.tensorflow.org/programmers_guide/datasets#applying_arbitrary_python_logic_with_tfpy_func |
| 5 | +# tf.py_func https://www.tensorflow.org/api_docs/python/tf/py_func |
| 6 | +# tl ref: https://github.com/tensorlayer/tensorlayer/blob/master/example/tutorial_imagenet_inceptionV3_distributed.py |
| 7 | +# cn ref: https://blog.csdn.net/dQCFKyQDXYm3F8rB0/article/details/79342369 |
| 8 | +# cn ref: https://zhuanlan.zhihu.com/p/31466173 |
| 9 | + |
| 10 | +import numpy as np |
| 11 | +import multiprocessing, random, json, time |
| 12 | +import tensorflow as tf |
| 13 | +import tensorlayer as tl |
| 14 | + |
| 15 | +imgs_file_list, _, _, _, classes, _, _,\ |
| 16 | + _, objs_info_list, _ = tl.files.load_voc_dataset(dataset="2007") |
| 17 | + |
| 18 | +ann_list = [] |
| 19 | +for info in objs_info_list: |
| 20 | + ann = tl.prepro.parse_darknet_ann_str_to_list(info) |
| 21 | + c, b = tl.prepro.parse_darknet_ann_list_to_cls_box(ann) |
| 22 | + ann_list.append([c, b]) |
| 23 | + |
| 24 | +n_epoch = 10 |
| 25 | +batch_size = 64 |
| 26 | +im_size = [416, 416] |
| 27 | +jitter = 0.2 |
| 28 | +shuffle_buffer_size = 100 |
| 29 | + |
| 30 | + |
| 31 | +def generator(): |
| 32 | + inputs = imgs_file_list |
| 33 | + targets = objs_info_list |
| 34 | + assert len(inputs) == len(targets) |
| 35 | + for _input, _target in zip(inputs, targets): |
| 36 | + yield _input.encode('utf-8'), _target.encode('utf-8') |
| 37 | + |
| 38 | + |
| 39 | +def _data_aug_fn(im, ann): |
| 40 | + ## parse annotation |
| 41 | + ann = ann.decode() |
| 42 | + ann = tl.prepro.parse_darknet_ann_str_to_list(ann) |
| 43 | + clas, coords = tl.prepro.parse_darknet_ann_list_to_cls_box(ann) |
| 44 | + ## random brightness, contrast and saturation |
| 45 | + im = tl.prepro.brightness(im, gamma=0.5, gain=1, is_random=True) |
| 46 | + # im = tl.prepro.illumination(im, gamma=(0.5, 1.5), |
| 47 | + # contrast=(0.5, 1.5), saturation=(0.5, 1.5), is_random=True) # TypeError: Cannot handle this data type |
| 48 | + ## random horizontal flip |
| 49 | + im, coords = tl.prepro.obj_box_left_right_flip(im, coords, is_rescale=True, is_center=True, is_random=True) |
| 50 | + ## random resize and crop |
| 51 | + tmp0 = random.randint(1, int(im_size[0] * jitter)) |
| 52 | + tmp1 = random.randint(1, int(im_size[1] * jitter)) |
| 53 | + im, coords = tl.prepro.obj_box_imresize(im, coords, [im_size[0] + tmp0, im_size[1] + tmp1], is_rescale=True, interp='bicubic') |
| 54 | + im, clas, coords = tl.prepro.obj_box_crop(im, clas, coords, wrg=im_size[1], hrg=im_size[0], is_rescale=True, is_center=True, is_random=True) |
| 55 | + ## value [0, 255] to [-1, 1] (optional) |
| 56 | + # im = im / 127.5 - 1 |
| 57 | + ## value [0, 255] to [0, 1] (optional) |
| 58 | + im = im / 255 |
| 59 | + im = np.array(im, dtype=np.float32) # important |
| 60 | + return im, str([clas, coords]).encode('utf-8') |
| 61 | + |
| 62 | + |
| 63 | +def _map_fn(filename, annotation): |
| 64 | + ## read image |
| 65 | + image = tf.read_file(filename) |
| 66 | + image = tf.image.decode_jpeg(image, channels=3) |
| 67 | + image = tf.image.convert_image_dtype(image, dtype=tf.float32) |
| 68 | + ## data augmentation |
| 69 | + image, annotation = tf.py_func(_data_aug_fn, [image, annotation], [tf.float32, tf.string]) |
| 70 | + return image, annotation |
| 71 | + |
| 72 | + |
| 73 | +ds = tf.data.Dataset().from_generator(generator, output_types=(tf.string, tf.string)) |
| 74 | +ds = ds.map(_map_fn, num_parallel_calls=multiprocessing.cpu_count()) |
| 75 | +ds = ds.repeat(n_epoch) |
| 76 | +ds = ds.shuffle(shuffle_buffer_size) |
| 77 | +ds = ds.batch(batch_size) |
| 78 | +value = ds.make_one_shot_iterator().get_next() |
| 79 | + |
| 80 | +sess = tf.InteractiveSession() |
| 81 | + |
| 82 | +## get a batch of images (after data augmentation) |
| 83 | +_, _ = sess.run(value) # 1st time takes time to compile |
| 84 | +st = time.time() |
| 85 | +im, annbyte = sess.run(value) |
| 86 | +print('took {}s'.format(time.time() - st)) |
| 87 | + |
| 88 | +ann = [] |
| 89 | +for a in annbyte: |
| 90 | + a = a.decode() |
| 91 | + ann.append(json.loads(a)) |
| 92 | + |
| 93 | +## save all images |
| 94 | +for i in range(len(im)): |
| 95 | + tl.vis.draw_boxes_and_labels_to_image(im[i] * 255, ann[i][0], ann[i][1], [], classes, True, save_name='_bbox_vis_%d.png' % i) |
0 commit comments