Skip to content
This repository was archived by the owner on Feb 3, 2025. It is now read-only.

Commit fa1e35a

Browse files
author
DEKHTIARJonathan
committed
UNet_Medical TF2 replace metric computation from TF to Numpy/Scipy to avoid GPU OOM
1 parent 8cf01e9 commit fa1e35a

File tree

1 file changed

+6
-7
lines changed
  • tftrt/examples/nvidia_examples/unet_medical_tf2

1 file changed

+6
-7
lines changed

tftrt/examples/nvidia_examples/unet_medical_tf2/utils.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -78,14 +78,13 @@ def preproc_samples(inputs, labels, precision):
7878

7979

8080
def dice_coef(predict, target, axis=1, eps=1e-6):
81+
from scipy.special import softmax
8182

82-
predict = tf.convert_to_tensor(predict)
83-
predict = tf.keras.activations.softmax(predict, axis=-1)
83+
predict = softmax(predict, axis=-1)
8484

85-
target = tf.convert_to_tensor(target)
86-
87-
intersection = tf.math.reduce_sum(predict * target, axis=axis)
88-
union = tf.math.reduce_sum(predict*predict + target*target, axis=axis)
85+
intersection = np.sum(predict * target, axis=axis)
86+
union = np.sum(predict*predict + target*target, axis=axis)
8987

9088
dice = (2.*intersection + eps) / (union+eps)
91-
return tf.math.reduce_mean(dice).numpy()
89+
90+
return np.mean(dice)

0 commit comments

Comments
 (0)