File tree Expand file tree Collapse file tree 1 file changed +6
-6
lines changed
tensorflow_model_optimization/python/examples/sparsity/keras/mnist Expand file tree Collapse file tree 1 file changed +6
-6
lines changed Original file line number Diff line number Diff line change 3939
4040# define model
4141input = tf .keras .layers .Input (shape = (28 , 28 ))
42- x = tf .keras .layers .MultiHeadAttention (num_heads = 2 , key_dim = 16 , name = " mha" )(
42+ x = tf .keras .layers .MultiHeadAttention (num_heads = 2 , key_dim = 16 , name = ' mha' )(
4343 query = input , value = input
4444)
4545x = tf .keras .layers .Flatten ()(x )
4848
4949# Train the digit classification model
5050model .compile (
51- optimizer = " adam" ,
51+ optimizer = ' adam' ,
5252 loss = tf .keras .losses .SparseCategoricalCrossentropy (from_logits = True ),
53- metrics = [" accuracy" ],
53+ metrics = [' accuracy' ],
5454)
5555
5656model .fit (
8080
8181# `prune_low_magnitude` requires a recompile.
8282model_for_pruning .compile (
83- optimizer = " adam" ,
83+ optimizer = ' adam' ,
8484 loss = tf .keras .losses .SparseCategoricalCrossentropy (from_logits = True ),
85- metrics = [" accuracy" ],
85+ metrics = [' accuracy' ],
8686)
8787
8888model_for_pruning .fit (
9696
9797score = model_for_pruning .evaluate (test_images , test_labels , verbose = 0 )
9898print ('Pruned model test loss:' , score [0 ])
99- print ('Pruned model test accuracy:' , score [1 ])
99+ print ('Pruned model test accuracy:' , score [1 ])
You can’t perform that action at this time.
0 commit comments