@@ -68,15 +68,15 @@ Examples:
6868class AddGradientFunction : public GradientFunction {
6969 public:
7070 Status Compute(AbstractContext* ctx,
71- absl::Span<AbstractTensorHandle* const> grad_inputs ,
72- absl::Span<AbstractTensorHandle*> grad_outputs ) override {
71+ absl::Span<AbstractTensorHandle* const> grad_outputs ,
72+ absl::Span<AbstractTensorHandle*> grad_inputs ) override {
7373 // Tape never calls a gradient function if there are no incoming grads.
74- DCHECK(grad_inputs [0]);
75- grad_outputs [0] = grad_inputs [0];
76- grad_outputs [1] = grad_inputs [0];
74+ DCHECK(grad_outputs [0]);
75+ grad_inputs [0] = grad_outputs [0];
76+ grad_inputs [1] = grad_outputs [0];
7777
78- grad_outputs [0]->Ref();
79- grad_outputs [1]->Ref();
78+ grad_inputs [0]->Ref();
79+ grad_inputs [1]->Ref();
8080 return Status::OK();
8181 }
8282 ~AddGradientFunction() override {}
@@ -88,15 +88,15 @@ class ExpGradientFunction : public GradientFunction {
8888 exp->Ref();
8989 }
9090 Status ExpGradientFunction::Compute(
91- AbstractContext* ctx, absl::Span<AbstractTensorHandle* const> grad_inputs ,
92- absl::Span<AbstractTensorHandle*> grad_outputs ) {
91+ AbstractContext* ctx, absl::Span<AbstractTensorHandle* const> grad_outputs ,
92+ absl::Span<AbstractTensorHandle*> grad_inputs ) {
9393 vector<AbstractTensorHandle*> conj_outputs(1);
9494 TF_RETURN_IF_ERROR(
9595 Conj(ctx, {exp_.get()}, absl::MakeSpan(conj_outputs), "Conj_Exp_Grad"));
9696 AbstractTensorHandlePtr conj_output_releaser(conj_outputs[0]);
9797
9898 TF_RETURN_IF_ERROR(
99- Mul(ctx, {conj_outputs[0], grad_inputs [0]}, grad_outputs , "Mul_Exp_Grad"));
99+ Mul(ctx, {conj_outputs[0], grad_outputs [0]}, grad_inputs , "Mul_Exp_Grad"));
100100 return Status::OK();
101101 }
102102
@@ -281,9 +281,9 @@ Example:
281281class CustomGradientFunction: public GradientFunction {
282282 public:
283283 Status ExpGradientFunction::Compute(
284- AbstractContext* ctx, absl::Span<AbstractTensorHandle* const> grad_inputs ,
285- absl::Span<AbstractTensorHandle*> grad_outputs ) {
286- // Populate grad_outputs .
284+ AbstractContext* ctx, absl::Span<AbstractTensorHandle* const> grad_outputs ,
285+ absl::Span<AbstractTensorHandle*> grad_inputs ) {
286+ // Populate grad_inputs .
287287 return Status::OK();
288288 }
289289
0 commit comments