Skip to content
This repository was archived by the owner on Jul 10, 2025. It is now read-only.

Commit 72cd8f5

Browse files
committed
Add details on TapeContext
1 parent 52c71a7 commit 72cd8f5

File tree

1 file changed

+58
-0
lines changed

1 file changed

+58
-0
lines changed

rfcs/20201201-cpp-gradients.md

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,64 @@ Status ExpGradModel(AbstractContext* ctx,
251251
return Status::OK();
252252
}
253253
```
254+
**TapeContext**
255+
256+
In the final C++ API, we don’t expect users to have to directly call `RecordOperation` for each op. We would provide an `AbstractContext` implementation for the tape which would trace ops on the tape and delegate execution to a backing context. A prototype for this is available [here](https://cs.opensource.google/tensorflow/tensorflow/+/master:tensorflow/c/experimental/gradients/tape/).
257+
258+
259+
```
260+
class TapeContext : public AbstractContext {
261+
public:
262+
explicit TapeContext(AbstractContext* parent_ctx, Tape*, const GradientRegistry&);
263+
// Skipping overridden methods.
264+
private:
265+
AbstractContext* parent_ctx_;
266+
Tape* tape_;
267+
const GradientRegistry& registry_;
268+
};
269+
270+
class TapeOperation : public AbstractOperation {
271+
public:
272+
explicit TapeOperation(AbstractOperation* parent_op, Tape*, const GradientRegistry&);
273+
// Skipping overridden methods.
274+
private:
275+
AbstractOperation* parent_op_;
276+
ForwardOperation forward_op_;
277+
Tape* tape_;
278+
const GradientRegistry& registry_;
279+
};
280+
```
281+
282+
283+
`TapeOperation` would populate the `ForwardOperation` object and record the operation on the tape in the call to `AbstractOperation::Execute`:
284+
285+
286+
```
287+
Status TapeOperation::AddInput(AbstractTensorHandle* input) {
288+
TF_RETURN_IF_ERROR(parent_op_->AddInput(input));
289+
forward_op_.inputs.push_back(input);
290+
return Status::OK();
291+
}
292+
293+
Status TapeOperation::Execute(absl::Span<AbstractTensorHandle*> retvals,
294+
int* num_retvals) {
295+
TF_RETURN_IF_ERROR(parent_op_->Execute(retvals, num_retvals));
296+
for (int i = 0; i < *num_retvals; i++) {
297+
forward_op_.outputs.push_back(retvals[i]);
298+
}
299+
// Populate forward_op_.skip_input_indices here.
300+
std::unique_ptr<GradientFunction> backward_fn;
301+
TF_RETURN_IF_ERROR(registry_.Lookup(forward_op_, &backward_fn));
302+
tape_->RecordOperation(forward_op_.inputs, forward_op_.outputs,
303+
backward_fn.release(), parent_op_->Name());
304+
return Status::OK();
305+
}
306+
```
307+
308+
This way the same C++ gen_ops code can be used to execute ops with/without a tape.
309+
310+
311+
Note: This interface is subject to change.
254312

255313

256314
**Some details on memory management**

0 commit comments

Comments
 (0)