|
12 | 12 | import zipfile |
13 | 13 | from . import visualize |
14 | 14 | from . import nlp |
| 15 | +from . import utils |
15 | 16 | import pickle |
16 | 17 | from six.moves import urllib |
17 | 18 | from six.moves import cPickle |
@@ -706,6 +707,217 @@ def load_celebA_dataset(dirpath='data'): |
706 | 707 | data_files[i] = os.path.join(image_path, data_files[i]) |
707 | 708 | return data_files |
708 | 709 |
|
| 710 | +def load_voc_dataset(path='data/VOC', dataset='2012', contain_classes_in_person=False): |
| 711 | + """Pascal VOC 2012 Dataset has 20 objects ``"aeroplane", "bicycle", "bird", |
| 712 | + "boat", "bottle", "bus", "car", "cat", "chair", "cow", "diningtable", |
| 713 | + "dog", "horse", "motorbike", "person", "pottedplant", "sheep", "sofa", |
| 714 | + "train", "tvmonitor"`` and additional 3 classes ``"head", "hand", "foot"`` |
| 715 | + for person. |
| 716 | +
|
| 717 | + Parameters |
| 718 | + ----------- |
| 719 | + path : string |
| 720 | + The path that the data is downloaded to, defaults is ``data/VOC``. |
| 721 | + dataset : string, 2012 or 2007 |
| 722 | + The VOC dataset version. |
| 723 | + contain_classes_in_person : If True, dataset will contains labels of head, hand and foot. |
| 724 | +
|
| 725 | + Returns |
| 726 | + --------- |
| 727 | + imgs_file_list : list of string. |
| 728 | + Full paths of all images. |
| 729 | + imgs_semseg_file_list : list of string. |
| 730 | + Full paths of all maps for semantic segmentation. Note that not all images have this map! |
| 731 | + imgs_insseg_file_list : list of string. |
| 732 | + Full paths of all maps for instance segmentation. Note that not all images have this map! |
| 733 | + imgs_ann_file_list : list of string. |
| 734 | + Full paths of all annotations for bounding box and object class, all images have this annotations. |
| 735 | + classes : list of string. |
| 736 | + Classes in order. |
| 737 | + classes_in_person : list of string. |
| 738 | + Classes in person. |
| 739 | + classes_dict : dictionary. |
| 740 | + Class label to integer. |
| 741 | + n_objs_list : list of integer |
| 742 | + Number of objects in all images in ``imgs_file_list` in order. |
| 743 | + objs_info_list : list of string. |
| 744 | + Darknet format for the annotation of all images in ``imgs_file_list`` in order. ``[class_id x_centre y_centre width height]`` in ratio format. |
| 745 | + objs_info_dicts : dictionary. |
| 746 | + ``{imgs_file_list : dictionary for annotation}``, the annotation of all images in ``imgs_file_list``, |
| 747 | + format from `TensorFlow/Models/object-detection <https://github.com/tensorflow/models/blob/master/object_detection/create_pascal_tf_record.py>`_. |
| 748 | +
|
| 749 | + References |
| 750 | + ------------- |
| 751 | + - `Pascal VOC2012 Website <http://host.robots.ox.ac.uk/pascal/VOC/voc2012/#devkit>`_. |
| 752 | + - `Pascal VOC2007 Website <http://host.robots.ox.ac.uk/pascal/VOC/voc2007/>`_. |
| 753 | + - `TensorFlow/Models/object-detection <https://github.com/zsdonghao/object-detection/blob/master/g3doc/preparing_inputs.md>`_. |
| 754 | + """ |
| 755 | + |
| 756 | + def _recursive_parse_xml_to_dict(xml): |
| 757 | + """Recursively parses XML contents to python dict. |
| 758 | + We assume that `object` tags are the only ones that can appear |
| 759 | + multiple times at the same level of a tree. |
| 760 | +
|
| 761 | + Args: |
| 762 | + xml: xml tree obtained by parsing XML file contents using lxml.etree |
| 763 | +
|
| 764 | + Returns: |
| 765 | + Python dictionary holding XML contents. |
| 766 | + """ |
| 767 | + if not xml: |
| 768 | + # if xml is not None: |
| 769 | + return {xml.tag: xml.text} |
| 770 | + result = {} |
| 771 | + for child in xml: |
| 772 | + child_result = _recursive_parse_xml_to_dict(child) |
| 773 | + if child.tag != 'object': |
| 774 | + result[child.tag] = child_result[child.tag] |
| 775 | + else: |
| 776 | + if child.tag not in result: |
| 777 | + result[child.tag] = [] |
| 778 | + result[child.tag].append(child_result[child.tag]) |
| 779 | + return {xml.tag: result} |
| 780 | + |
| 781 | + from lxml import etree # pip install lxml |
| 782 | + import xml.etree.ElementTree as ET |
| 783 | + |
| 784 | + ## |
| 785 | + if dataset == "2012": |
| 786 | + url = "http://host.robots.ox.ac.uk/pascal/VOC/voc2012/" |
| 787 | + tar_filename = "VOCtrainval_11-May-2012.tar" |
| 788 | + extracted_filename = "VOC2012"#"VOCdevkit/VOC2012" |
| 789 | + print(" [============= VOC 2012 =============]") |
| 790 | + elif dataset == "2007": |
| 791 | + url = "http://host.robots.ox.ac.uk/pascal/VOC/voc2007/" |
| 792 | + tar_filename = "VOCtrainval_06-Nov-2007.tar" |
| 793 | + extracted_filename = "VOC2007" |
| 794 | + print(" [============= VOC 2007 =============]") |
| 795 | + else: |
| 796 | + raise Exception("Please set the dataset aug to either 2012 or 2007.") |
| 797 | + |
| 798 | + ##======== download dataset |
| 799 | + if folder_exists(path+"/"+extracted_filename) is False: |
| 800 | + print("[VOC] {} is nonexistent in {}".format(extracted_filename, path)) |
| 801 | + maybe_download_and_extract(tar_filename, path, url, extract=True) |
| 802 | + del_file(path+'/'+tar_filename) |
| 803 | + if dataset == "2012": |
| 804 | + os.system("mv {}/VOCdevkit/VOC2012 {}/VOC2012".format(path, path)) |
| 805 | + elif dataset == "2007": |
| 806 | + os.system("mv {}/VOCdevkit/VOC2007 {}/VOC2007".format(path, path)) |
| 807 | + del_folder(path+'/VOCdevkit') |
| 808 | + ##======== object classes(labels) NOTE: YOU CAN CUSTOMIZE THIS LIST |
| 809 | + classes = ["aeroplane", "bicycle", "bird", "boat", "bottle", "bus", "car", |
| 810 | + "cat", "chair", "cow", "diningtable", "dog", "horse", "motorbike", |
| 811 | + "person", "pottedplant", "sheep", "sofa", "train", "tvmonitor"] |
| 812 | + if contain_classes_in_person: |
| 813 | + classes_in_person = ["head", "hand", "foot"] |
| 814 | + else: |
| 815 | + classes_in_person = [] |
| 816 | + |
| 817 | + classes += classes_in_person # use extra 3 classes for person |
| 818 | + |
| 819 | + classes_dict = utils.list_string_to_dict(classes) |
| 820 | + print("[VOC] object classes {}".format(classes_dict)) |
| 821 | + |
| 822 | + ##======== 1. image path list |
| 823 | + folder_imgs = path+"/"+extracted_filename+"/JPEGImages/" |
| 824 | + imgs_file_list = load_file_list(path=folder_imgs, regx='\\.jpg', printable=False) |
| 825 | + print("[VOC] {} images found".format(len(imgs_file_list))) |
| 826 | + imgs_file_list.sort(key=lambda s : int(s.replace('.',' ').replace('_', '').split(' ')[-2])) # 2007_000027.jpg --> 2007000027 |
| 827 | + imgs_file_list = [folder_imgs+s for s in imgs_file_list] |
| 828 | + # print('IM',imgs_file_list[0::3333], imgs_file_list[-1]) |
| 829 | + ##======== 2. semantic segmentation maps path list |
| 830 | + folder_semseg = path+"/"+extracted_filename+"/SegmentationClass/" |
| 831 | + imgs_semseg_file_list = load_file_list(path=folder_semseg, regx='\\.png', printable=False) |
| 832 | + print("[VOC] {} maps for semantic segmentation found".format(len(imgs_semseg_file_list))) |
| 833 | + imgs_semseg_file_list.sort(key=lambda s : int(s.replace('.',' ').replace('_', '').split(' ')[-2])) # 2007_000032.png --> 2007000032 |
| 834 | + imgs_semseg_file_list = [folder_semseg+s for s in imgs_semseg_file_list] |
| 835 | + # print('Semantic Seg IM',imgs_semseg_file_list[0::333], imgs_semseg_file_list[-1]) |
| 836 | + ##======== 3. instance segmentation maps path list |
| 837 | + folder_insseg = path+"/"+extracted_filename+"/SegmentationObject/" |
| 838 | + imgs_insseg_file_list = load_file_list(path=folder_insseg, regx='\\.png', printable=False) |
| 839 | + print("[VOC] {} maps for instance segmentation found".format(len(imgs_semseg_file_list))) |
| 840 | + imgs_insseg_file_list.sort(key=lambda s : int(s.replace('.',' ').replace('_', '').split(' ')[-2])) # 2007_000032.png --> 2007000032 |
| 841 | + imgs_insseg_file_list = [folder_semseg+s for s in imgs_insseg_file_list] |
| 842 | + # print('Instance Seg IM',imgs_insseg_file_list[0::333], imgs_insseg_file_list[-1]) |
| 843 | + ##======== 4. annotations for bounding box and object class |
| 844 | + folder_ann = path+"/"+extracted_filename+"/Annotations/" |
| 845 | + imgs_ann_file_list = load_file_list(path=folder_ann, regx='\\.xml', printable=False) |
| 846 | + print("[VOC] {} XML annotation files for bounding box and object class found".format(len(imgs_ann_file_list))) |
| 847 | + imgs_ann_file_list.sort(key=lambda s : int(s.replace('.',' ').replace('_', '').split(' ')[-2])) # 2007_000027.xml --> 2007000027 |
| 848 | + imgs_ann_file_list = [folder_ann+s for s in imgs_ann_file_list] |
| 849 | + # print('ANN',imgs_ann_file_list[0::3333], imgs_ann_file_list[-1]) |
| 850 | + ##======== parse XML annotations |
| 851 | + def convert(size, box): |
| 852 | + dw = 1./size[0] |
| 853 | + dh = 1./size[1] |
| 854 | + x = (box[0] + box[1])/2.0 |
| 855 | + y = (box[2] + box[3])/2.0 |
| 856 | + w = box[1] - box[0] |
| 857 | + h = box[3] - box[2] |
| 858 | + x = x*dw |
| 859 | + w = w*dw |
| 860 | + y = y*dh |
| 861 | + h = h*dh |
| 862 | + return (x,y,w,h) |
| 863 | + |
| 864 | + def convert_annotation(file_name): |
| 865 | + """ Given VOC2012 XML Annotations, returns number of objects and info. """ |
| 866 | + in_file = open(file_name) |
| 867 | + out_file = "" |
| 868 | + tree = ET.parse(in_file) |
| 869 | + root = tree.getroot() |
| 870 | + size = root.find('size') |
| 871 | + w = int(size.find('width').text) |
| 872 | + h = int(size.find('height').text) |
| 873 | + n_objs = 0 |
| 874 | + |
| 875 | + for obj in root.iter('object'): |
| 876 | + difficult = obj.find('difficult').text |
| 877 | + cls = obj.find('name').text |
| 878 | + if cls not in classes or int(difficult) == 1: |
| 879 | + continue |
| 880 | + cls_id = classes.index(cls) |
| 881 | + xmlbox = obj.find('bndbox') |
| 882 | + b = (float(xmlbox.find('xmin').text), float(xmlbox.find('xmax').text), float(xmlbox.find('ymin').text), float(xmlbox.find('ymax').text)) |
| 883 | + bb = convert((w,h), b) |
| 884 | + # out_file.write(str(cls_id) + " " + " ".join([str(a) for a in bb]) + '\n') |
| 885 | + out_file += str(cls_id) + " " + " ".join([str(a) for a in bb]) + '\n' |
| 886 | + n_objs += 1 |
| 887 | + if cls in "person": |
| 888 | + for part in obj.iter('part'): |
| 889 | + cls = part.find('name').text |
| 890 | + if cls not in classes_in_person: |
| 891 | + continue |
| 892 | + cls_id = classes.index(cls) |
| 893 | + xmlbox = part.find('bndbox') |
| 894 | + b = (float(xmlbox.find('xmin').text), float(xmlbox.find('xmax').text), float(xmlbox.find('ymin').text), float(xmlbox.find('ymax').text)) |
| 895 | + bb = convert((w,h), b) |
| 896 | + # out_file.write(str(cls_id) + " " + " ".join([str(a) for a in bb]) + '\n') |
| 897 | + out_file += str(cls_id) + " " + " ".join([str(a) for a in bb]) + '\n' |
| 898 | + n_objs += 1 |
| 899 | + in_file.close() |
| 900 | + return n_objs, out_file |
| 901 | + |
| 902 | + print("[VOC] Parsing xml annotations files") |
| 903 | + n_objs_list = [] |
| 904 | + objs_info_list = [] # Darknet Format list of string |
| 905 | + objs_info_dicts = {} |
| 906 | + for idx, ann_file in enumerate(imgs_ann_file_list): |
| 907 | + n_objs, objs_info = convert_annotation(ann_file) |
| 908 | + n_objs_list.append(n_objs) |
| 909 | + objs_info_list.append(objs_info) |
| 910 | + with tf.gfile.GFile(ann_file, 'r') as fid: |
| 911 | + xml_str = fid.read() |
| 912 | + xml = etree.fromstring(xml_str) |
| 913 | + data = _recursive_parse_xml_to_dict(xml)['annotation'] |
| 914 | + objs_info_dicts.update({imgs_file_list[idx]: data}) |
| 915 | + |
| 916 | + return imgs_file_list, imgs_semseg_file_list, imgs_insseg_file_list, imgs_ann_file_list, \ |
| 917 | + classes, classes_in_person, classes_dict,\ |
| 918 | + n_objs_list, objs_info_list, objs_info_dicts |
| 919 | + |
| 920 | + |
709 | 921 | ## Load and save network list npz |
710 | 922 | def save_npz(save_list=[], name='model.npz', sess=None): |
711 | 923 | """Input parameters and the file name, save parameters into .npz file. Use tl.utils.load_npz() to restore. |
|
0 commit comments