Skip to content

Commit 3d92aa7

Browse files
EmbeddingDenseBackward: Remove padding_idx cast to double (#9406)
1 parent 9499e6f commit 3d92aa7

File tree

2 files changed

+24
-1
lines changed

2 files changed

+24
-1
lines changed

test/test_operations.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -752,6 +752,29 @@ def test_rrelu_module(self):
752752
xla_output.sum().backward()
753753
self.assertEqual(a.grad, xla_a.grad.cpu())
754754

755+
def test_embedding_module(self):
756+
num_embeddings = 16
757+
embed_dim = 4
758+
input_shape = (2, 3)
759+
760+
xla_device = torch_xla.device()
761+
762+
idx = torch.randint(0, num_embeddings, input_shape, dtype=torch.long)
763+
xla_idx = idx.to(xla_device)
764+
765+
m = nn.Embedding(num_embeddings, embed_dim)
766+
xla_m = nn.Embedding(num_embeddings, embed_dim).to(xla_device)
767+
# keep parameters in sync
768+
xla_m.weight.data.copy_(m.weight.data)
769+
770+
output = m(idx)
771+
xla_output = xla_m(xla_idx)
772+
self.assertEqual(output, xla_output.cpu())
773+
774+
output.sum().backward()
775+
xla_output.sum().backward()
776+
self.assertEqual(m.weight.grad, xla_m.weight.grad.cpu())
777+
755778
def test_max_broadcast(self):
756779
xla_device = torch_xla.device()
757780
a = torch.rand(3, 1, 2)

torch_xla/csrc/tensor_ops.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -229,7 +229,7 @@ XLATensorPtr EmbeddingDenseBackward(const XLATensorPtr& grad_output,
229229
// Don't accumulate gradients for indices which are equal with the given
230230
// padding_idx.
231231
XLATensorPtr skip_padding = tensor_methods::unsqueeze(
232-
tensor_methods::ne(indices_rank1, static_cast<double>(padding_idx)), 1);
232+
tensor_methods::ne(indices_rank1, padding_idx), 1);
233233
skip_padding = tensor_methods::expand(
234234
skip_padding,
235235
torch::lazy::ToVector<int64_t>(grad->shape().get().dimensions()));

0 commit comments

Comments
 (0)