diff --git a/taichi/program/kernel.cpp b/taichi/program/kernel.cpp index 46ddb3f2fc5d8..7fab2eee7850d 100644 --- a/taichi/program/kernel.cpp +++ b/taichi/program/kernel.cpp @@ -30,6 +30,33 @@ Kernel::Kernel(Program &program, this->init(program, [&] { return func(this); }, primal_name, autodiff_mode); } +Kernel::Kernel(Program &program, + Block *block, + const std::string &primal_name, + AutodiffMode autodiff_mode) { + this->arch = program.compile_config().arch; + this->autodiff_mode = autodiff_mode; + this->ir = std::unique_ptr(block); + this->program = &program; + is_accessor = false; + ir_is_ast_ = false; // CHI IR + + TI_ASSERT(this->ir->is()); + this->ir->as()->set_parent_callable(this); + + if (autodiff_mode == AutodiffMode::kNone) { + name = primal_name; + } else if (autodiff_mode == AutodiffMode::kForward) { + name = primal_name + "_forward_grad"; + } else if (autodiff_mode == AutodiffMode::kReverse) { + name = primal_name + "_reverse_grad"; + } else if (autodiff_mode == AutodiffMode::kCheckAutodiffValid) { + name = primal_name + "_validate_grad"; + } else { + TI_ERROR("Unsupported autodiff mode"); + } +} + Kernel::Kernel(Program &program, std::unique_ptr &&ir, const std::string &primal_name, diff --git a/taichi/program/kernel.h b/taichi/program/kernel.h index 68b9f9b797cd1..92e58a6fc170c 100644 --- a/taichi/program/kernel.h +++ b/taichi/program/kernel.h @@ -35,6 +35,11 @@ class TI_DLL_EXPORT Kernel : public Callable { const std::string &name = "", AutodiffMode autodiff_mode = AutodiffMode::kNone); + Kernel(Program &program, + Block *block, + const std::string &name = "", + AutodiffMode autodiff_mode = AutodiffMode::kNone); + bool ir_is_ast() const { return ir_is_ast_; }