From 912eb8052e84462c04b67fc9e32f99ef322f120b Mon Sep 17 00:00:00 2001 From: idhamari Date: Fri, 12 Jan 2024 22:14:56 +0100 Subject: [PATCH 1/2] update to recent version of tensorflow and keras 2.13 --- segmentation_models/__init__.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/segmentation_models/__init__.py b/segmentation_models/__init__.py index d9d55222..cd821ad7 100644 --- a/segmentation_models/__init__.py +++ b/segmentation_models/__init__.py @@ -1,9 +1,10 @@ import os import functools +from tensorflow import keras from .__version__ import __version__ from . import base -_KERAS_FRAMEWORK_NAME = 'keras' +_KERAS_FRAMEWORK_NAME = 'tf.keras' _TF_KERAS_FRAMEWORK_NAME = 'tf.keras' _DEFAULT_KERAS_FRAMEWORK = _KERAS_FRAMEWORK_NAME @@ -64,14 +65,12 @@ def set_framework(name): name = name.lower() if name == _KERAS_FRAMEWORK_NAME: - import keras - import efficientnet.keras # init custom objects - elif name == _TF_KERAS_FRAMEWORK_NAME: - from tensorflow import keras + from tensorflow.keras import backend as K + from tensorflow.keras import layers, models, utils, losses import efficientnet.tfkeras # init custom objects else: - raise ValueError('Not correct module name `{}`, use `{}` or `{}`'.format( - name, _KERAS_FRAMEWORK_NAME, _TF_KERAS_FRAMEWORK_NAME)) + raise ValueError('Not correct module name `{}`, use `{}`'.format( + name, _TF_KERAS_FRAMEWORK_NAME)) global _KERAS_BACKEND, _KERAS_LAYERS, _KERAS_MODELS global _KERAS_UTILS, _KERAS_LOSSES, _KERAS_FRAMEWORK From 6c9a4cd0193291157e90df1945f4b0714a5d55fb Mon Sep 17 00:00:00 2001 From: idhamari Date: Sat, 13 Jan 2024 09:14:41 +0100 Subject: [PATCH 2/2] add dice score --- .../backbones/backbones_factory.py | 1 - segmentation_models/base/functional.py | 31 +++++++++ segmentation_models/metrics.py | 69 +++++++++++++++++-- 3 files changed, 95 insertions(+), 6 deletions(-) diff --git a/segmentation_models/backbones/backbones_factory.py b/segmentation_models/backbones/backbones_factory.py index 7d2a3b9e..716a3199 100644 --- a/segmentation_models/backbones/backbones_factory.py +++ b/segmentation_models/backbones/backbones_factory.py @@ -71,7 +71,6 @@ class BackbonesFactory(ModelsFactory): 'block3a_expand_activation', 'block2a_expand_activation'), } - _models_update = { 'inceptionresnetv2': [irv2.InceptionResNetV2, irv2.preprocess_input], 'inceptionv3': [iv3.InceptionV3, iv3.preprocess_input], diff --git a/segmentation_models/base/functional.py b/segmentation_models/base/functional.py index afd8e73f..cd613cf1 100644 --- a/segmentation_models/base/functional.py +++ b/segmentation_models/base/functional.py @@ -98,6 +98,37 @@ def iou_score(gt, pr, class_weights=1., class_indexes=None, smooth=SMOOTH, per_i return score +def dice_score(gt, pr, class_weights=1., class_indexes=None, smooth=SMOOTH, per_image=False, threshold=None, **kwargs): + """ + Calculate the Dice coefficient, a measure of set similarity. + + Args: + gt: Ground truth 4D keras tensor (B, H, W, C) or (B, C, H, W). + pr: Prediction 4D keras tensor (B, H, W, C) or (B, C, H, W). + class_weights: 1. or list of class weights, len(weights) = C. + class_indexes: Optional integer or list of integers, classes to consider, if `None` all classes are used. + smooth: Value to avoid division by zero. + per_image: If `True`, metric is calculated as mean over images in batch (B), else over whole batch. + threshold: Value to round predictions (use `>` comparison), if `None` prediction will not be rounded. + + Returns: + Dice score in range [0, 1]. + """ + + backend = kwargs['backend'] + + gt, pr = gather_channels(gt, pr, indexes=class_indexes, **kwargs) + pr = round_if_needed(pr, threshold, **kwargs) + axes = get_reduce_axes(per_image, **kwargs) + + # Adjusted score calculation for Dice coefficient + intersection = backend.sum(gt * pr, axis=axes) + sum_gt_pr = backend.sum(gt, axis=axes) + backend.sum(pr, axis=axes) + + score = (2 * intersection + smooth) / (sum_gt_pr + smooth) + score = average(score, per_image, class_weights, **kwargs) + + return score def f_score(gt, pr, beta=1, class_weights=1, class_indexes=None, smooth=SMOOTH, per_image=False, threshold=None, **kwargs): diff --git a/segmentation_models/metrics.py b/segmentation_models/metrics.py index 831dd462..28f06de5 100644 --- a/segmentation_models/metrics.py +++ b/segmentation_models/metrics.py @@ -62,6 +62,64 @@ def __call__(self, gt, pr): **self.submodules ) +class DICEScore(Metric): + r""" The `Dice coefficient`_, also known as the Sørensen–Dice coefficient or Dice similarity, + is a statistic used for gauging the similarity of two samples. It's often used in the context of + binary and multiclass segmentation problems. The Dice coefficient is defined as twice the size + of the intersection divided by the sum of the sizes of the two sample sets: + + .. math:: DSC(A, B) = \frac{2 |A \cap B|}{|A| + |B|} + + Args: + class_weights: 1. or ``np.array`` of class weights (``len(weights) = num_classes``). + class_indexes: Optional integer or list of integers, classes to consider, if ``None`` all classes are used. + smooth: Value to avoid division by zero. + per_image: If ``True``, metric is calculated as mean over images in batch (B), + else over whole batch. + threshold: Value to round predictions (use ``>`` comparison), if ``None`` prediction will not be rounded. + + Returns: + A callable ``dice_score`` instance. Can be used in the ``model.compile(...)`` function. + + .. _`Dice coefficient`: https://en.wikipedia.org/wiki/S%C3%B8rensen%E2%80%93Dice_coefficient + + Example: + + .. code:: python + + metric = DICEScore() + model.compile('SGD', loss=loss, metrics=[metric]) + """ + + def __init__( + self, + class_weights=None, + class_indexes=None, + threshold=None, + per_image=False, + smooth=SMOOTH, # SMOOTH should be defined elsewhere in your code + name=None, + ): + name = name or 'dice_score' + super().__init__(name=name) + self.class_weights = class_weights if class_weights is not None else 1 + self.class_indexes = class_indexes + self.threshold = threshold + self.per_image = per_image + self.smooth = smooth + + def __call__(self, gt, pr): + return F.dice_score( # Assuming F.dice_score is your implementation of Dice score + gt, + pr, + class_weights=self.class_weights, + class_indexes=self.class_indexes, + smooth=self.smooth, + per_image=self.per_image, + threshold=self.threshold, + **self.submodules + ) + class FScore(Metric): r"""The F-score (Dice coefficient) can be interpreted as a weighted average of the precision and recall, @@ -256,8 +314,9 @@ def __call__(self, gt, pr): # aliases -iou_score = IOUScore() -f1_score = FScore(beta=1) -f2_score = FScore(beta=2) -precision = Precision() -recall = Recall() +iou_score = IOUScore() +dice_score = DICEScore() +f1_score = FScore(beta=1) +f2_score = FScore(beta=2) +precision = Precision() +recall = Recall()