@@ -148,9 +148,9 @@ XLATensorPtr SmoothL1LossBackward(const XLATensorPtr& grad_output,
148
148
XLATensorPtr grad_scale = tensor_methods::get_dimensions_size (
149
149
broadcasted_input,
150
150
XlaHelpers::GetAllDimensions (broadcasted_input->shape ()));
151
- return tensor_methods::mul (
152
- tensor_methods::div (elementwise_loss_backward, grad_scale),
153
- grad_output);
151
+ XLATensorPtr div_result = GetValueOrThrow (
152
+ tensor_methods::div (elementwise_loss_backward, grad_scale));
153
+ return tensor_methods::mul (div_result, grad_output);
154
154
}
155
155
default :
156
156
XLA_ERROR () << " Invalid reduction type: "
@@ -174,12 +174,12 @@ XLATensorPtr SoftplusBackward(const XLATensorPtr& grad_output,
174
174
XLATensorPtr z = tensor_methods::exp (scaled_input);
175
175
XLATensorPtr one_vec =
176
176
tensor_methods::full_like (z, 1 , z->GetDevice (), z->dtype ());
177
+ XLATensorPtr div = GetValueOrThrow (
178
+ tensor_methods::div (z, tensor_methods::add (z, one_vec, 1 )));
177
179
178
- return tensor_methods::where (
179
- tensor_methods::gt (scaled_input, threshold), grad_output,
180
- tensor_methods::mul (
181
- grad_output,
182
- tensor_methods::div (z, tensor_methods::add (z, one_vec, 1 ))));
180
+ return tensor_methods::where (tensor_methods::gt (scaled_input, threshold),
181
+ grad_output,
182
+ tensor_methods::mul (grad_output, div));
183
183
}
184
184
185
185
XLATensorPtr Select (const XLATensorPtr& input, int64_t dim, int64_t index) {
@@ -223,8 +223,8 @@ XLATensorPtr EmbeddingDenseBackward(const XLATensorPtr& grad_output,
223
223
XLATensorPtr grad_weights_scale =
224
224
tensor_methods::index (counts, {indices_rank1}, 0 );
225
225
// Scale the value of the gradient by the histogram.
226
- grad = tensor_methods::div (
227
- grad, tensor_methods::unsqueeze (grad_weights_scale, 1 ));
226
+ grad = GetValueOrThrow ( tensor_methods::div (
227
+ grad, tensor_methods::unsqueeze (grad_weights_scale, 1 ))) ;
228
228
}
229
229
// Don't accumulate gradients for indices which are equal with the given
230
230
// padding_idx.
0 commit comments