You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
{{ message }}
This repository was archived by the owner on Jul 10, 2025. It is now read-only.
Copy file name to clipboardExpand all lines: rfcs/20201201-cpp-gradients.md
+5-5Lines changed: 5 additions & 5 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -57,7 +57,7 @@ class GradientFunction {
57
57
```
58
58
59
59
60
-
`GradientFunction::Compute` receives gradients wrt op’s outputs in grad\_outputs and needs to populate gradients wrt op’s inputs in grad\_inputs. This is the same signature we use for authoring python gradients with the addition of an `AbstractContext`, which provides an API creating operations (eagerly or traced). In python this context is stored in a global variable and is implicitly captured. For the C++ API we chose to pass this context explicitly.
60
+
`GradientFunction::Compute` receives gradients wrt op’s outputs in `grad_outputs` and needs to populate gradients wrt op’s inputs in `grad_inputs`. This is the same signature we use for authoring python gradients with the addition of an `AbstractContext`, which provides an API creating operations (eagerly or traced). In python this context is stored in a global variable and is implicitly captured. For the C++ API we chose to pass this context explicitly.
61
61
62
62
The reason `GradientFunction` is a class and not a callable is so that each op’s gradient function can keep the necessary state needed from forward pass for the gradient computation (see `ExpGradientFunction` below for an example).
63
63
@@ -117,7 +117,7 @@ Authoring gradient functions requires calling elementary ops in C++. Using low l
117
117
118
118
We provide a registry to store the mapping from op type to factory functions that return the `GradientFunction` for an op’s instance. The factory function takes as input the `ForwardOperation`, which contains metadata from the forward operation, and returns a `GradientFunction`. This allows gradient function authors to control which inputs/outputs of the forward op to keep around by increasing the ref-count on `AbstractTensorHandle`.
119
119
120
-
Additionally, we provide a utility function `RegisterNotDifferentiable` to mark an op as non-differentiable. This can be used to implement tf.no\_gradient. We also provide a `NotDifferentiableGradientFunction` which returns nullptr output gradients. This can be used to implement `tf.stop\_gradient`.
120
+
Additionally, we provide a utility function `RegisterNotDifferentiable` to mark an op as non-differentiable. This can be used to implement `tf.no_gradient`. We also provide a `NotDifferentiableGradientFunction` which returns nullptr output gradients. This can be used to implement `tf.stop_gradient`.
121
121
122
122
123
123
```
@@ -311,7 +311,7 @@ Status ExpWithCustomGrad(AbstractContext* ctx,
311
311
312
312
#### tf.recompute\_grad
313
313
314
-
`tf.recompute\_grad` is an application of `tf.custom\_gradient` where we do not record the forward pass on the tape so that we are not holding on to forward pass tensors in memory. (In `tf.custom\_gradient` we allow recording the forward pass on the tape in order for higher-order derivatives to work for cases where the custom gradient function uses intermediate tensors from the forward pass.) This is implemented by executing the forward pass outside the tape (managed by a higher layer) and registering a gradient function that re-runs the forward pass and computes gradients. The same behavior can be achieved using this tape.
314
+
`tf.recompute_grad` is an application of `tf.custom_gradient` where we do not record the forward pass on the tape so that we are not holding on to forward pass tensors in memory. (In `tf.custom_gradient` we allow recording the forward pass on the tape in order for higher-order derivatives to work for cases where the custom gradient function uses intermediate tensors from the forward pass.) This is implemented by executing the forward pass outside the tape (managed by a higher layer) and registering a gradient function that re-runs the forward pass and computes gradients. The same behavior can be achieved using this tape.
315
315
316
316
317
317
#### Nested tapes and higher-order derivatives
@@ -321,12 +321,12 @@ Higher order derivatives are computed by either using a persistent tape or by tr
321
321
322
322
#### Skipping gradients for certain op inputs (skip\_input\_indices)
323
323
324
-
A [small set](https://cs.opensource.google/search?q=f:py$%20skip_input_indices&sq=&ss=tensorflow%2Ftensorflow) of python gradient functions have been optimized to not return gradients for inputs which are not tracked under the tape. This is beneficial in eager mode where unneeded gradients cannot be pruned during execution. In the C++ tape, we support this by providing a `skip\_input\_indices` field on the `ForwardOperation` which stores the list of input indices which are either not watched or have an untrainable dtype.
324
+
A [small set](https://cs.opensource.google/search?q=f:py$%20skip_input_indices&sq=&ss=tensorflow%2Ftensorflow) of python gradient functions have been optimized to not return gradients for inputs which are not tracked under the tape. This is beneficial in eager mode where unneeded gradients cannot be pruned during execution. In the C++ tape, we support this by providing a `skip_input_indices` field on the `ForwardOperation` which stores the list of input indices which are either not watched or have an untrainable dtype.
325
325
326
326
327
327
#### Automatic variable tracking
328
328
329
-
In python, if a variable is accessed inside a `tf.GradientTape`s scope it is automatically tracked, i.e. `Tape::Watch` is called for the `DT\_RESOURCE` tensor backing the variable on behalf of the user. For now we will leave this out as a higher layer feature and require that variable handles are explicitly tracked by a higher layer. We can revisit this later if needed.
329
+
In python, if a variable is accessed inside a `tf.GradientTape`s scope it is automatically tracked, i.e. `Tape::Watch` is called for the `DT_RESOURCE` tensor backing the variable on behalf of the user. For now we will leave this out as a higher layer feature and require that variable handles are explicitly tracked by a higher layer. We can revisit this later if needed.
330
330
331
331
332
332
#### tf.function and functional control flow gradients [out of scope for now]
0 commit comments