Skip to content

Commit d325ab6

Browse files
authored
Merge pull request #1075 from Laicheng0830/Add_load_ckpt
add load ckpt weights
2 parents a1755a3 + 21329d2 commit d325ab6

File tree

4 files changed

+191
-0
lines changed

4 files changed

+191
-0
lines changed
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
#! /usr/bin/python
2+
# -*- coding: utf-8 -*-
3+
4+
import tensorlayer as tl
5+
from tensorlayer.layers import (Input, Conv2d, Flatten, Dense, MaxPool2d)
6+
from tensorlayer.models import Model
7+
from tensorlayer.files import maybe_download_and_extract
8+
import numpy as np
9+
import tensorflow as tf
10+
11+
filename = 'ckpt_parameters.zip'
12+
url_score = 'https://media.githubusercontent.com/media/tensorlayer/pretrained-models/master/models/'
13+
14+
# download weights
15+
down_file = tl.files.maybe_download_and_extract(
16+
filename=filename, working_directory='model/', url_source=url_score, extract=True
17+
)
18+
19+
model_file = 'model/ckpt_parameters'
20+
21+
# ckpt to npz, rename_key used to match TL naming rule
22+
tl.files.ckpt_to_npz_dict(model_file, rename_key=True)
23+
weights = np.load('model.npz', allow_pickle=True)
24+
25+
# View the parameters and weights shape
26+
for key in weights.keys():
27+
print(key, weights[key].shape)
28+
29+
30+
# build model
31+
def create_model(inputs_shape):
32+
W_init = tl.initializers.truncated_normal(stddev=5e-2)
33+
W_init2 = tl.initializers.truncated_normal(stddev=0.04)
34+
ni = Input(inputs_shape)
35+
nn = Conv2d(64, (3, 3), (1, 1), padding='SAME', act=tf.nn.relu, W_init=W_init, name='conv1_1')(ni)
36+
nn = MaxPool2d((2, 2), (2, 2), padding='SAME', name='pool1_1')(nn)
37+
nn = Conv2d(64, (3, 3), (1, 1), padding='SAME', act=tf.nn.relu, W_init=W_init, b_init=None, name='conv1_2')(nn)
38+
nn = MaxPool2d((2, 2), (2, 2), padding='SAME', name='pool1_2')(nn)
39+
40+
nn = Conv2d(128, (3, 3), (1, 1), padding='SAME', act=tf.nn.relu, W_init=W_init, b_init=None, name='conv2_1')(nn)
41+
nn = MaxPool2d((2, 2), (2, 2), padding='SAME', name='pool2_1')(nn)
42+
nn = Conv2d(128, (3, 3), (1, 1), padding='SAME', act=tf.nn.relu, W_init=W_init, b_init=None, name='conv2_2')(nn)
43+
nn = MaxPool2d((2, 2), (2, 2), padding='SAME', name='pool2_2')(nn)
44+
45+
nn = Conv2d(256, (3, 3), (1, 1), padding='SAME', act=tf.nn.relu, W_init=W_init, b_init=None, name='conv3_1')(nn)
46+
nn = MaxPool2d((2, 2), (2, 2), padding='SAME', name='pool3_1')(nn)
47+
nn = Conv2d(256, (3, 3), (1, 1), padding='SAME', act=tf.nn.relu, W_init=W_init, b_init=None, name='conv3_2')(nn)
48+
nn = MaxPool2d((2, 2), (2, 2), padding='SAME', name='pool3_2')(nn)
49+
50+
nn = Conv2d(512, (3, 3), (1, 1), padding='SAME', act=tf.nn.relu, W_init=W_init, b_init=None, name='conv4_1')(nn)
51+
nn = MaxPool2d((2, 2), (2, 2), padding='SAME', name='pool4_1')(nn)
52+
nn = Conv2d(512, (3, 3), (1, 1), padding='SAME', act=tf.nn.relu, W_init=W_init, b_init=None, name='conv4_2')(nn)
53+
nn = MaxPool2d((2, 2), (2, 2), padding='SAME', name='pool4_2')(nn)
54+
55+
nn = Flatten(name='flatten')(nn)
56+
nn = Dense(1000, act=None, W_init=W_init2, name='output')(nn)
57+
58+
M = Model(inputs=ni, outputs=nn, name='cnn')
59+
return M
60+
61+
62+
net = create_model([None, 224, 224, 3])
63+
# loaded weights whose name is not found in network's weights will be skipped.
64+
# If ckpt has the same naming rule as TL, We can restore the model with tl.files.load_and_assign_ckpt(model_dir=, network=, skip=True)
65+
tl.files.load_and_assign_npz_dict(network=net, skip=True)
66+
67+
# you can use the following code to view the restore the model parameters.
68+
net_weights_name = [w.name for w in net.all_weights]
69+
for i in range(len(net_weights_name)):
70+
print(net_weights_name[i], net.all_weights[net_weights_name.index(net_weights_name[i])])

tensorlayer/files/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,4 +72,6 @@
7272
#'load_graph',
7373
#'save_graph_and_params',
7474
#'load_graph_and_params',
75+
'load_and_assign_ckpt',
76+
'ckpt_to_npz_dict'
7577
]

tensorlayer/files/utils.py

Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from tensorflow.python.platform import gfile
2727
from tensorflow.python.util import serialization
2828
from tensorflow.python.util.tf_export import keras_export
29+
from tensorflow.python import pywrap_tensorflow
2930

