Skip to content

Commit 1856837

Browse files
committed
[layers] Compatible with Keras !!!
1 parent 35917ee commit 1856837

File tree

3 files changed

+124
-1
lines changed

3 files changed

+124
-1
lines changed

docs/modules/layers.rst

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -318,6 +318,7 @@ Layer list
318318
TileLayer
319319

320320
SlimNetsLayer
321+
KerasLayer
321322

322323
PReluLayer
323324

@@ -610,6 +611,14 @@ see `Slim-model <https://github.com/tensorflow/models/tree/master/slim#Install>`
610611

611612
.. autoclass:: SlimNetsLayer
612613

614+
Connect Keras
615+
------------------
616+
617+
Yes ! Keras models can be connected into TensorLayer! see `tutorial_keras.py <https://github.com/zsdonghao/tensorlayer/blob/master/example/tutorial_keras.py>`_ .
618+
619+
.. autoclass:: KerasLayer
620+
621+
613622
Parametric activation layer
614623
---------------------------
615624

example/tutorial_keras.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
#! /usr/bin/python
2+
# -*- coding: utf8 -*-
3+
4+
import numpy as np
5+
import tensorflow as tf
6+
import tensorlayer as tl
7+
import time
8+
from keras import backend as K
9+
from keras.layers import *
10+
from tensorlayer.layers import *
11+
12+
X_train, y_train, X_val, y_val, X_test, y_test = \
13+
tl.files.load_mnist_dataset(shape=(-1, 784))
14+
15+
sess = tf.InteractiveSession()
16+
17+
batch_size = 128
18+
x = tf.placeholder(tf.float32, shape=[None, 784])
19+
y_ = tf.placeholder(tf.int64, shape=[None,])
20+
21+
def keras_block(x, is_train=True):
22+
x = Dropout(0.8)(x)
23+
x = Dense(800, activation='relu')(x)
24+
x = Dropout(0.5)(x)
25+
x = Dense(800, activation='relu')(x)
26+
x = Dropout(0.5)(x)
27+
logits = Dense(10, activation='linear')(x)
28+
return logits
29+
30+
network = InputLayer(x, name='input')
31+
network = KerasLayer(network, keras_layer=keras_block, name='keras')
32+
33+
y = network.outputs
34+
network.print_params(False)
35+
network.print_layers()
36+
37+
cost = tl.cost.cross_entropy(y, y_, 'cost')
38+
correct_prediction = tf.equal(tf.argmax(y, 1), y_)
39+
acc = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
40+
41+
n_epoch = 200
42+
learning_rate = 0.0001
43+
44+
train_params = network.all_params
45+
train_op = tf.train.AdamOptimizer(learning_rate, beta1=0.9, beta2=0.999,
46+
epsilon=1e-08, use_locking=False).minimize(cost, var_list=train_params)
47+
48+
tl.layers.initialize_global_variables(sess)
49+
50+
for epoch in range(n_epoch):
51+
start_time = time.time()
52+
## Training
53+
for X_train_a, y_train_a in tl.iterate.minibatches(
54+
X_train, y_train, batch_size, shuffle=True):
55+
_, _ = sess.run([cost, train_op], feed_dict={x: X_train_a, y_: y_train_a,
56+
K.learning_phase(): 1})
57+
58+
print("Epoch %d of %d took %fs" % (epoch + 1, n_epoch, time.time() - start_time))
59+
## Evaluation
60+
train_loss, train_acc, n_batch = 0, 0, 0
61+
for X_train_a, y_train_a in tl.iterate.minibatches(
62+
X_train, y_train, batch_size, shuffle=False):
63+
err, ac = sess.run([cost, acc], feed_dict={x: X_train_a, y_: y_train_a,
64+
K.learning_phase(): 0})
65+
train_loss += err; train_acc += ac; n_batch += 1
66+
print(" train loss: %f" % (train_loss/ n_batch))
67+
print(" train acc: %f" % (train_acc/ n_batch))
68+
val_loss, val_acc, n_batch = 0, 0, 0
69+
for X_val_a, y_val_a in tl.iterate.minibatches(
70+
X_val, y_val, batch_size, shuffle=False):
71+
err, ac = sess.run([cost, acc], feed_dict={x: X_val_a, y_: y_val_a,
72+
K.learning_phase(): 0})
73+
val_loss += err; val_acc += ac; n_batch += 1
74+
print(" val loss: %f" % (val_loss/ n_batch))
75+
print(" val acc: %f" % (val_acc/ n_batch))

tensorlayer/layers.py

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4413,9 +4413,11 @@ def __init__(
44134413
layer = None,
44144414
slim_layer = None,
44154415
slim_args = {},
4416-
name ='InceptionV3',
4416+
name ='tfslim_layer',
44174417
):
44184418
Layer.__init__(self, name=name)
4419+
assert slim_layer is not None
4420+
assert slim_args is not None
44194421
self.inputs = layer.outputs
44204422
print(" [TL] SlimNetsLayer %s: %s" % (self.name, slim_layer.__name__))
44214423

@@ -4444,6 +4446,43 @@ def __init__(
44444446
self.all_layers.extend( slim_layers )
44454447
self.all_params.extend( slim_variables )
44464448

4449+
## Keras layer
4450+
class KerasLayer(Layer):
4451+
"""
4452+
The :class:`KerasLayer` class can be used to merge all Keras layers into
4453+
TensorLayer. Example can be found here `tutorial_keras.py <https://github.com/zsdonghao/tensorlayer/blob/master/example/tutorial_keras.py>`_
4454+
4455+
Parameters
4456+
----------
4457+
layer : a list of :class:`Layer` instances
4458+
The `Layer` class feeding into this layer.
4459+
keras_layer : a keras network function
4460+
keras_args : dictionary
4461+
The arguments for the keras model.
4462+
name : a string or None
4463+
An optional name to attach to this layer.
4464+
"""
4465+
def __init__(
4466+
self,
4467+
layer = None,
4468+
keras_layer = None,
4469+
keras_args = {},
4470+
name ='keras_layer',
4471+
):
4472+
Layer.__init__(self, name=name)
4473+
assert layer is not None
4474+
assert keras_layer is not None
4475+
self.inputs = layer.outputs
4476+
print(" [TL] KerasLayer %s: %s" % (self.name, keras_layer))
4477+
with tf.variable_scope(name) as vs:
4478+
self.outputs = keras_layer(self.inputs, **keras_args)
4479+
variables = tf.get_collection(TF_GRAPHKEYS_VARIABLES, scope=vs.name)
4480+
self.all_layers = list(layer.all_layers)
4481+
self.all_params = list(layer.all_params)
4482+
self.all_drop = dict(layer.all_drop)
4483+
self.all_layers.extend( [self.outputs] )
4484+
self.all_params.extend( variables )
4485+
44474486
## Special activation
44484487
class PReluLayer(Layer):
44494488
"""

0 commit comments

Comments
 (0)