Skip to content

Commit f09bb58

Browse files
authored
Add the PyTorch and Paddle model parameters to import TensorLayerX (#19)
1 parent df26565 commit f09bb58

File tree

2 files changed

+332
-0
lines changed

2 files changed

+332
-0
lines changed
Lines changed: 218 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,218 @@
1+
#! /usr/bin/python
2+
# -*- coding: utf-8 -*-
3+
# From the https://github.com/tensorlayer/TensorLayerX/issues/11
4+
# The author: @qiutzh
5+
6+
import os
7+
os.environ['TL_BACKEND'] = 'paddle'
8+
# os.environ['TL_BACKEND'] = 'tensorflow'
9+
import tensorlayerx.nn as nn
10+
from tensorlayerx import logging
11+
from tensorlayerx.files import assign_weights
12+
from paddle.utils.download import get_weights_path_from_url
13+
import numpy as np
14+
import paddle
15+
from paddle import to_tensor
16+
from PIL import Image
17+
import copy
18+
import tensorlayerx as tlx
19+
from examples.model_zoo.imagenet_classes import class_names
20+
21+
__all__ = []
22+
23+
model_urls = {
24+
'tlxvgg16': ('https://paddle-hapi.bj.bcebos.com/models/vgg16.pdparams',
25+
'89bbffc0f87d260be9b8cdc169c991c4'),
26+
'tlxvgg19': ('https://paddle-hapi.bj.bcebos.com/models/vgg19.pdparams',
27+
'23b18bb13d8894f60f54e642be79a0dd')
28+
}
29+
30+
31+
class VGG(nn.Module):
32+
"""VGG model from
33+
`"Very Deep Convolutional Networks For Large-Scale Image Recognition" <https://arxiv.org/pdf/1409.1556.pdf>`_
34+
Args:
35+
features (nn.Layer): Vgg features create by function make_layers.
36+
num_classes (int): Output dim of last fc layer. If num_classes <=0, last fc layer
37+
will not be defined. Default: 1000.
38+
with_pool (bool): Use pool before the last three fc layer or not. Default: True.
39+
Examples:
40+
.. code-block:: python
41+
from paddle.vision.models import VGG
42+
from paddle.vision.models.vgg import make_layers
43+
vgg11_cfg = [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M']
44+
features = make_layers(vgg11_cfg)
45+
vgg11 = VGG(features)
46+
"""
47+
48+
def __init__(self, features, num_classes=1000, with_pool=True):
49+
super(VGG, self).__init__()
50+
self.features = features
51+
self.num_classes = num_classes
52+
self.with_pool = with_pool
53+
54+
if self.with_pool:
55+
self.avgpool = nn.AdaptiveAvgPool2d((7, 7))
56+
57+
if num_classes > 0:
58+
self.classifier = nn.Sequential(
59+
nn.Linear(out_features=4096, act=None, in_features=512 * 7 * 7),
60+
nn.ReLU(),
61+
nn.Linear(out_features=4096, act=None, in_features=4096),
62+
nn.ReLU(),
63+
nn.Linear(in_features=4096, out_features=num_classes),
64+
)
65+
66+
def forward(self, x):
67+
print(self.features[0](x).shape)
68+
x = self.features(x)
69+
print("Conv shape", x.shape)
70+
# if self.with_pool:
71+
# x = self.avgpool(x)
72+
if self.num_classes > 0:
73+
x = paddle.flatten(x, 1)
74+
print('x.numpy =', x.shape)
75+
x = self.classifier(x)
76+
return x
77+
78+
79+
def make_layers(cfg, batch_norm=False):
80+
layers = []
81+
in_channels = 3
82+
for v in cfg:
83+
if v == 'M':
84+
layers += [nn.MaxPool2d(kernel_size=2, stride=2, padding=0, data_format='channels_first')] # padding默认为'SAME'
85+
else:
86+
conv2d = nn.Conv2d(out_channels=v, kernel_size=(3, 3), stride=(1, 1), act=None, padding=1,
87+
in_channels=in_channels, data_format='channels_first')
88+
if batch_norm:
89+
layers += [conv2d, nn.BatchNorm2d(num_features=v, data_format='channels_first'), nn.ReLU()]
90+
else:
91+
layers += [conv2d, nn.ReLU()]
92+
in_channels = v
93+
return nn.Sequential(*layers)
94+
95+
96+
cfgs = {
97+
'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
98+
'B':
99+
[64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
100+
'D': [
101+
64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512,
102+
512, 512, 'M'
103+
],
104+
'E': [
105+
64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512,
106+
'M', 512, 512, 512, 512, 'M'
107+
],
108+
}
109+
####################新增pd2tlx#################
110+
pd2tlx = {'features.0.weight': 'features.0.W',
111+
'features.2.weight': 'features.2.W',
112+
'features.5.weight': 'features.5.W',
113+
'features.7.weight': 'features.7.W',
114+
'features.10.weight': 'features.10.W',
115+
'features.12.weight': 'features.12.W',
116+
'features.14.weight': 'features.14.W',
117+
'features.17.weight': 'features.17.W',
118+
'features.19.weight': 'features.19.W',
119+
'features.21.weight': 'features.21.W',
120+
'features.24.weight': 'features.24.W',
121+
'features.26.weight': 'features.26.W',
122+
'features.28.weight': 'features.28.W',
123+
'features.0.bias': 'features.0.b',
124+
'features.2.bias': 'features.2.b',
125+
'features.5.bias': 'features.5.b',
126+
'features.7.bias': 'features.7.b',
127+
'features.10.bias': 'features.10.b',
128+
'features.12.bias': 'features.12.b',
129+
'features.14.bias': 'features.14.b',
130+
'features.17.bias': 'features.17.b',
131+
'features.19.bias': 'features.19.b',
132+
'features.21.bias': 'features.21.b',
133+
'features.24.bias': 'features.24.b',
134+
'features.26.bias': 'features.26.b',
135+
'features.28.bias': 'features.28.b',
136+
'classifier.0.weight': 'classifier.0.W',
137+
'classifier.3.weight': 'classifier.2.W',
138+
'classifier.6.weight': 'classifier.4.W',
139+
'classifier.0.bias': 'classifier.0.b',
140+
'classifier.3.bias': 'classifier.2.b',
141+
'classifier.6.bias': 'classifier.4.b'}
142+
143+
144+
def get_new_weight(param):
145+
'''新增函数,调整参数key'''
146+
new_param = {}
147+
for key in param.keys():
148+
new_param[pd2tlx[key]] = param[key]
149+
print(key, ":", param[key].shape, "vs", pd2tlx[key], ":", new_param[pd2tlx[key]].shape)
150+
return new_param
151+
152+
153+
def restore_model(param, model, model_type='vgg16'):
154+
""" 直接restore """
155+
weights = []
156+
if model_type == 'vgg16':
157+
for val in param.items():
158+
# for val in sorted(param.items()):
159+
weights.append(val[1])
160+
if len(model.all_weights) == len(weights):
161+
break
162+
elif model_type == 'vgg19':
163+
pass
164+
# assign weight values
165+
assign_weights(weights, model)
166+
del weights
167+
168+
169+
def _tlxvgg(arch, cfg, batch_norm, pretrained, **kwargs):
170+
model = VGG(make_layers(cfgs[cfg], batch_norm=batch_norm), **kwargs)
171+
if pretrained:
172+
assert arch in model_urls, "{} model do not have a pretrained model now, you should set pretrained=False".format(
173+
arch)
174+
weight_path = get_weights_path_from_url(model_urls[arch][0],
175+
model_urls[arch][1])
176+
param = paddle.load(weight_path)
177+
# model.load_dict(param)
178+
# new_param = get_new_weight(param)
179+
# model.load_dict(new_param)
180+
restore_model(param, model)
181+
return model
182+
183+
184+
def tlxvgg16(pretrained=False, batch_norm=False, **kwargs):
185+
"""VGG 16-layer model
186+
187+
Args:
188+
pretrained (bool): If True, returns a model pre-trained on ImageNet. Default: False.
189+
batch_norm (bool): If True, returns a model with batch_norm layer. Default: False.
190+
Examples:
191+
.. code-block:: python
192+
from paddle.vision.models import vgg16
193+
# build model
194+
model = vgg16()
195+
# build vgg16 model with batch_norm
196+
model = vgg16(batch_norm=True)
197+
"""
198+
model_name = 'tlxvgg16'
199+
if batch_norm:
200+
model_name += ('_bn')
201+
return _tlxvgg(model_name, 'D', batch_norm, pretrained, **kwargs)
202+
203+
204+
if __name__ == "__main__":
205+
model = tlxvgg16(pretrained=True, batch_norm=False)
206+
model.set_eval()
207+
for w in model.trainable_weights:
208+
print(w.name, w.shape)
209+
# get the whole model
210+
img = tlx.vision.load_image('data/tiger.jpeg')
211+
img = tlx.vision.transforms.Resize((224, 224))(img).astype(np.float32) / 255
212+
img = paddle.unsqueeze(paddle.Tensor(img), 0)
213+
img = tlx.ops.nhwc_to_nchw(img)
214+
output = model(img)
215+
probs = tlx.ops.softmax(output)[0].numpy()
216+
preds = (np.argsort(probs)[::-1])[0:5]
217+
for p in preds:
218+
print(class_names[p], probs[p])
Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
#! /usr/bin/python
2+
# -*- coding: utf-8 -*-
3+
4+
import os
5+
# os.environ['TL_BACKEND'] = 'tensorflow'
6+
os.environ['TL_BACKEND'] = 'paddle'
7+
import numpy as np
8+
9+
import torch
10+
import tensorlayerx as tlx
11+
from tensorlayerx import nn
12+
13+
"""
14+
Save pytorch parameters to the a.pth
15+
"""
16+
from torch import nn as th_nn
17+
class B(th_nn.Module):
18+
def __init__(self):
19+
super(B, self).__init__()
20+
self.conv1 = th_nn.Conv2d(3, 16, kernel_size=1)
21+
self.conv2 = th_nn.Conv2d(16, 16, kernel_size=1)
22+
self.bn1 = th_nn.BatchNorm2d(16)
23+
self.act = th_nn.ReLU(inplace=True)
24+
25+
def forward(self, x):
26+
return self.act(self.bn1(self.conv2(self.conv1(x))))
27+
28+
"""
29+
Load the pytorch parameters a.pth to TensorLayerX
30+
"""
31+
class A(nn.Module):
32+
def __init__(self):
33+
super(A, self).__init__()
34+
self.conv1 = nn.Conv2d(16, kernel_size=1, in_channels=3, data_format='channels_first')
35+
self.conv2 = nn.Conv2d(16, kernel_size=1, in_channels=16, data_format='channels_first')
36+
self.bn1 = nn.BatchNorm2d(num_features=16, data_format='channels_first')
37+
self.act = nn.activation.ReLU()
38+
39+
def forward(self, x):
40+
return self.act(self.bn1(self.conv2(self.conv1(x))))
41+
42+
43+
def pth2npz(pth_path, npz_path):
44+
tl_npz = {}
45+
temp = torch.load(pth_path)
46+
print("Pytorch parameter names and parameter shapes:")
47+
for key in temp.keys():
48+
print(key, temp[key].shape)
49+
50+
print("Parameter names and parameter shapes of the renamed PyTorch:")
51+
for key in temp.keys():
52+
tl_npz[def_rename_torch_key(key)] = def_torch_weight_reshape(temp[key])
53+
print(def_rename_torch_key(key), def_torch_weight_reshape(temp[key]).shape)
54+
np.savez(npz_path, **tl_npz)
55+
56+
57+
def def_rename_torch_key(key):
58+
# Define parameter naming rules that convert the parameter names of PyTorch to TensorLayerX.
59+
# Only the name changes of Conv2d and BatchNorm2d in this example are given.
60+
# Different code styles may not be applicable, so you need to customize the here.
61+
split_key = key.split('.')
62+
if 'conv' in key and 'weight' in split_key[1]:
63+
key = 'conv2d_' + key.split('.')[0][-1] + '/' + 'filters'
64+
if 'conv' in key and 'bias' in split_key[1]:
65+
key = 'conv2d_' + key.split('.')[0][-1] + '/' + 'biases'
66+
if 'bn' in key and 'weight' in split_key[1]:
67+
key = 'batchnorm2d_' + key.split('.')[0][-1] + '/' + 'gamma'
68+
if 'bn' in key and 'bias' in split_key[1]:
69+
key = 'batchnorm2d_' + key.split('.')[0][-1] + '/' + 'beta'
70+
if 'bn' in key and 'running_mean' in split_key[1]:
71+
key = 'batchnorm2d_' + key.split('.')[0][-1] + '/' + 'moving_mean'
72+
if 'bn' in key and 'running_var' in split_key[1]:
73+
key = 'batchnorm2d_' + key.split('.')[0][-1] + '/' + 'moving_var'
74+
return key
75+
76+
def def_torch_weight_reshape(weight):
77+
# The shape of the TensorFlow parameter is [ksize, ksize, in_channel, out_channel]
78+
if tlx.BACKEND == 'tensorflow':
79+
if isinstance(weight, int):
80+
return weight
81+
if len(weight.shape) == 4:
82+
weight = torch.moveaxis(weight, (1, 0), (2, 3))
83+
if len(weight.shape) == 5:
84+
weight = np.moveaxis(weight, (1, 0), (3, 4))
85+
return weight
86+
87+
if __name__ == '__main__':
88+
# Step1: save pytorch model parameters to a.pth
89+
# On the first run, uncomment lines 90 and 91.
90+
# b = B()
91+
# torch.save(a.state_dict(), 'a.pth')
92+
93+
a = A()
94+
# Step2: Converts pytorch a.pth to the model parameter format of tensorlayerx
95+
pth2npz('a.pth', 'a.npz')
96+
# View the parameter name and size of the tensorlayerx
97+
print("TensorLayer parameter names and parameter shapes:")
98+
for w in a.all_weights:
99+
print(w.name, w.shape)
100+
101+
# Step3: Load model parameters to tensorlayerx
102+
tlx.files.load_and_assign_npz_dict('a.npz', a, skip=True)
103+
a.set_eval()
104+
105+
# Perform tensorlayerx inference to output the value at position [0][0]
106+
print("TensorLayerX outputs[0][0]:", a(tlx.nn.Input(shape=(5, 3, 3, 3)))[0][0])
107+
108+
# load torch parameters
109+
b = B()
110+
b.eval()
111+
weights = torch.load('a.pth')
112+
b.load_state_dict(weights)
113+
# Perform pytorch inference to output the value at position [0][0]
114+
print("PyTorch outputs[0][0]:", b(torch.ones((5, 3, 3, 3)))[0][0])

0 commit comments

Comments
 (0)