You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
{{ message }}
This repository was archived by the owner on Jul 10, 2025. It is now read-only.
Copy file name to clipboardExpand all lines: rfcs/20201201-cpp-gradients.md
+16-20Lines changed: 16 additions & 20 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -31,7 +31,7 @@ In addition, we try to address some shortcomings of the current GradientTape des
31
31
32
32
## Design Overview
33
33
34
-
The gradients infrastructure will be built on top of the abstract interfaces for op execution which provide a backend agnostic way of tracing and executing ops. We provide APIs for authoring `GradientFunctions` and registering them into a `GradientRegistry` for name based lookup. We provide a gradient `Tape` API that is close to python’s tf.GradientTape and shares most of the implementation with the existing tape.
34
+
The gradients infrastructure will be built on top of the abstract interfaces for op execution which provide a backend agnostic way of tracing and executing ops. We provide APIs for authoring `GradientFunction`s and registering them into a `GradientRegistry` for name based lookup. We provide a gradient `Tape` API that is close to python’s tf.GradientTape and shares most of the implementation with the existing tape.
35
35
36
36
37
37
## Detailed Design
@@ -42,7 +42,7 @@ The gradients infrastructure will be built on top of the abstract interfaces for
42
42
43
43
#### GradientFunction
44
44
45
-
An op’s gradient is defined by subclassing GradientFunction
45
+
An op’s gradient is defined by subclassing `GradientFunction`
46
46
47
47
48
48
```
@@ -57,9 +57,9 @@ class GradientFunction {
57
57
```
58
58
59
59
60
-
GradientFunction::Compute receives gradients wrt op’s outputs in grad\_outputs and needs to populate gradients wrt op’s inputs in grad\_inputs. This is the same signature we use for authoring python gradients with the addition of an AbstractContext, which provides an API creating operations (eagerly or traced). In python this context is stored in a global variable and is implicitly captured. For the C++ API we chose to pass this context explicitly.
60
+
`GradientFunction::Compute` receives gradients wrt op’s outputs in grad\_outputs and needs to populate gradients wrt op’s inputs in grad\_inputs. This is the same signature we use for authoring python gradients with the addition of an `AbstractContext`, which provides an API creating operations (eagerly or traced). In python this context is stored in a global variable and is implicitly captured. For the C++ API we chose to pass this context explicitly.
61
61
62
-
The reason GradientFunction is a class and not a callable is so that each op’s gradient function can keep the necessary state needed from forward pass for the gradient computation (see ExpGradientFunction below for an example).
62
+
The reason `GradientFunction` is a class and not a callable is so that each op’s gradient function can keep the necessary state needed from forward pass for the gradient computation (see `ExpGradientFunction` below for an example).
63
63
64
64
Examples:
65
65
@@ -115,9 +115,9 @@ Authoring gradient functions requires calling elementary ops in C++. Using low l
115
115
116
116
#### GradientRegistry
117
117
118
-
We provide a registry to store the mapping from op type to factory functions that return the GradientFunction for an op’s instance. The factory function takes as input the ForwardOperation, which contains metadata from the forward operation, and returns a GradientFunction. This allows gradient function authors to control which inputs/outputs of the forward op to keep around by increasing the ref-count on AbstractTensorHandle.
118
+
We provide a registry to store the mapping from op type to factory functions that return the `GradientFunction` for an op’s instance. The factory function takes as input the `ForwardOperation`, which contains metadata from the forward operation, and returns a `GradientFunction`. This allows gradient function authors to control which inputs/outputs of the forward op to keep around by increasing the ref-count on `AbstractTensorHandle`.
119
119
120
-
Additionally, we provide a utility function RegisterNotDifferentiable to mark an op as non-differentiable. This can be used to implement tf.no\_gradient. We also provide a NotDifferentiableGradientFunction which returns nullptr output gradients. This can be used to implement tf.stop\_gradient.
120
+
Additionally, we provide a utility function `RegisterNotDifferentiable` to mark an op as non-differentiable. This can be used to implement tf.no\_gradient. We also provide a `NotDifferentiableGradientFunction` which returns nullptr output gradients. This can be used to implement `tf.stop\_gradient`.
The API for C++ Tape is very similar to python’s tf.GradientTape. The implementation for this interface is almost entirely shared with the C++ tape in c/eager/tape.h.
174
+
The API for C++ `Tape` is very similar to python’s `tf.GradientTape`. The implementation for this interface is almost entirely shared with the C++ tape in `c/eager/tape.h`.
175
175
176
176
177
177
```
@@ -251,7 +251,7 @@ Status ExpGradModel(AbstractContext* ctx,
251
251
252
252
**Some details on memory management**
253
253
254
-
AbstractTensorHandle provides Ref and Unref methods which can be used to manage its lifecycle. Gradient functions and the tape follow these guidelines for memory safety:
254
+
`AbstractTensorHandle` provides `Ref` and `Unref` methods which can be used to manage its lifecycle. Gradient functions and the tape follow these guidelines for memory safety:
255
255
256
256
257
257
@@ -268,7 +268,7 @@ If manual management of ref-counts becomes too cumbersome we could consider addi
268
268
269
269
#### tf.custom\_gradient
270
270
271
-
A custom GradientFunction for a set of inputs/outputs can be registered using Tape::RecordOperation similar to a gradient function looked up from the gradient registry.
271
+
A custom `GradientFunction` for a set of inputs/outputs can be registered using `Tape::RecordOperation` similar to a gradient function looked up from the gradient registry.
272
272
273
273
Example:
274
274
@@ -311,7 +311,7 @@ Status ExpWithCustomGrad(AbstractContext* ctx,
311
311
312
312
#### tf.recompute\_grad
313
313
314
-
tf.recompute\_grad is an application of tf.custom\_gradient where we do not record the forward pass on the tape so that we are not holding on to forward pass tensors in memory. (In tf.custom\_gradient we allow recording the forward pass on the tape in order for higher-order derivatives to work for cases where the custom gradient function uses intermediate tensors from the forward pass.) This is implemented by executing the forward pass outside the tape (managed by a higher layer) and registering a gradient function that re-runs the forward pass and computes gradients. The same behavior can be achieved using this tape.
314
+
`tf.recompute\_grad` is an application of `tf.custom\_gradient` where we do not record the forward pass on the tape so that we are not holding on to forward pass tensors in memory. (In `tf.custom\_gradient` we allow recording the forward pass on the tape in order for higher-order derivatives to work for cases where the custom gradient function uses intermediate tensors from the forward pass.) This is implemented by executing the forward pass outside the tape (managed by a higher layer) and registering a gradient function that re-runs the forward pass and computes gradients. The same behavior can be achieved using this tape.
315
315
316
316
317
317
#### Nested tapes and higher-order derivatives
@@ -321,12 +321,12 @@ Higher order derivatives are computed by either using a persistent tape or by tr
321
321
322
322
#### Skipping gradients for certain op inputs (skip\_input\_indices)
323
323
324
-
A [small set](https://cs.opensource.google/search?q=f:py$%20skip_input_indices&sq=&ss=tensorflow%2Ftensorflow) of python gradient functions have been optimized to not return gradients for inputs which are not tracked under the tape. This is beneficial in eager mode where unneeded gradients cannot be pruned during execution. In the C++ tape, we support this by providing a skip\_input\_indices field on the ForwardOperation which stores the list of input indices which are either not watched or have an untrainable dtype.
324
+
A [small set](https://cs.opensource.google/search?q=f:py$%20skip_input_indices&sq=&ss=tensorflow%2Ftensorflow) of python gradient functions have been optimized to not return gradients for inputs which are not tracked under the tape. This is beneficial in eager mode where unneeded gradients cannot be pruned during execution. In the C++ tape, we support this by providing a `skip\_input\_indices` field on the `ForwardOperation` which stores the list of input indices which are either not watched or have an untrainable dtype.
325
325
326
326
327
327
#### Automatic variable tracking
328
328
329
-
In python, if a variable is accessed inside a tf.GradientTape’s scope it is automatically tracked, i.e. Tape::Watch is called for the DT\_RESOURCE tensor backing the variable on behalf of the user. For now we will leave this out as a higher layer feature and require that variable handles are explicitly tracked by a higher layer. We can revisit this later if needed.
329
+
In python, if a variable is accessed inside a `tf.GradientTape`s scope it is automatically tracked, i.e. `Tape::Watch` is called for the `DT\_RESOURCE` tensor backing the variable on behalf of the user. For now we will leave this out as a higher layer feature and require that variable handles are explicitly tracked by a higher layer. We can revisit this later if needed.
330
330
331
331
332
332
#### tf.function and functional control flow gradients [out of scope for now]
@@ -336,7 +336,7 @@ Eventually we plan to implement tf.function and functional control flow gradient
336
336
337
337
#### IndexedSlices
338
338
339
-
Gradient function of a [gather](https://cs.opensource.google/tensorflow/tensorflow/+/master:tensorflow/python/ops/array_grad.py;l=582;drc=d724cdbce69862cbb80617dd6573baa83bd3e819) returns IndexedSlices for efficiency. We need to support IndexedSlices as part of the input and output gradients of a gradient function. Currently there is no good C++ representation for these. One possible representation would be to wrap the component tensors in an IndexedSlicesTensorHandle that subclasses `AbstractTensorHandle`. This way IndexedSlices would be transparent to the tape. The C++ gen ops can choose to handle IndexedSlices appropriately or simply densify them by calling a C++ equivalent of `convert_to_tensor`.
339
+
Gradient function of a [gather](https://cs.opensource.google/tensorflow/tensorflow/+/master:tensorflow/python/ops/array_grad.py;l=582;drc=d724cdbce69862cbb80617dd6573baa83bd3e819) returns `IndexedSlices` for efficiency. We need to support `IndexedSlices` as part of the input and output gradients of a gradient function. Currently there is no good C++ representation for these. One possible representation would be to wrap the component tensors in an `IndexedSlicesTensorHandle` that subclasses `AbstractTensorHandle`. This way IndexedSlices would be transparent to the tape. The C++ gen ops can choose to handle `IndexedSlices` appropriately or simply densify them by calling a C++ equivalent of `convert_to_tensor`.
340
340
341
341
342
342
```
@@ -392,24 +392,20 @@ Because gradient operations will no longer be run in Python, we expect the pytho
392
392
393
393
#### Framework
394
394
395
-
The framework is a fairly lightweight implementation of the existing Tape interface in c/eager/tape.h which was already templated to support different C++ types for gradient functions and tensors. We have been making necessary improvements to the base framework to support this project, e.g., moving [default zeros creation logic](https://cs.opensource.google/tensorflow/tensorflow/+/ee95d88c4eb92311a8c57a8f78378235e1909d08) from the tape to respective gradient functions.
395
+
The framework is a fairly lightweight implementation of the existing Tape interface in `c/eager/tape.h` which was already templated to support different C++ types for gradient functions and tensors. We have been making necessary improvements to the base framework to support this project, e.g., moving [default zeros creation logic](https://cs.opensource.google/tensorflow/tensorflow/+/ee95d88c4eb92311a8c57a8f78378235e1909d08) from the tape to respective gradient functions.
396
396
397
397
398
398
#### Gradient functions
399
399
400
-
We plan to implement gradient functions under tensorflow/c/gradients. As a proof-of-concept we implemented an MLP for MNIST using an experimental python binding (see python/framework/experimental/tape.py). For that we implemented gradient functions for MatMul, Add, ReLu and Softmax. We are currently working on implementing gradient functions needed for ResNet50.
400
+
We plan to implement gradient functions under `tensorflow/c/gradients`. As a proof-of-concept we implemented an MLP for MNIST using an experimental python binding (see python/framework/experimental/tape.py). For that we implemented gradient functions for MatMul, Add, ReLu and Softmax. We are currently working on implementing gradient functions needed for ResNet50.
401
401
402
402
We further plan to publish a guide for inviting contributions and setup a spreadsheet or some such for tracking completeness.
403
403
404
404
405
405
#### Python rollout
406
406
407
-
We plan to rollout C++ gradient functions incrementally. We will port the existing pybind C++ tape to use the new tape implementation. The GradientFunction for ops with registered C++ gradients will directly be called. For others, we will simply register a GradientFunction that calls the python gradient function.
407
+
We plan to rollout C++ gradient functions incrementally. We will port the existing pybind C++ tape to use the new tape implementation. The `GradientFunction` for ops with registered C++ gradients will directly be called. For others, we will simply register a `GradientFunction` that calls the python gradient function.
408
408
409
409
## Acknowledgements
410
410
411
411
alextp@ motivated the design and provided an initial prototype for this project. amturati@ implemented various gradient functions to get a MLP training on MNIST. vnvo2409@ has been working on making framework improvements and further implementing C++ gradient functions.
0 commit comments