Skip to content

Commit d7a59c6

Browse files
authored
Merge pull request #482 from tensorlayer/mpii-file
Load and visualise MPII dataset in 1 line of code
2 parents 44faa91 + f0992a3 commit d7a59c6

File tree

5 files changed

+344
-0
lines changed

5 files changed

+344
-0
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ TensorLayer is a deep learning and reinforcement learning library on top of [Ten
2424
- Useful links: [Documentation](http://tensorlayer.readthedocs.io), [Examples](http://tensorlayer.readthedocs.io/en/latest/user/example.html), [中文文档](https://tensorlayercn.readthedocs.io), [中文书](http://www.broadview.com.cn/book/5059)
2525

2626
# News
27+
* [10 Apr] Load and visualize MPII dataset in one line of code.
2728
* [05 Apr] Release [models APIs](http://tensorlayer.readthedocs.io/en/latest/modules/models.html#) for well-known pretained networks.
2829
* [18 Mar] Release experimental APIs for binary networks.
2930
* [18 Jan] [《深度学习:一起玩转TensorLayer》](http://www.broadview.com.cn/book/5059) (Deep Learning using TensorLayer)

docs/modules/files.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ API - Files
2020
load_cyclegan_dataset
2121
load_celebA_dataset
2222
load_voc_dataset
23+
load_mpii_pose_dataset
2324
download_file_from_google_drive
2425

2526
save_npz
@@ -108,6 +109,10 @@ VOC 2007/2012
108109
^^^^^^^^^^^^^^^^
109110
.. autofunction:: load_voc_dataset
110111

112+
MPII
113+
^^^^^^^^^^^^^^^^
114+
.. autofunction:: load_mpii_pose_dataset
115+
111116
Google Drive
112117
^^^^^^^^^^^^^^^^
113118
.. autofunction:: download_file_from_google_drive

docs/modules/visualize.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ to visualize the model, activations etc. Here we provide more functions for data
1313
save_image
1414
save_images
1515
draw_boxes_and_labels_to_image
16+
draw_mpii_people_to_image
1617
draw_weights
1718
CNN2d
1819
frame
@@ -44,6 +45,9 @@ Save image for object detection
4445
.. autofunction:: draw_boxes_and_labels_to_image
4546

4647

48+
Save image for pose estimation (MPII)
49+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
50+
.. autofunction:: draw_mpii_people_to_image
4751

4852
Visualize model parameters
4953
------------------------------

tensorlayer/files.py

Lines changed: 227 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@
6565
'download_file_from_google_drive',
6666
'load_celebA_dataset',
6767
'load_voc_dataset',
68+
'load_mpii_pose_dataset',
6869
'save_npz',
6970
'load_npz',
7071
'assign_params',
@@ -1317,6 +1318,232 @@ def convert_annotation(file_name):
13171318
n_objs_list, objs_info_list, objs_info_dicts
13181319

13191320

1321+
def load_mpii_pose_dataset(path='data', is_16_pos_only=False):
1322+
"""Load MPII Human Pose Dataset.
1323+
1324+
Parameters
1325+
-----------
1326+
path : str
1327+
The path that the data is downloaded to.
1328+
is_16_pos_only : boolean
1329+
If True, only return the peoples contain 16 pose keypoints. (Usually be used for single person pose estimation)
1330+
1331+
Returns
1332+
----------
1333+
img_train_list : list of str
1334+
The image directories of training data.
1335+
ann_train_list : list of dict
1336+
The annotations of training data.
1337+
img_test_list : list of str
1338+
The image directories of testing data.
1339+
ann_test_list : list of dict
1340+
The annotations of testing data.
1341+
1342+
Examples
1343+
--------
1344+
>>> import pprint
1345+
>>> import tensorlayer as tl
1346+
>>> img_train_list, ann_train_list, img_test_list, ann_test_list = tl.files.load_mpii_pose_dataset()
1347+
>>> image = tl.vis.read_image(img_train_list[0])
1348+
>>> tl.vis.draw_mpii_people_to_image(image, ann_train_list[0], 'image.png')
1349+
>>> pprint.pprint(ann_train_list[0])
1350+
1351+
References
1352+
-----------
1353+
- `MPII Human Pose Dataset. CVPR 14 <http://human-pose.mpi-inf.mpg.de>`__
1354+
- `MPII Human Pose Models. CVPR 16 <http://pose.mpi-inf.mpg.de>`__
1355+
- `MPII Human Shape, Poselet Conditioned Pictorial Structures and etc <http://pose.mpi-inf.mpg.de/#related>`__
1356+
- `MPII Keyponts and ID <http://human-pose.mpi-inf.mpg.de/#download>`__
1357+
"""
1358+
path = os.path.join(path, 'mpii_human_pose')
1359+
logging.info("Load or Download MPII Human Pose > {}".format(path))
1360+
1361+
# annotation
1362+
url = "http://datasets.d2.mpi-inf.mpg.de/andriluka14cvpr/"
1363+
tar_filename = "mpii_human_pose_v1_u12_2.zip"
1364+
extracted_filename = "mpii_human_pose_v1_u12_2"
1365+
if folder_exists(os.path.join(path, extracted_filename)) is False:
1366+
logging.info("[MPII] (annotation) {} is nonexistent in {}".format(extracted_filename, path))
1367+
maybe_download_and_extract(tar_filename, path, url, extract=True)
1368+
del_file(os.path.join(path, tar_filename))
1369+
1370+
# images
1371+
url = "http://datasets.d2.mpi-inf.mpg.de/andriluka14cvpr/"
1372+
tar_filename = "mpii_human_pose_v1.tar.gz"
1373+
extracted_filename2 = "images"
1374+
if folder_exists(os.path.join(path, extracted_filename2)) is False:
1375+
logging.info("[MPII] (images) {} is nonexistent in {}".format(extracted_filename, path))
1376+
maybe_download_and_extract(tar_filename, path, url, extract=True)
1377+
del_file(os.path.join(path, tar_filename))
1378+
1379+
# parse annotation, format see http://human-pose.mpi-inf.mpg.de/#download
1380+
import scipy.io as sio
1381+
logging.info("reading annotations from mat file ...")
1382+
# mat = sio.loadmat(os.path.join(path, extracted_filename, "mpii_human_pose_v1_u12_1.mat"))
1383+
1384+
# def fix_wrong_joints(joint): # https://github.com/mitmul/deeppose/blob/master/datasets/mpii_dataset.py
1385+
# if '12' in joint and '13' in joint and '2' in joint and '3' in joint:
1386+
# if ((joint['12'][0] < joint['13'][0]) and
1387+
# (joint['3'][0] < joint['2'][0])):
1388+
# joint['2'], joint['3'] = joint['3'], joint['2']
1389+
# if ((joint['12'][0] > joint['13'][0]) and
1390+
# (joint['3'][0] > joint['2'][0])):
1391+
# joint['2'], joint['3'] = joint['3'], joint['2']
1392+
# return joint
1393+
1394+
ann_train_list = []
1395+
ann_test_list = []
1396+
img_train_list = []
1397+
img_test_list = []
1398+
1399+
def save_joints():
1400+
# joint_data_fn = os.path.join(path, 'data.json')
1401+
# fp = open(joint_data_fn, 'w')
1402+
mat = sio.loadmat(os.path.join(path, extracted_filename, "mpii_human_pose_v1_u12_1.mat"))
1403+
1404+
for _, (anno, train_flag) in enumerate( # all images
1405+
zip(mat['RELEASE']['annolist'][0, 0][0], mat['RELEASE']['img_train'][0, 0][0])):
1406+
1407+
img_fn = anno['image']['name'][0, 0][0]
1408+
train_flag = int(train_flag)
1409+
1410+
# print(i, img_fn, train_flag) # DEBUG print all images
1411+
1412+
if train_flag:
1413+
img_train_list.append(img_fn)
1414+
ann_train_list.append([])
1415+
else:
1416+
img_test_list.append(img_fn)
1417+
ann_test_list.append([])
1418+
1419+
head_rect = []
1420+
if 'x1' in str(anno['annorect'].dtype):
1421+
head_rect = zip([x1[0, 0] for x1 in anno['annorect']['x1'][0]], [y1[0, 0] for y1 in anno['annorect']['y1'][0]],
1422+
[x2[0, 0] for x2 in anno['annorect']['x2'][0]], [y2[0, 0] for y2 in anno['annorect']['y2'][0]])
1423+
else:
1424+
head_rect = [] # TODO
1425+
1426+
if 'annopoints' in str(anno['annorect'].dtype):
1427+
annopoints = anno['annorect']['annopoints'][0]
1428+
head_x1s = anno['annorect']['x1'][0]
1429+
head_y1s = anno['annorect']['y1'][0]
1430+
head_x2s = anno['annorect']['x2'][0]
1431+
head_y2s = anno['annorect']['y2'][0]
1432+
for annopoint, head_x1, head_y1, head_x2, head_y2 in zip(annopoints, head_x1s, head_y1s, head_x2s, head_y2s):
1433+
if annopoint != []:
1434+
head_rect = [float(head_x1[0, 0]), float(head_y1[0, 0]), float(head_x2[0, 0]), float(head_y2[0, 0])]
1435+
1436+
# joint coordinates
1437+
annopoint = annopoint['point'][0, 0]
1438+
j_id = [str(j_i[0, 0]) for j_i in annopoint['id'][0]]
1439+
x = [x[0, 0] for x in annopoint['x'][0]]
1440+
y = [y[0, 0] for y in annopoint['y'][0]]
1441+
joint_pos = {}
1442+
for _j_id, (_x, _y) in zip(j_id, zip(x, y)):
1443+
joint_pos[int(_j_id)] = [float(_x), float(_y)]
1444+
# joint_pos = fix_wrong_joints(joint_pos)
1445+
1446+
# visiblity list
1447+
if 'is_visible' in str(annopoint.dtype):
1448+
vis = [v[0] if v else [0] for v in annopoint['is_visible'][0]]
1449+
vis = dict([(k, int(v[0])) if len(v) > 0 else v for k, v in zip(j_id, vis)])
1450+
else:
1451+
vis = None
1452+
1453+
# if len(joint_pos) == 16:
1454+
if ((is_16_pos_only == True) and (len(joint_pos) == 16)) or (is_16_pos_only == False):
1455+
# only use image with 16 key points / or use all
1456+
data = {'filename': img_fn, 'train': train_flag, 'head_rect': head_rect, 'is_visible': vis, 'joint_pos': joint_pos}
1457+
# print(json.dumps(data), file=fp) # py3
1458+
if train_flag:
1459+
ann_train_list[-1].append(data)
1460+
else:
1461+
ann_test_list[-1].append(data)
1462+
1463+
# def write_line(datum, fp):
1464+
# joints = sorted([[int(k), v] for k, v in datum['joint_pos'].items()])
1465+
# joints = np.array([j for i, j in joints]).flatten()
1466+
#
1467+
# out = [datum['filename']]
1468+
# out.extend(joints)
1469+
# out = [str(o) for o in out]
1470+
# out = ','.join(out)
1471+
#
1472+
# print(out, file=fp)
1473+
1474+
# def split_train_test():
1475+
# # fp_test = open('data/mpii/test_joints.csv', 'w')
1476+
# fp_test = open(os.path.join(path, 'test_joints.csv'), 'w')
1477+
# # fp_train = open('data/mpii/train_joints.csv', 'w')
1478+
# fp_train = open(os.path.join(path, 'train_joints.csv'), 'w')
1479+
# # all_data = open('data/mpii/data.json').readlines()
1480+
# all_data = open(os.path.join(path, 'data.json')).readlines()
1481+
# N = len(all_data)
1482+
# N_test = int(N * 0.1)
1483+
# N_train = N - N_test
1484+
#
1485+
# print('N:{}'.format(N))
1486+
# print('N_train:{}'.format(N_train))
1487+
# print('N_test:{}'.format(N_test))
1488+
#
1489+
# np.random.seed(1701)
1490+
# perm = np.random.permutation(N)
1491+
# test_indices = perm[:N_test]
1492+
# train_indices = perm[N_test:]
1493+
#
1494+
# print('train_indices:{}'.format(len(train_indices)))
1495+
# print('test_indices:{}'.format(len(test_indices)))
1496+
#
1497+
# for i in train_indices:
1498+
# datum = json.loads(all_data[i].strip())
1499+
# write_line(datum, fp_train)
1500+
#
1501+
# for i in test_indices:
1502+
# datum = json.loads(all_data[i].strip())
1503+
# write_line(datum, fp_test)
1504+
1505+
save_joints()
1506+
# split_train_test() #
1507+
1508+
## read images dir
1509+
logging.info("reading images list ...")
1510+
img_dir = os.path.join(path, extracted_filename2)
1511+
_img_list = load_file_list(path=os.path.join(path, extracted_filename2), regx='\\.jpg', printable=False)
1512+
# ann_list = json.load(open(os.path.join(path, 'data.json')))
1513+
for i, im in enumerate(img_train_list):
1514+
if im not in _img_list:
1515+
print('missing training image {} in {} (remove from img(ann)_train_list)'.format(im, img_dir))
1516+
# img_train_list.remove(im)
1517+
del img_train_list[i]
1518+
del ann_train_list[i]
1519+
for i, im in enumerate(img_test_list):
1520+
if im not in _img_list:
1521+
print('missing testing image {} in {} (remove from img(ann)_test_list)'.format(im, img_dir))
1522+
# img_test_list.remove(im)
1523+
del img_train_list[i]
1524+
del ann_train_list[i]
1525+
1526+
## check annotation and images
1527+
n_train_images = len(img_train_list)
1528+
n_test_images = len(img_test_list)
1529+
n_images = n_train_images + n_test_images
1530+
logging.info("n_images: {} n_train_images: {} n_test_images: {}".format(n_images, n_train_images, n_test_images))
1531+
n_train_ann = len(ann_train_list)
1532+
n_test_ann = len(ann_test_list)
1533+
n_ann = n_train_ann + n_test_ann
1534+
logging.info("n_ann: {} n_train_ann: {} n_test_ann: {}".format(n_ann, n_train_ann, n_test_ann))
1535+
n_train_people = len(sum(ann_train_list, []))
1536+
n_test_people = len(sum(ann_test_list, []))
1537+
n_people = n_train_people + n_test_people
1538+
logging.info("n_people: {} n_train_people: {} n_test_people: {}".format(n_people, n_train_people, n_test_people))
1539+
# add path to all image file name
1540+
for i, value in enumerate(img_train_list):
1541+
img_train_list[i] = os.path.join(img_dir, value)
1542+
for i, value in enumerate(img_test_list):
1543+
img_test_list[i] = os.path.join(img_dir, value)
1544+
return img_train_list, ann_train_list, img_test_list, ann_test_list
1545+
1546+
13201547
def save_npz(save_list=None, name='model.npz', sess=None):
13211548
"""Input parameters and the file name, save parameters into .npz file. Use tl.utils.load_npz() to restore.
13221549

0 commit comments

Comments
 (0)