Skip to content

Commit 21329d2

Browse files
committed
add tutorial load ckpt
1 parent 677cbde commit 21329d2

File tree

2 files changed

+123
-23
lines changed

2 files changed

+123
-23
lines changed
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
#! /usr/bin/python
2+
# -*- coding: utf-8 -*-
3+
4+
import tensorlayer as tl
5+
from tensorlayer.layers import (Input, Conv2d, Flatten, Dense, MaxPool2d)
6+
from tensorlayer.models import Model
7+
from tensorlayer.files import maybe_download_and_extract
8+
import numpy as np
9+
import tensorflow as tf
10+
11+
filename = 'ckpt_parameters.zip'
12+
url_score = 'https://media.githubusercontent.com/media/tensorlayer/pretrained-models/master/models/'
13+
14+
# download weights
15+
down_file = tl.files.maybe_download_and_extract(
16+
filename=filename, working_directory='model/', url_source=url_score, extract=True
17+
)
18+
19+
model_file = 'model/ckpt_parameters'
20+
21+
# ckpt to npz, rename_key used to match TL naming rule
22+
tl.files.ckpt_to_npz_dict(model_file, rename_key=True)
23+
weights = np.load('model.npz', allow_pickle=True)
24+
25+
# View the parameters and weights shape
26+
for key in weights.keys():
27+
print(key, weights[key].shape)
28+
29+
30+
# build model
31+
def create_model(inputs_shape):
32+
W_init = tl.initializers.truncated_normal(stddev=5e-2)
33+
W_init2 = tl.initializers.truncated_normal(stddev=0.04)
34+
ni = Input(inputs_shape)
35+
nn = Conv2d(64, (3, 3), (1, 1), padding='SAME', act=tf.nn.relu, W_init=W_init, name='conv1_1')(ni)
36+
nn = MaxPool2d((2, 2), (2, 2), padding='SAME', name='pool1_1')(nn)
37+
nn = Conv2d(64, (3, 3), (1, 1), padding='SAME', act=tf.nn.relu, W_init=W_init, b_init=None, name='conv1_2')(nn)
38+
nn = MaxPool2d((2, 2), (2, 2), padding='SAME', name='pool1_2')(nn)
39+
40+
nn = Conv2d(128, (3, 3), (1, 1), padding='SAME', act=tf.nn.relu, W_init=W_init, b_init=None, name='conv2_1')(nn)
41+
nn = MaxPool2d((2, 2), (2, 2), padding='SAME', name='pool2_1')(nn)
42+
nn = Conv2d(128, (3, 3), (1, 1), padding='SAME', act=tf.nn.relu, W_init=W_init, b_init=None, name='conv2_2')(nn)
43+
nn = MaxPool2d((2, 2), (2, 2), padding='SAME', name='pool2_2')(nn)
44+
45+
nn = Conv2d(256, (3, 3), (1, 1), padding='SAME', act=tf.nn.relu, W_init=W_init, b_init=None, name='conv3_1')(nn)
46+
nn = MaxPool2d((2, 2), (2, 2), padding='SAME', name='pool3_1')(nn)
47+
nn = Conv2d(256, (3, 3), (1, 1), padding='SAME', act=tf.nn.relu, W_init=W_init, b_init=None, name='conv3_2')(nn)
48+
nn = MaxPool2d((2, 2), (2, 2), padding='SAME', name='pool3_2')(nn)
49+
50+
nn = Conv2d(512, (3, 3), (1, 1), padding='SAME', act=tf.nn.relu, W_init=W_init, b_init=None, name='conv4_1')(nn)
51+
nn = MaxPool2d((2, 2), (2, 2), padding='SAME', name='pool4_1')(nn)
52+
nn = Conv2d(512, (3, 3), (1, 1), padding='SAME', act=tf.nn.relu, W_init=W_init, b_init=None, name='conv4_2')(nn)
53+
nn = MaxPool2d((2, 2), (2, 2), padding='SAME', name='pool4_2')(nn)
54+
55+
nn = Flatten(name='flatten')(nn)
56+
nn = Dense(1000, act=None, W_init=W_init2, name='output')(nn)
57+
58+
M = Model(inputs=ni, outputs=nn, name='cnn')
59+
return M
60+
61+
62+
net = create_model([None, 224, 224, 3])
63+
# loaded weights whose name is not found in network's weights will be skipped.
64+
# If ckpt has the same naming rule as TL, We can restore the model with tl.files.load_and_assign_ckpt(model_dir=, network=, skip=True)
65+
tl.files.load_and_assign_npz_dict(network=net, skip=True)
66+
67+
# you can use the following code to view the restore the model parameters.
68+
net_weights_name = [w.name for w in net.all_weights]
69+
for i in range(len(net_weights_name)):
70+
print(net_weights_name[i], net.all_weights[net_weights_name.index(net_weights_name[i])])

tensorlayer/files/utils.py

