From 9b5e58192935bf697ae3c901657278e7f776b30a Mon Sep 17 00:00:00 2001 From: AD1024 Date: Mon, 7 Feb 2022 01:35:08 -0800 Subject: [PATCH] fix bert --- aten/src/ATen/native/Checkpoint.cpp | 58 ++++++++++++++++++++++ aten/src/ATen/native/native_functions.yaml | 25 ++++++++++ 2 files changed, 83 insertions(+) diff --git a/aten/src/ATen/native/Checkpoint.cpp b/aten/src/ATen/native/Checkpoint.cpp index d6e01b9ffc1..0c8b74ec5e9 100644 --- a/aten/src/ATen/native/Checkpoint.cpp +++ b/aten/src/ATen/native/Checkpoint.cpp @@ -137,6 +137,30 @@ Tensor checkpoint_to(at::Tensor const& a, c10::TensorOptions const& b, bool c, b return CheckpointTensorImpl::make("to", rt, {a})[0]; } +Tensor checkpoint_to(at::Tensor const& a, at::Tensor const& b, bool c, bool d, c10::optional e) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {vec.at(0).to(b, c, d, e)}; + }; + return CheckpointTensorImpl::make("to", rt, {a})[0]; +} + +Tensor checkpoint_to(at::Tensor const& a, c10::ScalarType b, bool c, bool d, c10::optional e) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {vec.at(0).to(b, c, d, e)}; + }; + return CheckpointTensorImpl::make("to", rt, {a})[0]; +} + +Tensor checkpoint_to(at::Tensor const& a, c10::Device b, c10::ScalarType c, bool d, bool e, c10::optional f) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {vec.at(0).to(b, c, d, e, f)}; + }; + return CheckpointTensorImpl::make("to", rt, {a})[0]; +} + Tensor checkpoint_div(const Tensor& a, const Tensor& b) { rematerialize_function_t rt = [=](const Tensors& vec) -> Tensors { @@ -703,6 +727,40 @@ Tensor checkpoint_sum_dim_IntList(const Tensor& a, c10::ArrayRef b, bool c return CheckpointTensorImpl::make("sum_dim_IntList", rt, {a})[0]; } +Tensor& checkpoint_transpose_(at::Tensor& a, long b, long c) { + mutate_function_t mt = + [=](const Tensors& vec) { + Tensor a_ = vec.at(0); + a_.transpose_(b, c); + }; + CheckpointTensorImpl::mutate("transpose_", mt, {a}, {0}); + return a; +} + +Tensor checkpoint_transpose(at::Tensor const& a, long b, long c) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::transpose(vec.at(0), b, c)}; + }; + return CheckpointTensorImpl::make("transpose", rt, {a})[0]; +} + +Tensor checkpoint_gelu(at::Tensor const& a) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::gelu(vec.at(0))}; + }; + return CheckpointTensorImpl::make("gelu", rt, {a})[0]; +} + +Tensor checkpoint_matmul(at::Tensor const& a, at::Tensor const& b) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::matmul(vec.at(0), vec.at(1))}; + }; + return CheckpointTensorImpl::make("matmul", rt, {a, b})[0]; +} + Tensor checkpoint_threshold(const Tensor& a, c10::Scalar b, c10::Scalar c) { rematerialize_function_t rt = [=](const Tensors& vec) -> Tensors { diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 6d54e1cc9cb..5f657a6ada0 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -2157,8 +2157,17 @@ - func: matmul(Tensor self, Tensor other) -> Tensor use_c10_dispatcher: full variants: function, method + dispatch: + CPU, CUDA: matmul + Checkpoint: checkpoint_matmul - func: matmul.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + #use_c10_dispatcher: full + #variants: function, method + #dispatch: + # CPU, CUDA: matmul_out + # Checkpoint: checkpoint_matmul_out + # there is some pt bug which disallow the above code. - func: matrix_rank.tol(Tensor self, float tol, bool symmetric=False) -> Tensor use_c10_dispatcher: full @@ -2906,6 +2915,7 @@ dispatch: CPU: gelu_cpu CUDA: gelu_cuda + Checkpoint: checkpoint_gelu - func: gelu_backward(Tensor grad, Tensor self) -> Tensor use_c10_dispatcher: full @@ -3426,6 +3436,9 @@ use_c10_dispatcher: full variants: function, method device_guard: False + dispatch: + CPU, CUDA: transpose + Checkpoint: checkpoint_transpose - func: transpose.Dimname(Tensor(a) self, Dimname dim0, Dimname dim1) -> Tensor(a) variants: function, method @@ -3441,6 +3454,9 @@ use_c10_dispatcher: full variants: method device_guard: False + dispatch: + CPU, CUDA: transpose_ + Checkpoint: checkpoint_transpose_ - func: _mkldnn_transpose_(Tensor(a!) self, int dim0, int dim1) -> Tensor(a!) use_c10_dispatcher: full @@ -4461,16 +4477,25 @@ use_c10_dispatcher: full variants: method device_guard: False + dispatch: + CPU, CUDA: to + Checkpoint: checkpoint_to - func: to.dtype(Tensor self, ScalarType dtype, bool non_blocking=False, bool copy=False, MemoryFormat? memory_format=None) -> Tensor use_c10_dispatcher: full variants: method device_guard: False + dispatch: + CPU, CUDA: to + Checkpoint: checkpoint_to - func: to.other(Tensor self, Tensor other, bool non_blocking=False, bool copy=False, MemoryFormat? memory_format=None) -> Tensor use_c10_dispatcher: full variants: method device_guard: False + dispatch: + CPU, CUDA: to + Checkpoint: checkpoint_to - func: meshgrid(Tensor[] tensors) -> Tensor[] use_c10_dispatcher: full