Skip to content
This repository was archived by the owner on Jul 10, 2025. It is now read-only.

Commit 52c71a7

Browse files
committed
NFC: Make arg names of GradientFunctions consistent across examples.
1 parent 2d11264 commit 52c71a7

File tree

1 file changed

+13
-13
lines changed

1 file changed

+13
-13
lines changed

rfcs/20201201-cpp-gradients.md

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -68,15 +68,15 @@ Examples:
6868
class AddGradientFunction : public GradientFunction {
6969
public:
7070
Status Compute(AbstractContext* ctx,
71-
absl::Span<AbstractTensorHandle* const> grad_inputs,
72-
absl::Span<AbstractTensorHandle*> grad_outputs) override {
71+
absl::Span<AbstractTensorHandle* const> grad_outputs,
72+
absl::Span<AbstractTensorHandle*> grad_inputs) override {
7373
// Tape never calls a gradient function if there are no incoming grads.
74-
DCHECK(grad_inputs[0]);
75-
grad_outputs[0] = grad_inputs[0];
76-
grad_outputs[1] = grad_inputs[0];
74+
DCHECK(grad_outputs[0]);
75+
grad_inputs[0] = grad_outputs[0];
76+
grad_inputs[1] = grad_outputs[0];
7777
78-
grad_outputs[0]->Ref();
79-
grad_outputs[1]->Ref();
78+
grad_inputs[0]->Ref();
79+
grad_inputs[1]->Ref();
8080
return Status::OK();
8181
}
8282
~AddGradientFunction() override {}
@@ -88,15 +88,15 @@ class ExpGradientFunction : public GradientFunction {
8888
exp->Ref();
8989
}
9090
Status ExpGradientFunction::Compute(
91-
AbstractContext* ctx, absl::Span<AbstractTensorHandle* const> grad_inputs,
92-
absl::Span<AbstractTensorHandle*> grad_outputs) {
91+
AbstractContext* ctx, absl::Span<AbstractTensorHandle* const> grad_outputs,
92+
absl::Span<AbstractTensorHandle*> grad_inputs) {
9393
vector<AbstractTensorHandle*> conj_outputs(1);
9494
TF_RETURN_IF_ERROR(
9595
Conj(ctx, {exp_.get()}, absl::MakeSpan(conj_outputs), "Conj_Exp_Grad"));
9696
AbstractTensorHandlePtr conj_output_releaser(conj_outputs[0]);
9797
9898
TF_RETURN_IF_ERROR(
99-
Mul(ctx, {conj_outputs[0], grad_inputs[0]}, grad_outputs, "Mul_Exp_Grad"));
99+
Mul(ctx, {conj_outputs[0], grad_outputs[0]}, grad_inputs, "Mul_Exp_Grad"));
100100
return Status::OK();
101101
}
102102
@@ -281,9 +281,9 @@ Example:
281281
class CustomGradientFunction: public GradientFunction {
282282
public:
283283
Status ExpGradientFunction::Compute(
284-
AbstractContext* ctx, absl::Span<AbstractTensorHandle* const> grad_inputs,
285-
absl::Span<AbstractTensorHandle*> grad_outputs) {
286-
// Populate grad_outputs.
284+
AbstractContext* ctx, absl::Span<AbstractTensorHandle* const> grad_outputs,
285+
absl::Span<AbstractTensorHandle*> grad_inputs) {
286+
// Populate grad_inputs.
287287
return Status::OK();
288288
}
289289

0 commit comments

Comments
 (0)