Lines changed: 53 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -2780,6 +2780,46 @@ def load_hdf5_to_weights(filepath, network, skip=False):
27802780
logging.info("[*] Load %s SUCCESS!" % filepath)
27812781

27822782

2783+
def check_ckpt_file(model_dir):
2784+
model_dir = model_dir
2785+
model_path = None
2786+
count_extension = 0
2787+
for root, dirs, files in os.walk(model_dir):
2788+
for file in files:
2789+
filename, extension = os.path.splitext(file)
2790+
if extension in ['.data-00000-of-00001', '.index', '.meta']:
2791+
count_extension += 1
2792+
if count_extension == 3:
2793+
model_path = model_dir + '/' + filename
2794+
else:
2795+
raise Exception("Check the file extension for missing .data-00000-of-00001, .index, .meta")
2796+
if model_path is None:
2797+
raise Exception('The ckpt file is not found')
2798+
return model_path, filename
2799+
2800+
2801+
def rename_weight_or_biases(variable_name):
2802+
if variable_name is None:
2803+
return variable_name
2804+
split_var = variable_name.split('/')
2805+
2806+
str_temp = ''
2807+
for i in range(len(split_var)):
2808+
if 'w' in split_var[i]:
2809+
split_var[i] = 'filters:0'
2810+
elif 'b' in split_var[i]:
2811+
split_var[i] = 'biases:0'
2812+
else:
2813+
pass
2814+
2815+
if i < len(split_var) - 1:
2816+
str_temp = str_temp + split_var[i] + '/'
2817+
else:
2818+
str_temp = str_temp + split_var[i]
2819+
2820+
return str_temp
2821+
2822+
27832823
def load_and_assign_ckpt(model_dir, network=None, skip=True):
27842824
"""Load weights by name from a given file of ckpt format
27852825
@@ -2798,16 +2838,7 @@ def load_and_assign_ckpt(model_dir, network=None, skip=True):
27982838
-------
27992839
28002840
"""
2801-
model_dir = model_dir
2802-
model_path = None
2803-
for root, dirs, files in os.walk(model_dir):
2804-
for file in files:
2805-
filename, extension = os.path.splitext(file)
2806-
if extension in ['.data-00000-of-00001', '.index', '.meta']:
2807-
model_path = model_dir + '/' + filename
2808-
break
2809-
if model_path == None:
2810-
raise Exception('The ckpt file is not found')
2841+
model_path, filename = check_ckpt_file(model_dir)
28112842

28122843
reader = pywrap_tensorflow.NewCheckpointReader(model_path)
28132844
var_to_shape_map = reader.get_variable_to_shape_map()
@@ -2828,7 +2859,7 @@ def load_and_assign_ckpt(model_dir, network=None, skip=True):
28282859
logging.info("[*] Model restored from ckpt %s" % filename)
28292860

28302861

2831-
def ckpt_to_npz_dict(model_dir, save_name='model.npz'):
2862+
def ckpt_to_npz_dict(model_dir, save_name='model.npz', rename_key=False):
28322863
""" Save ckpt weights to npz file
28332864
28342865
Parameters
@@ -2838,28 +2869,27 @@ def ckpt_to_npz_dict(model_dir, save_name='model.npz'):
28382869
Examples: model_dir = /root/cnn_model/
28392870
save_name : str
28402871
The save_name of the `.npz` file.
2872+
rename_key : bool
2873+
Modify parameter naming, used to match TL naming rule.
2874+
Examples: conv1_1/b_b --> conv1_1/biases:0 ; conv1_1/w_w --> conv1_1/filters:0
28412875
28422876
Returns
28432877
-------
28442878
28452879
"""
2846-
model_dir = model_dir
2847-
model_path = None
2848-
for root, dirs, files in os.walk(model_dir):
2849-
for file in files:
2850-
filename, extension = os.path.splitext(file)
2851-
if extension in ['.data-00000-of-00001', '.index', '.meta']:
2852-
model_path = model_dir + '/' + filename
2853-
break
2854-
if model_path == None:
2855-
raise Exception('The ckpt file is not found')
2880+
model_path, _ = check_ckpt_file(model_dir)
28562881

28572882
reader = pywrap_tensorflow.NewCheckpointReader(model_path)
28582883
var_to_shape_map = reader.get_variable_to_shape_map()
28592884

28602885
parameters_dict = {}
2861-
for key in sorted(var_to_shape_map):
2862-
parameters_dict[key] = reader.get_tensor(key)
2886+
if rename_key is False:
2887+
for key in sorted(var_to_shape_map):
2888+
parameters_dict[key] = reader.get_tensor(key)
2889+
elif rename_key is True:
2890+
for key in sorted(var_to_shape_map):
2891+
parameters_dict[rename_weight_or_biases(key)] = reader.get_tensor(key)
2892+
28632893
np.savez(save_name, **parameters_dict)
28642894
parameters_dict = None
28652895
del parameters_dict

0 commit comments

Comments
 (0)