Skip to content

Commit 516c667

Browse files
committed
Addressed reviewer's comments.
Change-Id: I4e0c012f75813a18891072839da79c73c266ecad
1 parent fa58430 commit 516c667

File tree

1 file changed

+6
-6
lines changed
  • tensorflow_model_optimization/python/examples/sparsity/keras/mnist

1 file changed

+6
-6
lines changed

tensorflow_model_optimization/python/examples/sparsity/keras/mnist/mnist_mha.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939

4040
# define model
4141
input = tf.keras.layers.Input(shape=(28, 28))
42-
x = tf.keras.layers.MultiHeadAttention(num_heads=2, key_dim=16, name="mha")(
42+
x = tf.keras.layers.MultiHeadAttention(num_heads=2, key_dim=16, name='mha')(
4343
query=input, value=input
4444
)
4545
x = tf.keras.layers.Flatten()(x)
@@ -48,9 +48,9 @@
4848

4949
# Train the digit classification model
5050
model.compile(
51-
optimizer="adam",
51+
optimizer='adam',
5252
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
53-
metrics=["accuracy"],
53+
metrics=['accuracy'],
5454
)
5555

5656
model.fit(
@@ -80,9 +80,9 @@
8080

8181
# `prune_low_magnitude` requires a recompile.
8282
model_for_pruning.compile(
83-
optimizer="adam",
83+
optimizer='adam',
8484
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
85-
metrics=["accuracy"],
85+
metrics=['accuracy'],
8686
)
8787

8888
model_for_pruning.fit(
@@ -96,4 +96,4 @@
9696

9797
score = model_for_pruning.evaluate(test_images, test_labels, verbose=0)
9898
print('Pruned model test loss:', score[0])
99-
print('Pruned model test accuracy:', score[1])
99+
print('Pruned model test accuracy:', score[1])

0 commit comments

Comments
 (0)