Skip to content

Commit 08b22e9

Browse files
committed
nmse and docs
1 parent 2263eb5 commit 08b22e9

File tree

3 files changed

+26
-4
lines changed

3 files changed

+26
-4
lines changed

docs/modules/cost.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,7 @@ to the cost function.
120120
sigmoid_cross_entropy
121121
binary_cross_entropy
122122
mean_squared_error
123+
normalize_mean_squared_error
123124
dice_coe
124125
dice_hard_coe
125126
iou_coe
@@ -149,6 +150,10 @@ Mean squared error
149150
-------------------------
150151
.. autofunction:: mean_squared_error
151152

153+
Normalized mean squared error
154+
--------------------------------
155+
.. autofunction:: normalize_mean_squared_error
156+
152157
Dice coefficient
153158
-------------------------
154159
.. autofunction:: dice_coe

docs/modules/files.rst

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -94,12 +94,12 @@ English-to-French translation data from the WMT'15 Website
9494
Load and save network
9595
----------------------
9696

97-
Save network into list
98-
^^^^^^^^^^^^^^^^^^^^^^^^^^
97+
Save network into list (npz)
98+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
9999
.. autofunction:: save_npz
100100

101-
Save network into dict
102-
^^^^^^^^^^^^^^^^^^^^^^^^
101+
Save network into dict (npz)
102+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
103103
.. autofunction:: save_npz_dict
104104

105105
Load network from save_npz

tensorlayer/cost.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,23 @@ def mean_squared_error(output, target, is_mean=False):
102102
mse = tf.reduce_mean(tf.reduce_sum(tf.squared_difference(output, target), [1, 2, 3]))
103103
return mse
104104

105+
def normalize_mean_squared_error(output, target):
106+
"""Return the TensorFlow expression of normalized mean-squre-error of two distributions.
107+
108+
Parameters
109+
----------
110+
output : 2D or 4D tensor.
111+
target : 2D or 4D tensor.
112+
"""
113+
with tf.name_scope("mean_squared_error_loss"):
114+
if output.get_shape().ndims == 2: # [batch_size, n_feature]
115+
nmse_a = tf.sqrt(tf.reduce_sum(tf.squared_difference(output, target), axis=1))
116+
nmse_b = tf.sqrt(tf.reduce_sum(tf.square(target), axis=1))
117+
elif output.get_shape().ndims == 4: # [batch_size, w, h, c]
118+
nmse_a = tf.sqrt(tf.reduce_sum(tf.squared_difference(output, target), axis=[1,2,3]))
119+
nmse_b = tf.sqrt(tf.reduce_sum(tf.square(target), axis=[1,2,3]))
120+
nmse = tf.reduce_mean(nmse_a / nmse_b)
121+
return nmse
105122

106123

107124
def dice_coe(output, target, epsilon=1e-10):

0 commit comments

Comments
 (0)