Skip to content

Commit 0ede665

Browse files
committed
update mse for dim=3
1 parent 68cee00 commit 0ede665

File tree

1 file changed

+7
-0
lines changed

1 file changed

+7
-0
lines changed

tensorlayer/cost.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,11 +94,18 @@ def mean_squared_error(output, target, is_mean=False):
9494
mse = tf.reduce_mean(tf.reduce_mean(tf.squared_difference(output, target), 1))
9595
else:
9696
mse = tf.reduce_mean(tf.reduce_sum(tf.squared_difference(output, target), 1))
97+
elif output.get_shape().ndims == 3: # [batch_size, w, h]
98+
if is_mean:
99+
mse = tf.reduce_mean(tf.reduce_mean(tf.squared_difference(output, target), [1, 2]))
100+
else:
101+
mse = tf.reduce_mean(tf.reduce_sum(tf.squared_difference(output, target), [1, 2]))
97102
elif output.get_shape().ndims == 4: # [batch_size, w, h, c]
98103
if is_mean:
99104
mse = tf.reduce_mean(tf.reduce_mean(tf.squared_difference(output, target), [1, 2, 3]))
100105
else:
101106
mse = tf.reduce_mean(tf.reduce_sum(tf.squared_difference(output, target), [1, 2, 3]))
107+
else:
108+
raise Exception("Unknow dimension")
102109
return mse
103110

104111
def normalized_mean_square_error(output, target):

0 commit comments

Comments
 (0)