3031
import progressbar
3132
import tensorlayer as tl
@@ -76,6 +77,8 @@
7677
'static_graph2net',
7778
# 'save_pkl_graph',
7879
# 'load_pkl_graph',
80+
'load_and_assign_ckpt',
81+
'ckpt_to_npz_dict',
7982
]
8083

8184

@@ -2775,3 +2778,119 @@ def load_hdf5_to_weights(filepath, network, skip=False):
27752778

27762779
f.close()
27772780
logging.info("[*] Load %s SUCCESS!" % filepath)
2781+
2782+
2783+
def check_ckpt_file(model_dir):
2784+
model_dir = model_dir
2785+
model_path = None
2786+
count_extension = 0
2787+
for root, dirs, files in os.walk(model_dir):
2788+
for file in files:
2789+
filename, extension = os.path.splitext(file)
2790+
if extension in ['.data-00000-of-00001', '.index', '.meta']:
2791+
count_extension += 1
2792+
if count_extension == 3:
2793+
model_path = model_dir + '/' + filename
2794+
else:
2795+
raise Exception("Check the file extension for missing .data-00000-of-00001, .index, .meta")
2796+
if model_path is None:
2797+
raise Exception('The ckpt file is not found')
2798+
return model_path, filename
2799+
2800+
2801+
def rename_weight_or_biases(variable_name):
2802+
if variable_name is None:
2803+
return variable_name
2804+
split_var = variable_name.split('/')
2805+
2806+
str_temp = ''
2807+
for i in range(len(split_var)):
2808+
if 'w' in split_var[i]:
2809+
split_var[i] = 'filters:0'
2810+
elif 'b' in split_var[i]:
2811+
split_var[i] = 'biases:0'
2812+
else:
2813+
pass
2814+
2815+
if i < len(split_var) - 1:
2816+
str_temp = str_temp + split_var[i] + '/'
2817+
else:
2818+
str_temp = str_temp + split_var[i]
2819+
2820+
return str_temp
2821+
2822+
2823+
def load_and_assign_ckpt(model_dir, network=None, skip=True):
2824+
"""Load weights by name from a given file of ckpt format
2825+
2826+
Parameters
2827+
----------
2828+
model_dir : str
2829+
Filename to which the weights will be loaded, should be of ckpt format.
2830+
Examples: model_dir = /root/cnn_model/
2831+
network : Model
2832+
TL model.
2833+
skip : bool
2834+
If 'skip' == True, loaded weights whose name is not found in 'weights' will be skipped. If 'skip' is False,
2835+
error will be raised when mismatch is found. Default False.
2836+
2837+
Returns
2838+
-------
2839+
2840+
"""
2841+
model_path, filename = check_ckpt_file(model_dir)
2842+
2843+
reader = pywrap_tensorflow.NewCheckpointReader(model_path)
2844+
var_to_shape_map = reader.get_variable_to_shape_map()
2845+
2846+
net_weights_name = [w.name for w in network.all_weights]
2847+
2848+
for key in var_to_shape_map:
2849+
if key not in net_weights_name:
2850+
if skip:
2851+
logging.warning("Weights named '%s' not found in network. Skip it." % key)
2852+
else:
2853+
raise RuntimeError(
2854+
"Weights named '%s' not found in network. Hint: set argument skip=Ture "
2855+
"if you want to skip redundant or mismatch weights." % key
2856+
)
2857+
else:
2858+
assign_tf_variable(network.all_weights[net_weights_name.index(key)], reader.get_tensor(key))
2859+
logging.info("[*] Model restored from ckpt %s" % filename)
2860+
2861+
2862+
def ckpt_to_npz_dict(model_dir, save_name='model.npz', rename_key=False):
2863+
""" Save ckpt weights to npz file
2864+
2865+
Parameters
2866+
----------
2867+
model_dir : str
2868+
Filename to which the weights will be loaded, should be of ckpt format.
2869+
Examples: model_dir = /root/cnn_model/
2870+
save_name : str
2871+
The save_name of the `.npz` file.
2872+
rename_key : bool
2873+
Modify parameter naming, used to match TL naming rule.
2874+
Examples: conv1_1/b_b --> conv1_1/biases:0 ; conv1_1/w_w --> conv1_1/filters:0
2875+
2876+
Returns
2877+
-------
2878+
2879+
"""
2880+
model_path, _ = check_ckpt_file(model_dir)
2881+
2882+
reader = pywrap_tensorflow.NewCheckpointReader(model_path)
2883+
var_to_shape_map = reader.get_variable_to_shape_map()
2884+
2885+
parameters_dict = {}
2886+
if rename_key is False:
2887+
for key in sorted(var_to_shape_map):
2888+
parameters_dict[key] = reader.get_tensor(key)
2889+
elif rename_key is True:
2890+
for key in sorted(var_to_shape_map):
2891+
parameters_dict[rename_weight_or_biases(key)] = reader.get_tensor(key)
2892+
2893+
np.savez(save_name, **parameters_dict)
2894+
parameters_dict = None
2895+
del parameters_dict
2896+
logging.info("[*] Ckpt weights saved in npz_dict %s" % save_name)

tl

100755100644
File mode changed.

0 commit comments

Comments
 (0)