Skip to content

Commit 2a8a625

Browse files
yeqinglitensorflower-gardener
authored andcommitted
Internal change
PiperOrigin-RevId: 398095616
1 parent ce59424 commit 2a8a625

File tree

4 files changed

+36
-5
lines changed

4 files changed

+36
-5
lines changed

official/vision/beta/configs/video_classification.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@
1414

1515
# Lint as: python3
1616
"""Video classification configuration definition."""
17-
from typing import Optional, Tuple
1817
import dataclasses
18+
from typing import Optional, Tuple
1919
from official.core import config_definitions as cfg
2020
from official.core import exp_factory
2121
from official.modeling import hyperparams
@@ -121,6 +121,7 @@ class VideoClassificationModel(hyperparams.Config):
121121
use_sync_bn=False)
122122
dropout_rate: float = 0.2
123123
aggregate_endpoints: bool = False
124+
require_endpoints: Optional[Tuple[str, ...]] = None
124125

125126

126127
@dataclasses.dataclass
@@ -146,6 +147,10 @@ class VideoClassificationTask(cfg.TaskConfig):
146147
metrics: Metrics = Metrics()
147148
init_checkpoint: Optional[str] = None
148149
init_checkpoint_modules: str = 'all' # all or backbone
150+
# Spatial Partitioning fields. See go/tf2-spatial-partition-api-examples
151+
# for explanation of the technique.
152+
train_input_partition_dims: Optional[Tuple[int, ...]] = None
153+
eval_input_partition_dims: Optional[Tuple[int, ...]] = None
149154

150155

151156
def add_trainer(experiment: cfg.ExperimentConfig,

official/vision/beta/modeling/factory_3d.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,5 +98,6 @@ def build_video_classification_model(
9898
input_specs=input_specs_dict,
9999
dropout_rate=model_config.dropout_rate,
100100
aggregate_endpoints=model_config.aggregate_endpoints,
101-
kernel_regularizer=l2_regularizer)
101+
kernel_regularizer=l2_regularizer,
102+
require_endpoints=model_config.require_endpoints)
102103
return model

official/vision/beta/modeling/video_classification_model.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,8 @@
1313
# limitations under the License.
1414

1515
"""Build video classification models."""
16-
from typing import Any, Mapping, Optional, Union
16+
from typing import Any, Mapping, Optional, Union, List, Text
17+
1718
import tensorflow as tf
1819

1920
layers = tf.keras.layers
@@ -33,6 +34,7 @@ def __init__(
3334
kernel_initializer: str = 'random_uniform',
3435
kernel_regularizer: Optional[tf.keras.regularizers.Regularizer] = None,
3536
bias_regularizer: Optional[tf.keras.regularizers.Regularizer] = None,
37+
require_endpoints: Optional[List[Text]] = None,
3638
**kwargs):
3739
"""Video Classification initialization function.
3840
@@ -48,6 +50,8 @@ def __init__(
4850
None.
4951
bias_regularizer: tf.keras.regularizers.Regularizer object. Default to
5052
None.
53+
require_endpoints: the required endpoints for prediction. If None or
54+
empty, then only uses the final endpoint.
5155
**kwargs: keyword arguments to be passed.
5256
"""
5357
if not input_specs:
@@ -64,6 +68,7 @@ def __init__(
6468
'kernel_initializer': kernel_initializer,
6569
'kernel_regularizer': kernel_regularizer,
6670
'bias_regularizer': bias_regularizer,
71+
'require_endpoints': require_endpoints,
6772
}
6873
self._input_specs = input_specs
6974
self._kernel_regularizer = kernel_regularizer
@@ -82,8 +87,18 @@ def __init__(
8287
pooled_feats.append(x_pool)
8388
x = tf.concat(pooled_feats, axis=1)
8489
else:
85-
x = endpoints[max(endpoints.keys())]
86-
x = tf.keras.layers.GlobalAveragePooling3D()(x)
90+
if not require_endpoints:
91+
# Uses the last endpoint for prediction.
92+
x = endpoints[max(endpoints.keys())]
93+
x = tf.keras.layers.GlobalAveragePooling3D()(x)
94+
else:
95+
# Concats all the required endpoints for prediction.
96+
outputs = []
97+
for name in require_endpoints:
98+
x = endpoints[name]
99+
x = tf.keras.layers.GlobalAveragePooling3D()(x)
100+
outputs.append(x)
101+
x = tf.concat(outputs, axis=1)
87102

88103
x = tf.keras.layers.Dropout(dropout_rate)(x)
89104
x = tf.keras.layers.Dense(

official/vision/beta/tasks/video_classification.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -255,6 +255,11 @@ def train_step(self,
255255
A dictionary of logs.
256256
"""
257257
features, labels = inputs
258+
input_partition_dims = self.task_config.train_input_partition_dims
259+
if input_partition_dims:
260+
strategy = tf.distribute.get_strategy()
261+
features['image'] = strategy.experimental_split_to_logical_devices(
262+
features['image'], input_partition_dims)
258263

259264
num_replicas = tf.distribute.get_strategy().num_replicas_in_sync
260265
with tf.GradientTape() as tape:
@@ -314,6 +319,11 @@ def validation_step(self,
314319
A dictionary of logs.
315320
"""
316321
features, labels = inputs
322+
input_partition_dims = self.task_config.eval_input_partition_dims
323+
if input_partition_dims:
324+
strategy = tf.distribute.get_strategy()
325+
features['image'] = strategy.experimental_split_to_logical_devices(
326+
features['image'], input_partition_dims)
317327

318328
outputs = self.inference_step(features, model)
319329
outputs = tf.nest.map_structure(lambda x: tf.cast(x, tf.float32), outputs)

0 commit comments

Comments
 (0)