Skip to content

Commit c8a9178

Browse files
fyangftensorflower-gardener
authored andcommitted
Internal change
PiperOrigin-RevId: 381516130
1 parent 6e5cbee commit c8a9178

33 files changed

+3307
-0
lines changed
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
# Volumetric Models
2+
3+
**DISCLAIMER**: This implementation is still under development. No support will
4+
be provided during the development phase.
5+
6+
This folder contains implementation of volumetric models, i.e., UNet 3D model,
7+
for 3D semantic segmentation.
8+
9+
## Modeling
10+
11+
Following the style of TF-Vision, a UNet 3D model is implemented as a backbone
12+
and a decoder.
13+
14+
## Backbone
15+
16+
The backbone is the left U-shape of the complete UNet model. It takes batch of
17+
images as input, and outputs a dictionary in a form of `{level: features}`.
18+
`features` in the output is a tensor of feature maps.
19+
20+
## Decoder
21+
22+
The decoder is the right U-shape of the complete UNet model. It takes the output
23+
dictionary from the backbone and connects the feature maps from each level to
24+
the decoder's decoding branches. The final output is the raw segmentation
25+
predictions.
26+
27+
An additional head is attached to the output of the decoder to optionally
28+
perform more operations and then generate the prediction map of logits.
29+
30+
The `factory.py` file builds and connects the backbone, decoder and head
31+
together to form the complete UNet model.
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# Lint as: python3
16+
"""Backbones configurations."""
17+
from typing import Optional, Sequence
18+
19+
import dataclasses
20+
21+
from official.modeling import hyperparams
22+
23+
24+
@dataclasses.dataclass
25+
class UNet3D(hyperparams.Config):
26+
"""UNet3D config."""
27+
model_id: int = 4
28+
pool_size: Sequence[int] = (2, 2, 2)
29+
kernel_size: Sequence[int] = (3, 3, 3)
30+
base_filters: int = 32
31+
use_batch_normalization: bool = True
32+
33+
34+
@dataclasses.dataclass
35+
class Backbone(hyperparams.OneOfConfig):
36+
"""Configuration for backbones.
37+
38+
Attributes:
39+
type: 'str', type of backbone be used, one the of fields below.
40+
resnet: resnet backbone config.
41+
dilated_resnet: dilated resnet backbone for semantic segmentation config.
42+
revnet: revnet backbone config.
43+
efficientnet: efficientnet backbone config.
44+
spinenet: spinenet backbone config.
45+
mobilenet: mobilenet backbone config.
46+
"""
47+
type: Optional[str] = None
48+
unet_3d: UNet3D = UNet3D()
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# Lint as: python3
16+
"""Decoders configurations."""
17+
from typing import Optional, Sequence
18+
19+
import dataclasses
20+
21+
from official.modeling import hyperparams
22+
23+
24+
@dataclasses.dataclass
25+
class UNet3DDecoder(hyperparams.Config):
26+
"""UNet3D decoder config."""
27+
model_id: int = 4
28+
pool_size: Sequence[int] = (2, 2, 2)
29+
kernel_size: Sequence[int] = (3, 3, 3)
30+
use_batch_normalization: bool = True
31+
use_deconvolution: bool = True
32+
33+
34+
@dataclasses.dataclass
35+
class Decoder(hyperparams.OneOfConfig):
36+
"""Configuration for decoders.
37+
38+
Attributes:
39+
type: 'str', type of decoder be used, on the of fields below.
40+
fpn: fpn config.
41+
"""
42+
type: Optional[str] = None
43+
unet_3d_decoder: UNet3DDecoder = UNet3DDecoder()
Lines changed: 163 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,163 @@
1+
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# Lint as: python3
16+
"""Semantic segmentation configuration definition."""
17+
from typing import List, Optional, Union
18+
19+
import dataclasses
20+
21+
from official.core import exp_factory
22+
from official.modeling import hyperparams
23+
from official.modeling import optimization
24+
from official.modeling.hyperparams import config_definitions as cfg
25+
from official.vision.beta.configs import common
26+
from official.vision.beta.projects.volumetric_models.configs import backbones
27+
from official.vision.beta.projects.volumetric_models.configs import decoders
28+
29+
30+
@dataclasses.dataclass
31+
class DataConfig(cfg.DataConfig):
32+
"""Input config for training."""
33+
output_size: List[int] = dataclasses.field(default_factory=list)
34+
input_size: List[int] = dataclasses.field(default_factory=list)
35+
num_classes: int = 0
36+
num_channels: int = 1
37+
input_path: str = ''
38+
global_batch_size: int = 0
39+
is_training: bool = True
40+
dtype: str = 'float32'
41+
label_dtype: str = 'float32'
42+
image_field_key: str = 'image/encoded'
43+
label_field_key: str = 'image/class/label'
44+
shuffle_buffer_size: int = 1000
45+
cycle_length: int = 10
46+
drop_remainder: bool = False
47+
file_type: str = 'tfrecord'
48+
49+
50+
@dataclasses.dataclass
51+
class SegmentationHead3D(hyperparams.Config):
52+
"""Segmentation head config."""
53+
num_classes: int = 0
54+
level: int = 1
55+
num_convs: int = 0
56+
num_filters: int = 256
57+
upsample_factor: int = 1
58+
output_logits: bool = True
59+
60+
61+
@dataclasses.dataclass
62+
class SemanticSegmentationModel3D(hyperparams.Config):
63+
"""Semantic segmentation model config."""
64+
num_classes: int = 0
65+
num_channels: int = 1
66+
input_size: List[int] = dataclasses.field(default_factory=list)
67+
min_level: int = 3
68+
max_level: int = 6
69+
head: SegmentationHead3D = SegmentationHead3D()
70+
backbone: backbones.Backbone = backbones.Backbone(
71+
type='unet_3d', unet_3d=backbones.UNet3D())
72+
decoder: decoders.Decoder = decoders.Decoder(
73+
type='unet_3d_decoder', unet_3d_decoder=decoders.UNet3DDecoder())
74+
norm_activation: common.NormActivation = common.NormActivation()
75+
76+
77+
@dataclasses.dataclass
78+
class Losses(hyperparams.Config):
79+
# Supported `loss_type` are `adaptive` and `generalized`.
80+
loss_type: str = 'adaptive'
81+
l2_weight_decay: float = 0.0
82+
83+
84+
@dataclasses.dataclass
85+
class Evaluation(hyperparams.Config):
86+
report_per_class_metric: bool = False # Whether to report per-class metrics.
87+
88+
89+
@dataclasses.dataclass
90+
class SemanticSegmentation3DTask(cfg.TaskConfig):
91+
"""The model config."""
92+
model: SemanticSegmentationModel3D = SemanticSegmentationModel3D()
93+
train_data: DataConfig = DataConfig(is_training=True)
94+
validation_data: DataConfig = DataConfig(is_training=False)
95+
losses: Losses = Losses()
96+
evaluation: Evaluation = Evaluation()
97+
train_input_partition_dims: List[int] = dataclasses.field(
98+
default_factory=list)
99+
eval_input_partition_dims: List[int] = dataclasses.field(default_factory=list)
100+
init_checkpoint: Optional[str] = None
101+
init_checkpoint_modules: Union[
102+
str, List[str]] = 'all' # all, backbone, and/or decoder
103+
104+
105+
@exp_factory.register_config_factory('seg_unet3d_test')
106+
def seg_unet3d_test() -> cfg.ExperimentConfig:
107+
"""Image segmentation on a dummy dataset with 3D UNet for testing purpose."""
108+
train_batch_size = 2
109+
eval_batch_size = 2
110+
steps_per_epoch = 10
111+
config = cfg.ExperimentConfig(
112+
task=SemanticSegmentation3DTask(
113+
model=SemanticSegmentationModel3D(
114+
num_classes=2,
115+
input_size=[32, 32, 32],
116+
num_channels=2,
117+
backbone=backbones.Backbone(
118+
type='unet_3d', unet_3d=backbones.UNet3D(model_id=2)),
119+
decoder=decoders.Decoder(
120+
type='unet_3d_decoder',
121+
unet_3d_decoder=decoders.UNet3DDecoder(model_id=2)),
122+
head=SegmentationHead3D(num_convs=0, num_classes=2),
123+
norm_activation=common.NormActivation(
124+
activation='relu', use_sync_bn=False)),
125+
train_data=DataConfig(
126+
input_path='train.tfrecord',
127+
num_classes=2,
128+
input_size=[32, 32, 32],
129+
num_channels=2,
130+
is_training=True,
131+
global_batch_size=train_batch_size),
132+
validation_data=DataConfig(
133+
input_path='val.tfrecord',
134+
num_classes=2,
135+
input_size=[32, 32, 32],
136+
num_channels=2,
137+
is_training=False,
138+
global_batch_size=eval_batch_size),
139+
losses=Losses(loss_type='adaptive')),
140+
trainer=cfg.TrainerConfig(
141+
steps_per_loop=steps_per_epoch,
142+
summary_interval=steps_per_epoch,
143+
checkpoint_interval=steps_per_epoch,
144+
train_steps=10,
145+
validation_steps=10,
146+
validation_interval=steps_per_epoch,
147+
optimizer_config=optimization.OptimizationConfig({
148+
'optimizer': {
149+
'type': 'sgd',
150+
},
151+
'learning_rate': {
152+
'type': 'constant',
153+
'constant': {
154+
'learning_rate': 0.000001
155+
}
156+
}
157+
})),
158+
restrictions=[
159+
'task.train_data.is_training != None',
160+
'task.validation_data.is_training != None'
161+
])
162+
163+
return config
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# Lint as: python3
16+
"""Tests for semantic_segmentation."""
17+
18+
# pylint: disable=unused-import
19+
from absl.testing import parameterized
20+
import tensorflow as tf
21+
22+
from official.core import exp_factory
23+
from official.modeling.hyperparams import config_definitions as cfg
24+
from official.vision.beta.projects.volumetric_models.configs import semantic_segmentation_3d as exp_cfg
25+
26+
27+
class ImageSegmentationConfigTest(tf.test.TestCase, parameterized.TestCase):
28+
29+
@parameterized.parameters(
30+
('seg_unet3d_test',),)
31+
def test_semantic_segmentation_configs(self, config_name):
32+
config = exp_factory.get_exp_config(config_name)
33+
self.assertIsInstance(config, cfg.ExperimentConfig)
34+
self.assertIsInstance(config.task, exp_cfg.SemanticSegmentation3DTask)
35+
self.assertIsInstance(config.task.model,
36+
exp_cfg.SemanticSegmentationModel3D)
37+
self.assertIsInstance(config.task.train_data, exp_cfg.DataConfig)
38+
config.task.train_data.is_training = None
39+
with self.assertRaises(KeyError):
40+
config.validate()
41+
42+
43+
if __name__ == '__main__':
44+
tf.test.main()

0 commit comments

Comments
 (0)