Skip to content

Commit 79c7526

Browse files
saberkunallenwang28
authored andcommitted
Add resnet_config.py.
Remove callback_test.py as it uses private TF symbol callback_test PiperOrigin-RevId: 302990143
1 parent de31fd8 commit 79c7526

File tree

2 files changed

+61
-86
lines changed

2 files changed

+61
-86
lines changed

official/vision/image_classification/callbacks_test.py

Lines changed: 0 additions & 86 deletions
This file was deleted.
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
# Lint as: python3
2+
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
# ==============================================================================
16+
"""Configuration definitions for ResNet losses, learning rates, and optimizers."""
17+
from __future__ import absolute_import
18+
from __future__ import division
19+
from __future__ import print_function
20+
21+
from typing import Any, Mapping
22+
23+
import dataclasses
24+
25+
from official.vision.image_classification.configs import base_configs
26+
27+
28+
_RESNET_LR_SCHEDULE = [ # (multiplier, epoch to start) tuples
29+
(1.0, 5), (0.1, 30), (0.01, 60), (0.001, 80)
30+
]
31+
_RESNET_LR_BOUNDARIES = list(p[1] for p in _RESNET_LR_SCHEDULE[1:])
32+
_RESNET_LR_MULTIPLIERS = list(p[0] for p in _RESNET_LR_SCHEDULE)
33+
_RESNET_LR_WARMUP_EPOCHS = _RESNET_LR_SCHEDULE[0][1]
34+
35+
36+
@dataclasses.dataclass
37+
class ResNetModelConfig(base_configs.ModelConfig):
38+
"""Configuration for the ResNet model."""
39+
name: str = 'ResNet'
40+
num_classes: int = 1000
41+
model_params: Mapping[str, Any] = dataclasses.field(default_factory=lambda: {
42+
'num_classes': 1000,
43+
'batch_size': None,
44+
'use_l2_regularizer': True,
45+
'rescale_inputs': False,
46+
})
47+
loss: base_configs.LossConfig = base_configs.LossConfig(
48+
name='sparse_categorical_crossentropy')
49+
optimizer: base_configs.OptimizerConfig = base_configs.OptimizerConfig(
50+
name='momentum',
51+
decay=0.9,
52+
epsilon=0.001,
53+
momentum=0.9,
54+
moving_average_decay=None)
55+
learning_rate: base_configs.LearningRateConfig = (
56+
base_configs.LearningRateConfig(
57+
name='piecewise_constant_with_warmup',
58+
examples_per_epoch=1281167,
59+
warmup_epochs=_RESNET_LR_WARMUP_EPOCHS,
60+
boundaries=_RESNET_LR_BOUNDARIES,
61+
multipliers=_RESNET_LR_MULTIPLIERS))

0 commit comments

Comments
 (0)