Skip to content

Commit 34934d3

Browse files
jburnimtensorflower-gardener
authored andcommitted
Fix a bug with sparse updates in VariationalSGD optimizer.
PiperOrigin-RevId: 388353133
1 parent c26e0ed commit 34934d3

File tree

2 files changed

+3
-3
lines changed

2 files changed

+3
-3
lines changed

tensorflow_probability/python/optimizer/variational_sgd.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -251,8 +251,9 @@ def _resource_apply_sparse(self, grad, var, indices):
251251
self._burnin_max_learning_rate, self._max_learning_rate)
252252

253253
learn_rate = tf.clip_by_value(
254-
self._get_coordinatewise_learning_rate(grad, var), 0.,
255-
tf.cast(max_learning_rate, var.dtype))
254+
self._get_coordinatewise_learning_rate(
255+
tf.IndexedSlices(grad, indices), var),
256+
0., tf.cast(max_learning_rate, var.dtype))
256257
delta = grad * learn_rate
257258

258259
return self._resource_scatter_add(var, indices, -delta)

tensorflow_probability/python/optimizer/variational_sgd_test.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -309,7 +309,6 @@ def testWithGlobalStep(self):
309309
init_step_value + 1, self.evaluate(sgd_optimizer.iterations))
310310

311311
def testSparseBasic(self):
312-
self.skipTest('b/195306553')
313312
for dtype in [tf.half, tf.float32, tf.float64]:
314313
with self.cached_session():
315314
var0 = tf.Variable([[1.1], [2.1]], dtype=dtype)

0 commit comments

Comments
 (0)