Skip to content

Commit 286bcf6

Browse files
authored
Merge branch 'tensorflow:master' into yolov4_tiny_pr
2 parents 5eba85b + 8bcb4a0 commit 286bcf6

File tree

341 files changed

+24519
-2519
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

341 files changed

+24519
-2519
lines changed

.github/bot_config.yml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,4 @@
2121

2222
# A list of assignees
2323
assignees:
24-
- sushreebarsa
2524
- laxmareddyp
26-
- sineeli

docs/vision/instance_segmentation.ipynb

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -745,7 +745,7 @@
745745
},
746746
"outputs": [],
747747
"source": [
748-
"def show_batch(raw_records, num_of_examples):\n",
748+
"def show_batch(raw_records):\n",
749749
" plt.figure(figsize=(20, 20))\n",
750750
" use_normalized_coordinates=True\n",
751751
" min_score_thresh = 0.30\n",
@@ -802,7 +802,7 @@
802802
"\n",
803803
"train_tfrecords = tf.io.gfile.glob(exp_config.task.train_data.input_path)\n",
804804
"raw_records = tf.data.TFRecordDataset(train_tfrecords).shuffle(buffer_size=buffer_size).take(num_of_examples)\n",
805-
"show_batch(raw_records, num_of_examples)"
805+
"show_batch(raw_records)"
806806
]
807807
},
808808
{
@@ -962,7 +962,7 @@
962962
"\n",
963963
"test_tfrecords = tf.io.gfile.glob('./lvis_tfrecords/val*')\n",
964964
"test_ds = tf.data.TFRecordDataset(test_tfrecords).take(num_of_examples)\n",
965-
"show_batch(test_ds, num_of_examples)"
965+
"show_batch(test_ds)"
966966
]
967967
},
968968
{
@@ -1095,7 +1095,7 @@
10951095
" detection_masks = tf.convert_to_tensor(result['detection_masks'][0])\n",
10961096
" detection_boxes = tf.convert_to_tensor(result['detection_boxes'][0])\n",
10971097
" detection_masks_reframed = reframe_box_masks_to_image_masks(\n",
1098-
" detection_masks, detection_boxes/255.0,\n",
1098+
" detection_masks, detection_boxes/256.0,\n",
10991099
" image_np.shape[0], image_np.shape[1])\n",
11001100
" detection_masks_reframed = tf.cast(\n",
11011101
" detection_masks_reframed \u003e min_score_thresh,\n",

official/core/base_task.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,9 @@ def __init__(self,
5757
"""
5858
super().__init__(name=name)
5959
self._task_config = params
60-
self._logging_dir = logging_dir
60+
self._logging_dir = (
61+
logging_dir or ""
62+
) # Empty directory hints current working dir.
6163

6264
@property
6365
def task_config(self):

official/core/base_trainer.py

Lines changed: 27 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -47,18 +47,29 @@ def init_async(self):
4747
tf.distribute.experimental.coordinator.ClusterCoordinator(
4848
self._strategy))
4949

50+
def coordinator_for_async(
51+
self,
52+
) -> tf.distribute.experimental.coordinator.ClusterCoordinator:
53+
if not self._coordinator:
54+
raise ValueError(
55+
"Coordinator uninitialized for async run. Call init_async() first."
56+
)
57+
return self._coordinator
58+
5059
def join(self):
5160
"""Join all async steps. Only useful in aysnc training."""
5261
if getattr(self, "_is_async", False):
53-
self._coordinator.join()
62+
self.coordinator_for_async().join()
5463

5564
def create_train_loop_fn(self):
5665
"""Creates a eval loop from the given step function and options."""
5766
train_loop_fn = super().create_train_loop_fn()
5867
if getattr(self, "_is_async", False):
5968

6069
def _async_loop_fn(iterator, num_steps):
61-
self._coordinator.schedule(train_loop_fn, args=(iterator, num_steps))
70+
self.coordinator_for_async().schedule(
71+
train_loop_fn, args=(iterator, num_steps)
72+
)
6273

6374
return _async_loop_fn
6475
else:
@@ -76,7 +87,9 @@ def create_eval_loop_fn(self, has_state: bool):
7687
def _async_loop_fn(iterator, num_steps, state=None, reduce_fn=None):
7788
assert state is None
7889
assert reduce_fn is None
79-
self._coordinator.schedule(eval_loop_fn, args=(iterator, num_steps))
90+
self.coordinator_for_async().schedule(
91+
eval_loop_fn, args=(iterator, num_steps)
92+
)
8093

8194
return _async_loop_fn
8295
else:
@@ -102,7 +115,9 @@ def distribute_dataset(self, dataset_or_fn, *args, **kwargs):
102115
*args, **kwargs)
103116
per_worker_dataset_fn = tf.function(per_worker_dataset_fn)
104117

105-
return self._coordinator.create_per_worker_dataset(per_worker_dataset_fn)
118+
return self.coordinator_for_async().create_per_worker_dataset(
119+
per_worker_dataset_fn
120+
)
106121
else:
107122
return orbit.utils.make_distributed_dataset(self._strategy, dataset_or_fn,
108123
*args, **kwargs)
@@ -352,7 +367,10 @@ def next_train_inputs(self, iterator):
352367
This method provides a way to control how to fetch the next model input, and
353368
what data to send to the model.
354369
355-
This function runs in eager mode.
370+
Note: This function runs on the host side when accelerators are used.
371+
372+
Note: Depending on the training setup this may or may not run in eager mode.
373+
In most cases it will be run in graph mode.
356374
357375
Args:
358376
iterator: Dataset iterator to generate the next inputs from.
@@ -399,7 +417,10 @@ def next_eval_inputs(self, iterator):
399417
processed later in `aggregate_logs`. This is useful for sending extra logs
400418
downstream that are not compatible with the accelerators.
401419
402-
This function runs in eager mode.
420+
Note: This function runs on the host side when accelerators are used.
421+
422+
Note: Depending on the training setup this may or may not run in eager mode.
423+
In most cases it will be run in graph mode.
403424
404425
Args:
405426
iterator: Dataset iterator to generate the next inputs from.

official/core/config_definitions.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,7 @@ class TrainerConfig(base_config.Config):
214214
train_tf_while_loop: whether or not to use tf while loop.
215215
train_tf_function: whether or not to use tf_function for training loop.
216216
eval_tf_function: whether or not to use tf_function for eval.
217+
eval_tf_while_loop: whether or not to use tf while loop for eval.
217218
allow_tpu_summary: Whether to allow summary happen inside the XLA program
218219
runs on TPU through automatic outside compilation.
219220
steps_per_loop: number of steps per loop to report training metrics. This
@@ -244,7 +245,9 @@ class TrainerConfig(base_config.Config):
244245
preemption_on_demand_checkpoint: whether or not to save on-demand
245246
checkpoints after a preemption.
246247
"""
247-
optimizer_config: OptimizationConfig = OptimizationConfig()
248+
optimizer_config: OptimizationConfig = dataclasses.field(
249+
default_factory=OptimizationConfig
250+
)
248251
# Orbit settings.
249252
train_tf_while_loop: bool = True
250253
train_tf_function: bool = True
@@ -276,16 +279,16 @@ class TrainerConfig(base_config.Config):
276279
recovery_max_trials: int = 0
277280
validation_summary_subdir: str = "validation"
278281
# Preemption on-demand checkpoint.
279-
preemption_on_demand_checkpoint: bool = True
282+
preemption_on_demand_checkpoint: bool = True # copybara-replace
280283

281284

282285
@dataclasses.dataclass
283286
class TaskConfig(base_config.Config):
284287
"""Config passed to task."""
285288
init_checkpoint: str = ""
286289
model: Optional[base_config.Config] = None
287-
train_data: DataConfig = DataConfig()
288-
validation_data: DataConfig = DataConfig()
290+
train_data: DataConfig = dataclasses.field(default_factory=DataConfig)
291+
validation_data: DataConfig = dataclasses.field(default_factory=DataConfig)
289292
name: Optional[str] = None
290293
# Configs for differential privacy
291294
# These configs are only effective if you use create_optimizer in
@@ -301,6 +304,6 @@ class TaskConfig(base_config.Config):
301304
@dataclasses.dataclass
302305
class ExperimentConfig(base_config.Config):
303306
"""Top-level configuration."""
304-
task: TaskConfig = TaskConfig()
305-
trainer: TrainerConfig = TrainerConfig()
306-
runtime: RuntimeConfig = RuntimeConfig()
307+
task: TaskConfig = dataclasses.field(default_factory=TaskConfig)
308+
trainer: TrainerConfig = dataclasses.field(default_factory=TrainerConfig)
309+
runtime: RuntimeConfig = dataclasses.field(default_factory=RuntimeConfig)

official/core/savedmodel_checkpoint_manager.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -72,13 +72,19 @@ def save(self,
7272

7373
# Save the models for the checkpoint that just got written.
7474
saved_modules_directory = make_saved_modules_directory_name(checkpoint_path)
75+
# Atomic export of SavedModel. Write into a temporary direcotory and then
76+
# rename as the final direcotory after finishing the writing.
77+
# This can avoid trying to read an unfinished savedmodel.
78+
saved_modules_directory_tmp = saved_modules_directory + '_temp'
7579
for model_name, model in self._modules_to_export.items():
7680
signatures = getattr(model, 'saved_model_signatures', None)
7781
if signatures is not None:
7882
tf.saved_model.save(
7983
obj=model,
80-
export_dir=os.path.join(saved_modules_directory, model_name),
84+
export_dir=os.path.join(saved_modules_directory_tmp, model_name),
8185
signatures=signatures)
86+
if tf.io.gfile.exists(saved_modules_directory_tmp):
87+
tf.io.gfile.rename(saved_modules_directory_tmp, saved_modules_directory)
8288

8389
saved_modules_directories_to_keep = [
8490
make_saved_modules_directory_name(ckpt) for ckpt in self.checkpoints
@@ -105,7 +111,14 @@ def get_existing_savedmodels(self) -> List[str]:
105111
"""
106112
saved_modules_glob = make_saved_modules_directory_name(
107113
self._checkpoint_prefix + '-*')
108-
return tf.io.gfile.glob(saved_modules_glob)
114+
savedmodels = tf.io.gfile.glob(saved_modules_glob)
115+
# Filter out temporary savedmodel.
116+
savedmodels = [
117+
savedmodel
118+
for savedmodel in savedmodels
119+
if savedmodel.endswith(SAVED_MODULES_PATH_SUFFIX)
120+
]
121+
return savedmodels
109122

110123
@property
111124
def latest_savedmodel(self) -> Union[str, None]:
@@ -214,7 +227,7 @@ def wait_for_new_savedmodel(
214227
logging.info('Waiting for new savedmodel at %s', self._directory)
215228
stop_time = time.time() + timeout if timeout is not None else None
216229

217-
last_savedmodel_number = 0
230+
last_savedmodel_number = -1
218231
if last_savedmodel:
219232
last_savedmodel_number = self.get_savedmodel_number_from_path(
220233
last_savedmodel)

official/core/train_lib.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ def trainer(self) -> base_trainer.Trainer:
137137
return self._trainer
138138

139139
@property
140-
def checkpoint_manager(self) -> tf.train.CheckpointManager:
140+
def checkpoint_manager(self) -> Optional[tf.train.CheckpointManager]:
141141
"""The CheckpointManager that stores the checkpoints in a train job."""
142142
return self._checkpoint_manager
143143

@@ -205,11 +205,14 @@ def _build_controller(
205205
"""Builds a Orbit controler."""
206206
train_actions = [] if not train_actions else train_actions
207207
if trainer:
208+
checkpoint_manager = self.checkpoint_manager
209+
assert checkpoint_manager, 'Checkpoint manager required but undefined.'
208210
train_actions += actions.get_train_actions(
209211
self.params,
210212
trainer,
211213
self.model_dir,
212-
checkpoint_manager=self.checkpoint_manager)
214+
checkpoint_manager=checkpoint_manager,
215+
)
213216

214217
eval_actions = [] if not eval_actions else eval_actions
215218
if evaluator:

official/core/train_utils.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414

1515
"""Training utils."""
16+
1617
import dataclasses
1718
import inspect
1819
import json
@@ -22,10 +23,12 @@
2223

2324
from absl import logging
2425
import gin
26+
import numpy as np
2527
import orbit
2628
import tensorflow as tf
2729

2830
# pylint: disable=g-direct-tensorflow-import
31+
from tensorflow.python.framework import ops
2932
from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2_as_graph
3033
# pylint: enable=g-direct-tensorflow-import
3134
from official.core import base_task
@@ -564,3 +567,44 @@ def try_count_flops(model: Union[tf.Module, tf.keras.Model],
564567
'reached before this run.', e)
565568
return None
566569
return None
570+
571+
572+
@ops.RegisterStatistics('Einsum', 'flops')
573+
def _einsum_flops(graph, node):
574+
"""Calculates the compute resources needed for Einsum."""
575+
assert len(node.input) == 2
576+
x_shape = tf.compat.v1.graph_util.tensor_shape_from_node_def_name(
577+
graph, node.input[0])
578+
y_shape = tf.compat.v1.graph_util.tensor_shape_from_node_def_name(
579+
graph, node.input[1])
580+
x_shape.assert_is_fully_defined()
581+
y_shape.assert_is_fully_defined()
582+
x_shape = x_shape.as_list()
583+
y_shape = y_shape.as_list()
584+
equation = str(node.attr['equation'])
585+
equation = (
586+
equation.replace('s:', '')
587+
.replace('"', '')
588+
.replace(' ', '')
589+
.replace('\n', '')
590+
)
591+
x_str = equation.split(',')[0]
592+
y_r_str = equation.split(',')[1]
593+
y_str = y_r_str.split('->')[0]
594+
r_str = y_r_str.split('->')[1]
595+
shape_dic = {}
596+
contracted = set()
597+
for indice in x_str + y_str:
598+
if indice in x_str:
599+
indice_dim = x_shape[x_str.find(indice)]
600+
elif indice in y_str:
601+
indice_dim = y_shape[y_str.find(indice)]
602+
else:
603+
raise ValueError('indice {} not found in inputs'.format(indice))
604+
shape_dic[indice] = indice_dim
605+
if indice not in r_str:
606+
contracted.add(indice)
607+
madds = np.prod([shape_dic[indice] for indice in r_str]) * (
608+
np.prod([shape_dic[indice] for indice in contracted]))
609+
flops = 2 * madds
610+
return ops.OpStats('flops', flops)

official/legacy/image_classification/configs/base_configs.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -112,10 +112,16 @@ class TrainConfig(hyperparams.Config):
112112
resume_checkpoint: bool = None
113113
epochs: int = None
114114
steps: int = None
115-
callbacks: CallbacksConfig = CallbacksConfig()
115+
callbacks: CallbacksConfig = dataclasses.field(
116+
default_factory=CallbacksConfig
117+
)
116118
metrics: MetricsConfig = None
117-
tensorboard: TensorBoardConfig = TensorBoardConfig()
118-
time_history: TimeHistoryConfig = TimeHistoryConfig()
119+
tensorboard: TensorBoardConfig = dataclasses.field(
120+
default_factory=TensorBoardConfig
121+
)
122+
time_history: TimeHistoryConfig = dataclasses.field(
123+
default_factory=TimeHistoryConfig
124+
)
119125
set_epoch_loop: bool = False
120126

121127

0 commit comments

Comments
 (0)