Skip to content

Commit d14cf43

Browse files
LouYu2015tensorflower-gardener
authored andcommitted
No public description
PiperOrigin-RevId: 596723408
1 parent befedd2 commit d14cf43

File tree

1 file changed

+6
-2
lines changed

1 file changed

+6
-2
lines changed

official/projects/pix2seq/modeling/pix2seq_model.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -341,9 +341,13 @@ def call(
341341
inputs: tf.Tensor,
342342
targets: Optional[tf.Tensor] = None,
343343
training: bool = None,
344-
use_teacher_forcing_for_eval: bool = False
344+
use_teacher_forcing_for_eval: bool = False,
345+
use_input_as_backbone_features=False,
345346
) -> List[Any]:
346-
features = self._backbone(inputs)[self._backbone_endpoint_name]
347+
if use_input_as_backbone_features:
348+
features = inputs
349+
else:
350+
features = self._backbone(inputs)[self._backbone_endpoint_name]
347351
mask = tf.ones_like(features)
348352
batch_size, h, w, num_channels = get_shape(features)
349353
features = tf.reshape(features, [batch_size, h * w, num_channels])

0 commit comments

Comments
 (0)