Skip to content

Commit 2adce8a

Browse files
Xharktensorflower-gardener
authored andcommitted
Change pruning_step type to int64 from int32.
PiperOrigin-RevId: 278802043
1 parent 4a96efe commit 2adce8a

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

tensorflow_model_optimization/python/core/sparsity/keras/pruning_wrapper.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,7 @@ def build(self, input_shape):
201201
'pruning_step',
202202
shape=[],
203203
initializer=initializers.Constant(-1),
204-
dtype=dtypes.int32,
204+
dtype=dtypes.int64,
205205
trainable=False)
206206

207207
def training_step_fn():
@@ -222,7 +222,9 @@ def call(self, inputs, training=None):
222222
def add_update():
223223
with ops.control_dependencies([
224224
check_ops.assert_greater_equal(
225-
self.pruning_step, 0, message=self._PRUNE_CALLBACK_ERROR_MSG)]):
225+
self.pruning_step,
226+
np.int64(0),
227+
message=self._PRUNE_CALLBACK_ERROR_MSG)]):
226228
with ops.control_dependencies(
227229
[self.pruning_obj.conditional_mask_update()]):
228230
return control_flow_ops.no_op('update')

0 commit comments

Comments
 (0)