19
19
from __future__ import print_function
20
20
# import g3
21
21
import numpy as np
22
- from tensorflow .python .eager import context
23
- from tensorflow .python .framework import constant_op
24
- from tensorflow .python .framework import dtypes
25
- from tensorflow .python .framework import ops
22
+
23
+ import tensorflow .compat .v1 as tf
24
+ # TODO(tf-mot): when migrating to 2.0, K.get_session() no longer exists.
25
+ K = tf .keras .backend
26
+ dtypes = tf .dtypes
27
+ test = tf .test
28
+
26
29
from tensorflow .python .framework import test_util as tf_test_util
27
- from tensorflow .python .keras import backend as K
28
- from tensorflow .python .ops import math_ops
29
- from tensorflow .python .ops import partitioned_variables
30
- from tensorflow .python .ops import state_ops
31
- from tensorflow .python .ops import variable_scope
32
- from tensorflow .python .ops import variables
33
- from tensorflow .python .platform import test
34
30
from tensorflow_model_optimization .python .core .sparsity .keras import pruning_impl
35
31
from tensorflow_model_optimization .python .core .sparsity .keras import pruning_schedule
36
32
from tensorflow_model_optimization .python .core .sparsity .keras import pruning_utils
@@ -66,7 +62,7 @@ def testUpdateSingleMask(self):
66
62
mask_before_pruning = K .get_value (mask )
67
63
self .assertAllEqual (np .count_nonzero (mask_before_pruning ), 100 )
68
64
69
- if context .executing_eagerly ():
65
+ if tf .executing_eagerly ():
70
66
p .conditional_mask_update ()
71
67
else :
72
68
K .get_session ().run (p .conditional_mask_update ())
@@ -121,7 +117,7 @@ def testBlockMaskingAvg(self):
121
117
def testBlockMaskingMax (self ):
122
118
block_size = (2 , 2 )
123
119
block_pooling_type = "MAX"
124
- weight = constant_op .constant ([[0.1 , 0.0 , 0.2 , 0.0 ], [0.0 , - 0.1 , 0.0 , - 0.2 ],
120
+ weight = tf .constant ([[0.1 , 0.0 , 0.2 , 0.0 ], [0.0 , - 0.1 , 0.0 , - 0.2 ],
125
121
[0.3 , 0.0 , 0.4 , 0.0 ], [0.0 , - 0.3 , 0.0 ,
126
122
- 0.4 ]])
127
123
expected_mask = [[0.0 , 0.0 , 0.0 , 0.0 ], [0.0 , 0.0 , 0.0 , 0.0 ],
@@ -133,7 +129,7 @@ def testBlockMaskingWithHigherDimensionsRaisesError(self):
133
129
block_size = (2 , 2 )
134
130
block_pooling_type = "AVG"
135
131
# Weights as in testBlockMasking, but with one extra dimension.
136
- weight = constant_op .constant ([[[0.1 , 0.1 , 0.2 , 0.2 ], [0.1 , 0.1 , 0.2 , 0.2 ],
132
+ weight = tf .constant ([[[0.1 , 0.1 , 0.2 , 0.2 ], [0.1 , 0.1 , 0.2 , 0.2 ],
137
133
[0.3 , 0.3 , 0.4 , 0.4 ], [0.3 , 0.3 , 0.4 ,
138
134
0.4 ]]])
139
135
expected_mask = [[[0.0 , 0.0 , 0.0 , 0.0 ], [0.0 , 0.0 , 0.0 , 0.0 ],
@@ -149,9 +145,9 @@ def testConditionalMaskUpdate(self):
149
145
threshold = K .zeros ([])
150
146
151
147
def linear_sparsity (step ):
152
- sparsity_val = ops .convert_to_tensor (
148
+ sparsity_val = tf .convert_to_tensor (
153
149
[0.0 , 0.1 , 0.1 , 0.3 , 0.3 , 0.5 , 0.5 , 0.5 , 0.5 , 0.5 ])
154
- return ops .convert_to_tensor (True ), sparsity_val [step ]
150
+ return tf .convert_to_tensor (True ), sparsity_val [step ]
155
151
156
152
# Set up pruning
157
153
p = pruning_impl .Pruning (
@@ -163,14 +159,14 @@ def linear_sparsity(step):
163
159
164
160
non_zero_count = []
165
161
for _ in range (10 ):
166
- if context .executing_eagerly ():
162
+ if tf .executing_eagerly ():
167
163
p .conditional_mask_update ()
168
164
p .weight_mask_op ()
169
- state_ops .assign_add (self .global_step , 1 )
165
+ tf .assign_add (self .global_step , 1 )
170
166
else :
171
167
K .get_session ().run (p .conditional_mask_update ())
172
168
K .get_session ().run (p .weight_mask_op ())
173
- K .get_session ().run (state_ops .assign_add (self .global_step , 1 ))
169
+ K .get_session ().run (tf .assign_add (self .global_step , 1 ))
174
170
175
171
non_zero_count .append (np .count_nonzero (K .get_value (weight )))
176
172
@@ -180,4 +176,5 @@ def linear_sparsity(step):
180
176
181
177
182
178
if __name__ == "__main__" :
179
+ tf .disable_v2_behavior ()
183
180
test .main ()
0 commit comments