Skip to content

Commit b8c62df

Browse files
saberkuntensorflower-gardener
authored andcommitted
Change Seq2SeqTransformer inputs to dictionary.
PiperOrigin-RevId: 338309524
1 parent 17155b4 commit b8c62df

File tree

3 files changed

+14
-17
lines changed

3 files changed

+14
-17
lines changed

official/nlp/modeling/models/seq2seq_transformer.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -130,9 +130,9 @@ def call(self, inputs):
130130
"""Calculate target logits or inferred target sequences.
131131
132132
Args:
133-
inputs: input tensor list of size 1 or 2.
134-
First item, inputs: int tensor with shape [batch_size, input_length].
135-
Second item (optional), targets: None or int tensor with shape
133+
inputs: a dictionary of tensors.
134+
Feature `inputs`: int tensor with shape [batch_size, input_length].
135+
Feature `targets` (optional): None or int tensor with shape
136136
[batch_size, target_length].
137137
138138
Returns:
@@ -147,12 +147,8 @@ def call(self, inputs):
147147
Raises:
148148
NotImplementedError: If try to use padded decode method on CPU/GPUs.
149149
"""
150-
inputs = inputs if isinstance(inputs, list) else [inputs]
151-
if len(inputs) == 2:
152-
sources, targets = inputs[0], inputs[1]
153-
else:
154-
# Decoding path.
155-
sources, targets = inputs[0], None
150+
sources = inputs["inputs"]
151+
targets = inputs.get("targets", None)
156152
attention_bias = model_utils.get_padding_bias(sources)
157153
attention_bias = tf.cast(attention_bias, self._dtype)
158154
# Prepare inputs to the layer stack by adding positional encodings and

official/nlp/modeling/models/seq2seq_transformer_test.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -82,15 +82,15 @@ def _step_fn(inputs):
8282
return tf.nest.map_structure(distribution.experimental_local_results,
8383
outputs)
8484

85-
fake_inputs = [np.zeros((batch_size, decode_max_length), dtype=np.int32)]
85+
fake_inputs = dict(
86+
inputs=np.zeros((batch_size, decode_max_length), dtype=np.int32))
8687
local_outputs = step(fake_inputs)
8788
logging.info("local_outputs=%s", local_outputs)
8889
self.assertEqual(local_outputs["outputs"][0].shape, (4, 10))
8990

90-
fake_inputs = [
91-
np.zeros((batch_size, decode_max_length), dtype=np.int32),
92-
np.zeros((batch_size, 8), dtype=np.int32)
93-
]
91+
fake_inputs = dict(
92+
inputs=np.zeros((batch_size, decode_max_length), dtype=np.int32),
93+
targets=np.zeros((batch_size, 8), dtype=np.int32))
9494
local_outputs = step(fake_inputs)
9595
logging.info("local_outputs=%s", local_outputs)
9696
self.assertEqual(local_outputs[0].shape, (4, 8, 100))
@@ -108,7 +108,7 @@ def __init__(self, model):
108108

109109
@tf.function
110110
def serve(self, inputs):
111-
return self.model.call([inputs])
111+
return self.model.call(dict(inputs=inputs))
112112

113113
save_module = SaveModule(model)
114114
if padded_decode:

official/nlp/transformer/transformer_forward_test.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,8 @@ def _create_model(params, is_train):
7070
inputs = tf.keras.layers.Input((None,), dtype="int64", name="inputs")
7171
targets = tf.keras.layers.Input((None,), dtype="int64", name="targets")
7272
internal_model = models.Seq2SeqTransformer(**model_kwargs)
73-
logits = internal_model([inputs, targets], training=is_train)
73+
logits = internal_model(
74+
dict(inputs=inputs, targets=targets), training=is_train)
7475
vocab_size = params["vocab_size"]
7576
label_smoothing = params["label_smoothing"]
7677
if params["enable_metrics_in_training"]:
@@ -90,7 +91,7 @@ def _create_model(params, is_train):
9091
dtype="int64",
9192
name="inputs")
9293
internal_model = models.Seq2SeqTransformer(**model_kwargs)
93-
ret = internal_model([inputs], training=is_train)
94+
ret = internal_model(dict(inputs=inputs), training=is_train)
9495
outputs, scores = ret["outputs"], ret["scores"]
9596
return tf.keras.Model(inputs, [outputs, scores])
9697

0 commit comments

Comments
 (0)