Skip to content

Commit ca8c44d

Browse files
Internal change
PiperOrigin-RevId: 424422082
1 parent a7894f9 commit ca8c44d

File tree

17 files changed

+3174
-0
lines changed

17 files changed

+3174
-0
lines changed

official/projects/detr/README.md

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
# End-to-End Object Detection with Transformers (DETR)
2+
3+
[![DETR](https://img.shields.io/badge/DETR-arXiv.2005.12872-B3181B?)](https://arxiv.org/abs/2005.12872).
4+
5+
TensorFlow 2 implementation of End-to-End Object Detection with Transformers
6+
7+
⚠️ Disclaimer: All datasets hyperlinked from this page are not owned or
8+
distributed by Google. The dataset is made available by third parties.
9+
Please review the terms and conditions made available by the third parties
10+
before using the data.
11+
12+
## Scripts:
13+
14+
You can find the scripts to reproduce the following experiments in
15+
detr/experiments.
16+
17+
18+
## DETR [COCO](https://cocodataset.org) ([ImageNet](https://www.image-net.org) pretrained)
19+
20+
| Model | Resolution | Batch size | Epochs | Decay@ | Params (M) | Box AP | Dashboard | Checkpoint | Experiment |
21+
| --------- | :--------: | ----------:| ------:| -----: | ---------: | -----: | --------: | ---------: | ---------: |
22+
| DETR-ResNet-50 | 1333x1333 |64|300| 200 |41 | 40.6 | [tensorboard](https://tensorboard.dev/experiment/o2IEZnniRYu6pqViBeopIg/#scalars) | [ckpt](https://storage.googleapis.com/tf_model_garden/vision/detr/detr_resnet_50_300.tar.gz) | detr_r50_300epochs.sh |
23+
| DETR-ResNet-50 | 1333x1333 |64|500| 400 |41 | 42.0| [tensorboard](https://tensorboard.dev/experiment/YFMDKpESR4yjocPh5HgfRw/) | [ckpt](https://storage.googleapis.com/tf_model_garden/vision/detr/detr_resnet_50_500.tar.gz) | detr_r50_500epochs.sh |
24+
| DETR-ResNet-50 | 1333x1333 |64|300| 200 |41 | 40.6 | paper | NA | NA |
25+
| DETR-ResNet-50 | 1333x1333 |64|500| 400 |41 | 42.0 | paper | NA | NA |
26+
| DETR-DC5-ResNet-50 | 1333x1333 |64|500| 400 |41 | 43.3 | paper | NA | NA |
27+
28+
## Need contribution:
29+
30+
* Add DC5 support and update experiment table.
31+
32+
33+
## Citing TensorFlow Model Garden
34+
35+
If you find this codebase helpful in your research, please cite this repository.
36+
37+
```
38+
@misc{tensorflowmodelgarden2020,
39+
author = {Hongkun Yu and Chen Chen and Xianzhi Du and Yeqing Li and
40+
Abdullah Rashwan and Le Hou and Pengchong Jin and Fan Yang and
41+
Frederick Liu and Jaeyoun Kim and Jing Li},
42+
title = {{TensorFlow Model Garden}},
43+
howpublished = {\url{https://github.com/tensorflow/models}},
44+
year = {2020}
45+
}
46+
```
Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""DETR configurations."""
16+
17+
import dataclasses
18+
from official.core import config_definitions as cfg
19+
from official.core import exp_factory
20+
from official.projects.detr import optimization
21+
from official.projects.detr.dataloaders import coco
22+
23+
24+
@dataclasses.dataclass
25+
class DetectionConfig(cfg.TaskConfig):
26+
"""The translation task config."""
27+
train_data: cfg.DataConfig = cfg.DataConfig()
28+
validation_data: cfg.DataConfig = cfg.DataConfig()
29+
lambda_cls: float = 1.0
30+
lambda_box: float = 5.0
31+
lambda_giou: float = 2.0
32+
33+
init_ckpt: str = ''
34+
num_classes: int = 81 # 0: background
35+
background_cls_weight: float = 0.1
36+
num_encoder_layers: int = 6
37+
num_decoder_layers: int = 6
38+
39+
# Make DETRConfig.
40+
num_queries: int = 100
41+
num_hidden: int = 256
42+
per_category_metrics: bool = False
43+
44+
45+
@exp_factory.register_config_factory('detr_coco')
46+
def detr_coco() -> cfg.ExperimentConfig:
47+
"""Config to get results that matches the paper."""
48+
train_batch_size = 64
49+
eval_batch_size = 64
50+
num_train_data = 118287
51+
num_steps_per_epoch = num_train_data // train_batch_size
52+
train_steps = 500 * num_steps_per_epoch # 500 epochs
53+
decay_at = train_steps - 100 * num_steps_per_epoch # 400 epochs
54+
config = cfg.ExperimentConfig(
55+
task=DetectionConfig(
56+
train_data=coco.COCODataConfig(
57+
tfds_name='coco/2017',
58+
tfds_split='train',
59+
is_training=True,
60+
global_batch_size=train_batch_size,
61+
shuffle_buffer_size=1000,
62+
),
63+
validation_data=coco.COCODataConfig(
64+
tfds_name='coco/2017',
65+
tfds_split='validation',
66+
is_training=False,
67+
global_batch_size=eval_batch_size,
68+
drop_remainder=False
69+
)
70+
),
71+
trainer=cfg.TrainerConfig(
72+
train_steps=train_steps,
73+
validation_steps=-1,
74+
steps_per_loop=10000,
75+
summary_interval=10000,
76+
checkpoint_interval=10000,
77+
validation_interval=10000,
78+
max_to_keep=1,
79+
best_checkpoint_export_subdir='best_ckpt',
80+
best_checkpoint_eval_metric='AP',
81+
optimizer_config=optimization.OptimizationConfig({
82+
'optimizer': {
83+
'type': 'detr_adamw',
84+
'detr_adamw': {
85+
'weight_decay_rate': 1e-4,
86+
'global_clipnorm': 0.1,
87+
# Avoid AdamW legacy behavior.
88+
'gradient_clip_norm': 0.0
89+
}
90+
},
91+
'learning_rate': {
92+
'type': 'stepwise',
93+
'stepwise': {
94+
'boundaries': [decay_at],
95+
'values': [0.0001, 1.0e-05]
96+
}
97+
},
98+
})
99+
),
100+
restrictions=[
101+
'task.train_data.is_training != None',
102+
])
103+
return config
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Tests for detr."""
16+
17+
# pylint: disable=unused-import
18+
from absl.testing import parameterized
19+
import tensorflow as tf
20+
21+
from official.core import config_definitions as cfg
22+
from official.core import exp_factory
23+
from official.projects.detr.configs import detr as exp_cfg
24+
from official.projects.detr.dataloaders import coco
25+
26+
27+
class DetrTest(tf.test.TestCase, parameterized.TestCase):
28+
29+
@parameterized.parameters(('detr_coco',))
30+
def test_detr_configs(self, config_name):
31+
config = exp_factory.get_exp_config(config_name)
32+
self.assertIsInstance(config, cfg.ExperimentConfig)
33+
self.assertIsInstance(config.task, exp_cfg.DetectionConfig)
34+
self.assertIsInstance(config.task.train_data, coco.COCODataConfig)
35+
config.task.train_data.is_training = None
36+
with self.assertRaises(KeyError):
37+
config.validate()
38+
39+
40+
if __name__ == '__main__':
41+
tf.test.main()
Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,157 @@
1+
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""COCO data loader for DETR."""
16+
17+
import dataclasses
18+
from typing import Optional, Tuple
19+
import tensorflow as tf
20+
21+
from official.core import config_definitions as cfg
22+
from official.core import input_reader
23+
from official.vision.beta.ops import box_ops
24+
from official.vision.beta.ops import preprocess_ops
25+
26+
27+
@dataclasses.dataclass
28+
class COCODataConfig(cfg.DataConfig):
29+
"""Data config for COCO."""
30+
output_size: Tuple[int, int] = (1333, 1333)
31+
max_num_boxes: int = 100
32+
resize_scales: Tuple[int, ...] = (
33+
480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800)
34+
35+
36+
class COCODataLoader():
37+
"""A class to load dataset for COCO detection task."""
38+
39+
def __init__(self, params: COCODataConfig):
40+
self._params = params
41+
42+
def preprocess(self, inputs):
43+
"""Preprocess COCO for DETR."""
44+
image = inputs['image']
45+
boxes = inputs['objects']['bbox']
46+
classes = inputs['objects']['label'] + 1
47+
is_crowd = inputs['objects']['is_crowd']
48+
49+
image = preprocess_ops.normalize_image(image)
50+
if self._params.is_training:
51+
image, boxes, _ = preprocess_ops.random_horizontal_flip(image, boxes)
52+
53+
do_crop = tf.greater(tf.random.uniform([]), 0.5)
54+
if do_crop:
55+
# Rescale
56+
boxes = box_ops.denormalize_boxes(boxes, tf.shape(image)[:2])
57+
index = tf.random.categorical(tf.zeros([1, 3]), 1)[0]
58+
scales = tf.gather([400.0, 500.0, 600.0], index, axis=0)
59+
short_side = scales[0]
60+
image, image_info = preprocess_ops.resize_image(image, short_side)
61+
boxes = preprocess_ops.resize_and_crop_boxes(boxes,
62+
image_info[2, :],
63+
image_info[1, :],
64+
image_info[3, :])
65+
boxes = box_ops.normalize_boxes(boxes, image_info[1, :])
66+
67+
# Do croping
68+
shape = tf.cast(image_info[1], dtype=tf.int32)
69+
h = tf.random.uniform(
70+
[], 384, tf.math.minimum(shape[0], 600), dtype=tf.int32)
71+
w = tf.random.uniform(
72+
[], 384, tf.math.minimum(shape[1], 600), dtype=tf.int32)
73+
i = tf.random.uniform([], 0, shape[0] - h + 1, dtype=tf.int32)
74+
j = tf.random.uniform([], 0, shape[1] - w + 1, dtype=tf.int32)
75+
image = tf.image.crop_to_bounding_box(image, i, j, h, w)
76+
boxes = tf.clip_by_value(
77+
(boxes[..., :] * tf.cast(
78+
tf.stack([shape[0], shape[1], shape[0], shape[1]]),
79+
dtype=tf.float32) -
80+
tf.cast(tf.stack([i, j, i, j]), dtype=tf.float32)) /
81+
tf.cast(tf.stack([h, w, h, w]), dtype=tf.float32), 0.0, 1.0)
82+
scales = tf.constant(
83+
self._params.resize_scales,
84+
dtype=tf.float32)
85+
index = tf.random.categorical(tf.zeros([1, 11]), 1)[0]
86+
scales = tf.gather(scales, index, axis=0)
87+
else:
88+
scales = tf.constant([self._params.resize_scales[-1]], tf.float32)
89+
90+
image_shape = tf.shape(image)[:2]
91+
boxes = box_ops.denormalize_boxes(boxes, image_shape)
92+
gt_boxes = boxes
93+
short_side = scales[0]
94+
image, image_info = preprocess_ops.resize_image(
95+
image,
96+
short_side,
97+
max(self._params.output_size))
98+
boxes = preprocess_ops.resize_and_crop_boxes(boxes,
99+
image_info[2, :],
100+
image_info[1, :],
101+
image_info[3, :])
102+
boxes = box_ops.normalize_boxes(boxes, image_info[1, :])
103+
104+
# Filters out ground truth boxes that are all zeros.
105+
indices = box_ops.get_non_empty_box_indices(boxes)
106+
boxes = tf.gather(boxes, indices)
107+
classes = tf.gather(classes, indices)
108+
is_crowd = tf.gather(is_crowd, indices)
109+
boxes = box_ops.yxyx_to_cycxhw(boxes)
110+
111+
image = tf.image.pad_to_bounding_box(
112+
image, 0, 0, self._params.output_size[0], self._params.output_size[1])
113+
labels = {
114+
'classes':
115+
preprocess_ops.clip_or_pad_to_fixed_size(
116+
classes, self._params.max_num_boxes),
117+
'boxes':
118+
preprocess_ops.clip_or_pad_to_fixed_size(
119+
boxes, self._params.max_num_boxes)
120+
}
121+
if not self._params.is_training:
122+
labels.update({
123+
'id':
124+
inputs['image/id'],
125+
'image_info':
126+
image_info,
127+
'is_crowd':
128+
preprocess_ops.clip_or_pad_to_fixed_size(
129+
is_crowd, self._params.max_num_boxes),
130+
'gt_boxes':
131+
preprocess_ops.clip_or_pad_to_fixed_size(
132+
gt_boxes, self._params.max_num_boxes),
133+
})
134+
135+
return image, labels
136+
137+
def _transform_and_batch_fn(
138+
self,
139+
dataset,
140+
input_context: Optional[tf.distribute.InputContext] = None):
141+
"""Preprocess and batch."""
142+
dataset = dataset.map(
143+
self.preprocess, num_parallel_calls=tf.data.experimental.AUTOTUNE)
144+
per_replica_batch_size = input_context.get_per_replica_batch_size(
145+
self._params.global_batch_size
146+
) if input_context else self._params.global_batch_size
147+
dataset = dataset.batch(
148+
per_replica_batch_size, drop_remainder=self._params.is_training)
149+
return dataset
150+
151+
def load(self, input_context: Optional[tf.distribute.InputContext] = None):
152+
"""Returns a tf.dataset.Dataset."""
153+
reader = input_reader.InputReader(
154+
params=self._params,
155+
decoder_fn=None,
156+
transform_and_batch_fn=self._transform_and_batch_fn)
157+
return reader.read(input_context)

0 commit comments

Comments
 (0)