diff --git a/.gitmodules b/.gitmodules index 4239b5cda2b..cab5d000826 100644 --- a/.gitmodules +++ b/.gitmodules @@ -133,3 +133,6 @@ [submodule "third_party/kineto"] path = third_party/kineto url = https://github.com/pytorch/kineto +[submodule "third_party/json"] + path = third_party/json + url = git@github.com:nlohmann/json.git diff --git a/aten/src/ATen/CheckpointTensorImpl.cpp b/aten/src/ATen/CheckpointTensorImpl.cpp new file mode 100644 index 00000000000..dc04f8ba3a1 --- /dev/null +++ b/aten/src/ATen/CheckpointTensorImpl.cpp @@ -0,0 +1,762 @@ +#include +#include +#include +#include +#include + +namespace at { + +using Clock = std::chrono::high_resolution_clock; +using Time = Clock::time_point; +using Duration = Clock::duration; + +DispatchKeySet convert_key_set(const DispatchKeySet& t) { + TORCH_CHECK(!t.has(DispatchKey::Checkpoint)); + auto ret = t.add(DispatchKey::Checkpoint); + return ret; +} + +void External::release_resources() { + value->pool->release_external(); + value.reset(); +} + +CheckpointTensorImpl* get_cpti(const Tensor& t) { + return dynamic_cast(t.unsafeGetTensorImpl()); +} + +CheckpointTensorImpl* must_get_cpti(const Tensor& t) { + auto ret = get_cpti(t); + TORCH_CHECK(ret); + return ret; +} + +size_t memory_sum = 0; +size_t memory_max = 0; +size_t memory_count = 0; + +void reset_memory_stat() { + memory_sum = 0; + memory_max = 0; + memory_count = 0; +} + +// todo: use defensive programming to make this only pass on dense tensor. +// todo: rn this track memory from all device. but if we are cpointing on gpu, we dont care about cpu. +size_t memory(const Tensor& t) { + if (! t.has_storage()) { + return 0; + } + auto& storage = t.storage(); + size_t res = storage.nbytes(); + memory_sum += res; + memory_max = std::max(memory_max, res); + memory_count += 1; + return res; +} + +// todo: generalize this to other device? e.g. we might want checkpointing on pure cpu. +long current_memory() { + auto device_stat = c10::cuda::CUDACachingAllocator::getDeviceStats(0); + return device_stat.allocated_bytes[0].current; +} + +bool use_log_ = false; +bool use_profile_ = false; +long base_compute_time_ = 0; +long remat_compute_time_ = 0; +long search_time_ = 0; +long cost_time_ = 0; + +CheckpointPool pool; +void CheckpointPool::add(const intrusive_ptr& p) { + if (p->memory > 0 && (memory_count == 0 || !ignore_small_tensors || p->memory >= 0.01 * double(memory_sum/memory_count))) { + aps.push_back(weak_intrusive_ptr(p)); + } +} + +void CheckpointPool::auto_evict() { + if (has_memory_budget) { + while (current_memory() > memory_budget) { + evict(); + } + } +} + +void CheckpointPool::evict() { + time_t pre = std::chrono::system_clock::now(); + TORCH_CHECK(aps.size() > 0); + // shrunk: either something has been evicted or the pools have gotten smaller + bool shrunk = false; + int evict_idx = -1; + double evict_cost = INFINITY; + time_t current_time = std::chrono::system_clock::now(); + auto remove_from_aps = [&](size_t i) { + aps[i] = aps[aps.size() - 1]; + aps.pop_back(); + }; + std::uniform_int_distribution<> distrib(1, 1 * std::max(1, static_cast(std::sqrt(aps.size())))); + // sampling a random independent subset of all evictable tensors to find the cheapest tensor to evict. + for (size_t i = 0; i < aps.size();) { + auto cannot_evict = [&]() { + shrunk = true; + remove_from_aps(i); + }; + auto ap_strong = aps[i].lock(); + if (!ap_strong.defined()) { + cannot_evict(); + } + else if (ap_strong->ecn) { + cannot_evict(); + } + else { + if (ap_strong->evictable()) { + double cost = ap_strong->cost(current_time); + if (cost < evict_cost) { + evict_cost = cost; + evict_idx = i; + } + } + + if (sample_tensors) { + i += distrib(gen); + } else { + i += 1; + } + } + } + if (evict_idx == -1) { + TORCH_CHECK(shrunk); + } else { + auto evict_from_idx = [&](size_t idx) { + auto ap_strong = aps[idx].lock(); + TORCH_CHECK(ap_strong.defined()); + ap_strong->evict(); + remove_from_aps(evict_idx); + }; + evict_from_idx(evict_idx); + } + time_t post = std::chrono::system_clock::now(); + search_time_ += (post - pre).count(); +} + +// todo: make this a function of Checkpointpool +// should we traverse all externals in chronological order or reverse chronological order? +// my intuition tell me it should be reversed, because the reversed order prioritize the newer external, +// which has tensor more near it unevicted (because of staleness). +// if we go with chronological order, those tensors might be evicted. +void CheckpointPool::clear_checkpointpool() { + while (!exts.empty()) { + if (auto e = exts.back().lock()) { + e->value->pin(); + } + exts.pop_back(); + } + aps.clear(); +} + +Tensor uncheckpoint(const strong& input) { + return input->get(); +} + +Tensors uncheckpoint(const strongs& inputs) { + Tensors ret; + ret.reserve(inputs.size()); + for (const strong& input : inputs) { + ret.push_back(uncheckpoint(input)); + } + return ret; +}; + +Tensors try_checkpoint(const Tensors& inputs) { + Tensors ret; + ret.reserve(inputs.size()); + for (const Tensor& input : inputs) { + ret.push_back(at::native::try_checkpoint(input)); + } + return ret; +} + +void Rematerializer::remat() { + // TODO: refactor using RAII for exception safety. + for (const strong& s : inputs) { + s->pool->lock(); + } + Tensors ts = uncheckpoint(inputs); + time_t pre = std::chrono::system_clock::now(); + auto ret = func(ts); + time_t post = std::chrono::system_clock::now(); + pool.auto_evict(); + remat_compute_time_ += (post - pre).count(); + TORCH_CHECK(ret.size() == outputs.size()); + for (size_t i = 0; i < outputs.size(); ++i) { + if (auto output_cell = outputs[i].lock()) { + output_cell->fill(ret[i]); + } + } + ecn.reset(); + for (const strong& s : inputs) { + s->pool->unlock(); + } +} + +ecn_ptr Rematerializer::get_ecn() { + if (!ecn) { + ecn = ecn_ptr::make(CheckpointInfo(compute_cost)); + } + return ecn; +} + +CheckpointInfo merge_cpi(CheckpointInfo l, CheckpointInfo r) { + return CheckpointInfo(l.compute_cost + r.compute_cost); +} + +std::set AliasPool::neighbor_ecn() { + std::set ptr_set; + int size = neighbors.size(); + for (size_t i = 0; i < size;) { + if (auto cptc = neighbors[i].lock()) { + if (cptc->pool->ecn) { + ptr_set.insert(cptc->pool->ecn); + } + ++i; + } else { + neighbors[i] = neighbors[size - 1]; + --size; + } + } + if (size < neighbors.size()) { + neighbors.erase(neighbors.begin() + size); + } + return ptr_set; +} + + +double AliasPool::cost(time_t current_time) { + TORCH_CHECK(evictable()); + time_t pre = std::chrono::system_clock::now(); + auto cpi = CheckpointInfo(head_remat->compute_cost); + auto ecns = neighbor_ecn(); + for (const auto& necn : ecns) { + cpi = merge_cpi(cpi, get_t(necn)); + } + auto ret = cpi.cost(memory, (current_time - last_used_time).count()); + time_t post = std::chrono::system_clock::now(); + cost_time_ += (post - pre).count(); + return ret; +} + +void AliasPool::evict() { + TORCH_CHECK(!ecn); + ecn = head_remat->get_ecn(); + auto ecns = neighbor_ecn(); + for (const auto& necn : ecns) { + merge(merge_cpi, ecn, necn); + } + TORCH_CHECK(lock_count == 0); + for (const weak& w : tensors) { + if (auto cell = w.lock()) { + cell->evict(); + } + } +} + +void AliasPool::set_not_evicted(const intrusive_ptr& self) { + if (ecn) { + TORCH_CHECK(head_remat); + auto cpi = get_t(ecn); + update_t(ecn, CheckpointInfo(cpi.compute_cost - head_remat->compute_cost)); + ecn.reset(); + pool.add(self); + } +} + +void CheckpointTensorCell::fill(const Tensor& t) { + if (!(this->t)) { + TORCH_CHECK(!at::native::is_checkpoint(t)); + TORCH_CHECK(!t.key_set().has(DispatchKey::Checkpoint)) + this->t = std::make_unique(t.detach()); + pool->set_not_evicted(pool); + if (!defined) { + defined = true; + is_undefined_tensor = !t.defined(); + key_set_ = t.key_set(); + if (t.requires_grad()) { + key_set_ = key_set_.add(DispatchKey::Autograd); + } + dtype_ = t.dtype(); + optional_device_ = t.optional_device(); + } + } +} + +Tensor CheckpointTensorImpl::get() const { + return ref->value->value->get(); +} + +CheckpointTensorImpl::CheckpointTensorImpl(const Tensor& t) : CheckpointTensorImpl(intrusive_ptr::make(t)) { } + +CheckpointTensorImpl::CheckpointTensorImpl(const Ref>& ref) : + TensorImpl(convert_key_set(ref->value->value->key_set()), + ref->value->value->dtype(), + ref->value->value->optional_device()), + ref(ref) { + if (key_set().has(DispatchKey::Autograd)) { + set_requires_grad(true); + } +} + +intrusive_ptr CheckpointTensorImpl::shallow_copy_and_detach(const VariableVersion& version_counter, + bool allow_tensor_metadata_change) const { + // I was once a smartasss and thought I didnt need to copy, + // for the value is immutable. + // Turnout I am a dumbass: + // the autogradmeta is mutable. + auto ret = intrusive_ptr::make(ref); + if (use_log_) { + DTRLogCopy(ret->counter_name(), counter_name()); + } + return ret; +} + +intrusive_ptr CheckpointTensorImpl::shallow_copy_and_detach(VariableVersion&& version_counter, + bool allow_tensor_metadata_change) const { + return shallow_copy_and_detach(version_counter, allow_tensor_metadata_change); +} + +void CheckpointTensorImpl::shallow_copy_from(const c10::intrusive_ptr& impl) { + TORCH_CHECK(key_set() == impl->key_set()); + auto* cpti = dynamic_cast(impl.get()); + TORCH_CHECK(cpti != nullptr); + ref->value = cpti->ref->value; + if (use_log_) { + DTRLogCopyFrom(counter_name(), cpti->counter_name()); + } +} + +int CheckpointTensorImpl::counter = 0; + +bool is_alias(const Tensor& l, const Tensor& r) { + return l.defined() && r.defined() && l.is_alias_of(r); +} + +// return an index for alias. +// we dont care which one because they all lead to the same alias pool. +// return -1 for no alias. +// may god forgive my sin. +int get_alias(const Tensors& ts, const Tensor& t) { + if (t.defined()) { + for (size_t i = 0; i < ts.size(); ++i) { + if (ts[i].defined() && t.is_alias_of(ts[i])) { + return i; + } + } + } + return -1; +} + +void add_neighbor(const strong& l, const strong& r) { + l->pool->neighbors.push_back(weak(r)); + r->pool->neighbors.push_back(weak(l)); +} + +struct MakeRawResult { + std::vector> outputs; + std::vector aliases; + duration_t time; + intrusive_ptr rematerializer; +}; + +MakeRawResult make_raw(const rematerialize_function_t& remat_f, + const strongs& inputs) { + for (const strong& s : inputs) { + s->pool->lock(); + } + Tensors raw_inputs = uncheckpoint(inputs); + time_t pre = std::chrono::system_clock::now(); + auto raw_outputs = remat_f(raw_inputs); + time_t post = std::chrono::system_clock::now(); + pool.auto_evict(); + base_compute_time_ += (post - pre).count(); + std::vector> outputs; + std::vector aliases; + weaks weak_outputs; + auto remat = intrusive_ptr::make(Unsafe(), remat_f, inputs, post - pre); + + for (const Tensor& t : raw_outputs) { + intrusive_ptr alias_pool; + int alias = get_alias(raw_inputs, t); + if (alias == -1) { + auto m = memory(t); + alias_pool = intrusive_ptr::make(Unsafe(), remat, m); + pool.add(alias_pool); + } + else { + alias_pool = inputs[alias]->pool; + if (alias_pool->head_remat) { + alias_pool->head_remat->compute_cost += (post - pre); + } + } + auto e = intrusive_ptr::make(t, alias_pool, remat); + pool.exts.push_back(weak_intrusive_ptr(e)); + alias_pool->tensors.push_back(weak(e->value)); + outputs.push_back(e); + aliases.push_back(alias); + weak_outputs.push_back(weak(outputs.back()->value)); + } + remat->outputs = weak_outputs; + for (size_t i = 0; i < inputs.size(); ++i) { + for (size_t j = 0; j < outputs.size(); ++j) { + if (!is_alias(raw_inputs[i], raw_outputs[j])) { + add_neighbor(inputs[i], outputs[j]->value); + } + } + } + for (const strong& s : inputs) { + s->pool->unlock(); + } + return {outputs, aliases, post - pre, remat}; +} + +std::string from_time(duration_t t) { + return std::to_string(std::chrono::nanoseconds(t).count()); +} + +Tensors CheckpointTensorImpl::make(const std::string& name, + const rematerialize_function_t& remat, + const Tensors& inputs) { + Tensors checkpointed_inputs = try_checkpoint(inputs); + auto input_size = checkpointed_inputs.size(); + + strongs input_values; + input_values.reserve(input_size); + + std::vector args; + args.reserve(input_size); + + for (const Tensor& t: checkpointed_inputs) { + auto* cpti = must_get_cpti(t); + input_values.push_back(cpti->ref->value->value); + if (use_log_) { + args.push_back(cpti->counter_name()); + } + } + + auto ret = make_raw(remat, input_values); + + Tensors tensors; + tensors.reserve(ret.outputs.size()); + + for (const auto& t: ret.outputs) { + auto cp = Tensor(intrusive_ptr::make(t)); + tensors.push_back(cp); + } + + if (use_log_) { + std::vector res; + res.reserve(ret.outputs.size()); + + for (const auto& tensor : tensors) { + res.push_back(get_cpti(tensor)->counter_name()); + } + + DTRLogCall(res, name, args, from_time(ret.time)); + for (size_t i = 0; i < tensors.size(); ++i) { + Tensor t = tensors[i]; + auto cpti = get_cpti(t); + DTRLogMemory(cpti->counter_name(), cpti->ref->value->value->memory()); + DTRLogAlias(cpti->counter_name(), ret.aliases[i]); + } + } + + return tensors; +} + +void CheckpointTensorImpl::release_resources() { + if (use_log_) { + DTRLogRelease(counter_name()); + } + ref.reset(); +} + +struct CheckpointFunctionsImpl: CheckpointFunctions { + void new_log(std::string str) override { + DTRLogger::logger().out = std::ofstream(DTRLogger::logger().get_filename(str)); + } + void annotate_log(std::string str) override { + if (use_log_) { + json j; + j[INSTRUCTION] = "ANNOTATE"; + j[ANNOTATION] = str; + DTRLogger::logger().log(j.dump()); + } + } + void toggle_log(bool b) override { + use_log_ = b; + } + void clear_checkpointpool() override { + pool.clear_checkpointpool(); + } + void unset_memory_budget() override { + pool.has_memory_budget = false; + } + void set_memory_budget(long budget) override { + pool.memory_budget = budget; + pool.has_memory_budget = true; + } + void toggle_sampling(bool sample) override { + pool.sample_tensors = sample; + } + void toggle_ignore_small_tensors(bool ignore) override { + pool.ignore_small_tensors = ignore; + } + void toggle_profile(bool profile) override { + use_profile_ = profile; + } + void reset_profile() override { + base_compute_time_ = 0; + remat_compute_time_ = 0; + search_time_ = 0; + cost_time_ = 0; + } + long base_compute_time() override { + return base_compute_time_; + } + long remat_compute_time() override { + return remat_compute_time_; + } + long compute_time() override { + return base_compute_time() + remat_compute_time(); + } + long cost_time() override { + return cost_time_; + } + long search_time() override { + return search_time_; + } + long loop_time() override { + return search_time() - cost_time(); + } +}; + +CheckpointFunctions* GetCheckpointFunctions() { + static CheckpointFunctionsImpl cpfi; + return &cpfi; +} + +namespace native { + +Tensor checkpoint(const Tensor& t) { + TORCH_CHECK(!is_checkpoint(t)); + auto cpti = intrusive_ptr::make(t); + if (use_log_) { + DTRLogConstant(cpti->counter_name()); + DTRLogMemory(cpti->counter_name(), cpti->ref->value->value->memory()); + } + return Tensor(cpti); +} + +Tensor uncheckpoint(const Tensor& t) { + auto cpti = must_get_cpti(t); + return cpti->get(); +} + +Tensor try_uncheckpoint(const Tensor& t) { + return is_checkpoint(t) ? uncheckpoint(t) : t; +} + +Tensor decheckpoint(const Tensor& t) { + return try_uncheckpoint(t); +} + +void pin(Tensor& t) { + must_get_cpti(t)->ref->value->value->pin(); +} + +bool is_checkpoint(const Tensor& t) { + return get_cpti(t) != nullptr; +} + +Tensor try_checkpoint(const Tensor& t) { + return is_checkpoint(t) ? t : checkpoint(t); +} + +} + +// map over the tensor in the ivalue. +// weird stuff. seems like i cant write a generic function over all list :( +template +IValue map_ivalue(const F& f, const IValue& iv) { + if (iv.isTensor()) { + return f(iv.toTensor()); + } else if (iv.isScalar() || iv.isBool() || iv.isDevice() || iv.isNone() || iv.isIntList() || iv.isBoolList() || iv.isDoubleList()) { + return iv; + } else if (iv.isTensorList()) { + std::vector ts; + for (const auto& t: iv.toTensorList()) { + ts.push_back(f(t)); + } + return ts; + } + else { + TORCH_CHECK(false, "unknown ivalue type: ", *(iv.type())); + throw; + } +} + +Ref> cell_from_tensor(const Tensor& t) { + return must_get_cpti(t)->ref; +} + +// note: please be deterministic (same input give same output/same mutation on input no matter how many time it is called). +// if it is not deterministic, at least dont change the shape of the output (if input shape not changed). +// otherwise the code will break. +// Right now uncheckedpointed tensor is converted into checkpoint tensor before going into CheckpointTensorImpl::make. +// It seems like you can not convert to save time, but it break our logging code. +// the code is a bit cleaner this way, and this extra information maybe helpful. +// So there is two interface: CheckpointTensor's pure Tensors -> Tensors interface, and a stack mutating interface. +// We have to convert twice. +// In particular, we implement a stack mutating interface for checkpointedtensor, +// by implementing a Tensors -> Tensors interface for ordinary tensor (the conversion is handled by tensors::make). +// we implement that by converting it back to stack mutation. +// Additionally, since the stack contain IValue instead of Tensors, +// we have to inject/extracted the Tensors to/from the saved IValue +// everytime we convert Tensors to/from stack. +// Reminder: you can convert IValue to/from Tensor, but you should not do that in here, +// as IValue may hold zero or more Tensor. +// the only way to construct/destruct an IValue should be map_ivalue. +void CheckpointFallback(const c10::OperatorHandle& op, torch::jit::Stack* stack) { + size_t before_size = stack->size(); + auto s = op.schema(); + // std::cout << s << std::endl; + size_t num_arg = s.arguments().size(); + // todo: use s.hasAnyAliasInfo() to figure out alias info instead of doing a runtime loop. + std::vector checkpoint_reversed_ivalue_in; // popping them from the jit stack and pushing them back will reverse stuff. + std::vector checkpoint_reversed_ivalue_in_mutable; + // but should we really reverse stuff? there is a peek() function which doesnt. + // ezyang seems to want to replace stack impl from std::vector to some sort of list, + // so slower peek() though. + for (size_t i = 0; i < num_arg; ++i) { + checkpoint_reversed_ivalue_in.push_back(torch::jit::pop(stack)); + const auto& aliasInfo = s.arguments()[s.arguments().size() - 1 - i].alias_info(); + checkpoint_reversed_ivalue_in_mutable.push_back(aliasInfo && aliasInfo.value().isWrite()); + } + Tensors original_tensors_in; + strongs checkpoint_tensors_in; + std::vector checkpoint_tensors_in_mutable; + auto it = checkpoint_reversed_ivalue_in.rbegin(); + auto mit = checkpoint_reversed_ivalue_in_mutable.rbegin(); + while (it != checkpoint_reversed_ivalue_in.rend()) { + TORCH_CHECK(mit != checkpoint_reversed_ivalue_in_mutable.rend()); + map_ivalue([&](const Tensor& t) { + original_tensors_in.push_back(t); + auto tcp = native::try_checkpoint(t); + auto* cpti = must_get_cpti(tcp); + checkpoint_tensors_in.push_back(cpti->ref->value->value); + checkpoint_tensors_in_mutable.push_back(*mit); + return t; // dont care value + }, *it); + ++it; + ++mit; + } + // todo: modify on heap instead of pushing and popping? + struct JitRemat { + struct Boxed { + c10::OperatorHandle op; + std::vector checkpoint_reversed_ivalue_in; + std::vector checkpoint_reversed_ivalue_out; + std::vector checkpoint_tensors_in_mutable; + bool initial_call = true; + Boxed(const c10::OperatorHandle& op, + const std::vector& checkpoint_reversed_ivalue_in, + const std::vector& checkpoint_tensors_in_mutable) : + op(op), + checkpoint_reversed_ivalue_in(checkpoint_reversed_ivalue_in), + checkpoint_tensors_in_mutable(checkpoint_tensors_in_mutable) { } + Tensors operator()(const Tensors& remat_in) { + torch::jit::Stack stack; + size_t count = 0; + Tensors copied_values; + for (auto it = checkpoint_reversed_ivalue_in.rbegin(); it != checkpoint_reversed_ivalue_in.rend(); ++it) { + torch::jit::push(&stack, + map_ivalue([&](const Tensor&) { + auto rem_at = remat_in.at(count); + auto ret = [&]() { + if (checkpoint_tensors_in_mutable.at(count)) { + auto cloned = rem_at.clone(); + copied_values.push_back(cloned); + return cloned; + } else { + return rem_at; + } + }(); + ++count; + return ret; + }, *it)); + } + TORCH_CHECK(count == remat_in.size()); + op.callBoxed(&stack); + Tensors remat_out; + auto s = op.schema(); + size_t num_ret = s.returns().size(); + for (size_t i = 0; i < num_ret; ++i) { + checkpoint_reversed_ivalue_out.push_back(torch::jit::pop(&stack)); + } + for (auto it = checkpoint_reversed_ivalue_out.rbegin(); it != checkpoint_reversed_ivalue_out.rend(); ++it) { + map_ivalue([&](const Tensor& t) { + remat_out.push_back(t); + return t; // dont care value + }, *it); + } + for (const Tensor& t: copied_values) { + remat_out.push_back(t); + } + if (initial_call) { + initial_call = false; + } else { + checkpoint_reversed_ivalue_out.clear(); + } + initial_call = false; + TORCH_CHECK(stack.empty()); + return remat_out; + } + }; + std::shared_ptr boxed; + JitRemat(const c10::OperatorHandle& op, + const std::vector& checkpoint_reversed_ivalue_in, + const std::vector& checkpoint_tensors_in_mutable) : + boxed(std::make_shared(op, checkpoint_reversed_ivalue_in, checkpoint_tensors_in_mutable)) { } + Tensors operator()(const Tensors& remat_in) { return (*boxed)(remat_in); } + } remat(op, checkpoint_reversed_ivalue_in, checkpoint_tensors_in_mutable); + auto make_raw_result = make_raw(remat, checkpoint_tensors_in); + size_t count = 0; + for (auto it = remat.boxed->checkpoint_reversed_ivalue_out.rbegin(); it != remat.boxed->checkpoint_reversed_ivalue_out.rend(); ++it) { + torch::jit::push(stack, + map_ivalue([&](const Tensor&) { + auto out = make_raw_result.outputs.at(count); + ++count; + return Tensor(intrusive_ptr::make(out)); + }, *it)); + } + for (size_t i = 0; i < checkpoint_tensors_in.size(); ++i) { + if (checkpoint_tensors_in_mutable.at(i)) { + cell_from_tensor(original_tensors_in.at(i))->value = make_raw_result.outputs.at(count); + ++count; + } + } + // clear the stored ivalue output, so the tensor returned can actually be freed from memory (if evicted). + remat.boxed->checkpoint_reversed_ivalue_out.clear(); + TORCH_CHECK(before_size - s.arguments().size() + s.returns().size() == stack->size()); + TORCH_CHECK(count == make_raw_result.outputs.size()); +} + +// todo: i can also use a torch library impl instead of calling fallback explicitly. should i do that? +struct Register { + Register() { + static auto registration = c10::Dispatcher::singleton().registerFallback(DispatchKey::Checkpoint, + KernelFunction::makeFromBoxedFunction<&CheckpointFallback>(), + "checkpoint"); + } +} register_checkpoint; + +} diff --git a/aten/src/ATen/CheckpointTensorImpl.h b/aten/src/ATen/CheckpointTensorImpl.h new file mode 100644 index 00000000000..ffeabe1205c --- /dev/null +++ b/aten/src/ATen/CheckpointTensorImpl.h @@ -0,0 +1,478 @@ +#pragma once + +#include + +#include +#include +#include +#include + +// [System Description]: +// Every Tensor is managed by a CheckpointTensor, +// that describe how it is computed, (the function and the inputs) +// And might optionally hold the tensor value. +// The tensor value might be dropped, and when requested later, recomputed and cached again. + +// [DTR and autodiff]: +// unlike other gradient checkpointing scheme that is coupled to automatic differentation, +// DTR is decoupled from automatic checkpointing, and thus is more like a tensor level cache +// then memory optimization on automatic differentiation. +// this mean we can use DTR without using autodiff. +// Implementation wise, this mean DTR will work below autodiff: +// let autodiff do it's thing, and DTR will override the forward and backward propagation tensor. + +// [Corner Cases]: +// A CheckpointedTensor might require_grad. +// In this case the underlying data must not require_grad, +// as we want backpropagation on the outer, uncheckpoined level. +// see Note [DTR and autodiff]. +// A CheckpointedTensor might be constant. +// In this case it is unevictable. +// An operator might return multiple output. +// In this case the computation info (rematerializer) is shared between all of them, +// And when the function get computed again all value get cached. +// An operator might mutate input value. +// To combat this, we COW the operator, and wrap CheckpopintTensor with a Ref. +// By doing this the inner CheckpointTensor is kept purely functional. +// An operator might try to mutate uncheckpointed tensor. +// We do not support this and will error. +// An operator might create aliases. +// We track alias in AliasPool. +// Each AliasPool hold a set of tensor that is alias to eachother. +// An operator might try to create Alias to an unevictable tensor. +// In such a case the output tensor is unevictable. +// An operator might try to mutate Tensor with Alias. +// We do not support this case an will error if a Tensor has any alive Alias. +// However it could be done without a major redesign of the system - +// Each AliasPool will hold weak pointers to the External Reference. +// When alias mutation occur, +// we make a rematerialize_function that take in the base tensor (other tensor alias from) +// and output all the new value of the aliases, then update the Ref. +// Of course, the cleaner way is to not support this. +// Shame on those who use this feature. + +// Memory Safety: +// The objects here will have lots of backedges. +// In order to collect memory when computation is completed, +// We require that all strong pointer is of the form of value -> input. +// This ensure that everything will be released if there is no external ref whatsoever. + +// Optimization: +// We treat tensor that has no external reference differently - +// When all external reference to a Checkpoint Tensor is lost, +// we will try to immediately evict it. + +// Note: to code fast I do not use RAII and just assume the code will not try to recover from exception. +// It should be easy to fix though. + +namespace at { + +using Clock = std::chrono::high_resolution_clock; +using Time = Clock::time_point; +using Duration = Clock::duration; + +// TODO: using a pool allocator might make more sense - no need to allocate and delete each pointer individually. +// TODO: egg simply store all node in a vector and use vector index as reference. maybe do that? +template +struct EquivalentClassNode : intrusive_ptr_target { + explicit EquivalentClassNode(const T& t) : t_unsafe(t) { } + mutable intrusive_ptr parent; + bool is_root() { + return !parent; + } + void release_resources() override { + parent.reset(); + } + T t_unsafe; +}; + +template +T& get_t(const intrusive_ptr>& n) { + return find_root(n)->t_unsafe; +} + +template +static void update_t(const intrusive_ptr>& n, const T& t) { + find_root(n)->t_unsafe = t; +} + +template +intrusive_ptr> find_root(const intrusive_ptr>& n) { + if (n->is_root()) { + return n; + } else { + n->parent = find_root(n->parent); + return n->parent; + } +} + +template +intrusive_ptr> merge(const std::function& merge_t, + const intrusive_ptr>& lhs, + const intrusive_ptr>& rhs) { + auto l = find_root(lhs); + auto r = find_root(rhs); + if (l == r) { + return l; + } + l->parent = r; + r->t_unsafe = merge_t(l->t_unsafe, r->t_unsafe); + return r; +} + +size_t memory(const Tensor& t); + +template +struct RefCell final : intrusive_ptr_target { + mutable T value; + void release_resources() final { + static_release_resources(value); + } + RefCell(const T& t) : value(t) { } +}; + +template +using Ref = intrusive_ptr>; + +template +void static_release_resources(intrusive_ptr& ptr) { + ptr.reset(); +} + +class CheckpointTensorCell; +using strong = intrusive_ptr; +using strongs = std::vector; +using weak = weak_intrusive_ptr; +using weaks = std::vector; +using Tensors = std::vector; +// do we really need this? can we just use operatorhandle? +// after some thoght i think we shouldnt be that coupled. +using rematerialize_function_t = std::function; +// this doesnt look right. a mutate function can also return input. +// i guess mutate_function simply is not as fundamental as rematerialize_function. +// the former desugar to the latter. +// using mutate_function_t = std::function; + +using time_t = std::chrono::time_point; +using duration_t = std::chrono::system_clock::duration; +struct CheckpointInfo { + duration_t compute_cost; + double cost(size_t memory, size_t staleness) const { + TORCH_CHECK(memory > 0); + TORCH_CHECK(staleness > 0); + return compute_cost.count() / static_cast(memory * staleness); + } + CheckpointInfo(duration_t compute_cost) : + compute_cost(compute_cost) { + } +}; + +// ecn represent a evicted tensor group. +// it is a set of tensor that are evicted, and if two evicted tensor are input -> output to each other, +// they must be in an ecn. +// note: we try to support removal from ecn by subtracting compute_cost and memory. +// this will create suprious connection but that should be fine empircally. +// below is an example of a suprious connection: +// a -> b, a -> c +// a, b, c got evicted so belong to a single ecn. +// a got rematerialized. +// b, c still belong to a single ecn although there is no connection. +using ecn_ptr = intrusive_ptr>; + +struct Unsafe { }; + +// The rematerializer could be called to reinvoke an operator. +// Tensor point to remat which point to Tensor. +// To build the cycle remat support a default constructor, +// And allow you to fill in the member later. +struct Rematerializer : intrusive_ptr_target { + rematerialize_function_t func; + strongs inputs; + weaks outputs; + duration_t compute_cost; + // when some output in here get evicted, they should belong to this ecn. + // a rematerializer have to track this, + // because when multiple output of a rematerializer get evicted, + // we only want to count the compute cost once. + ecn_ptr ecn; + Rematerializer(const Unsafe&, + const rematerialize_function_t& func, + const strongs& inputs, + duration_t compute_cost) : + func(func), + inputs(inputs), + compute_cost(compute_cost) { + } + void release_resources() final { + func = rematerialize_function_t(); + inputs.clear(); + outputs.clear(); + } + void remat(); + ecn_ptr get_ecn(); +}; + +// Track all Tensor that share the same Storage. +// This is the atomic level of eviction - when evicting, everything here will get evicted. +// When an AliasPool is evicted, the Storage of the underlying tensor must be freed. +// Additionally, the AliasPool contain weak pointer to all children of tensors, +// in order to compute the score of evicting a Storage. +struct AliasPool : intrusive_ptr_target { + weaks tensors; + weaks neighbors; + std::set neighbor_ecn(); + // get() might hold some raw Tensor, rendering them unevictable. + // it is likely that get() will run out of memory, and when it does so, it will try to evict. + // so, it is crucial that we dont try to evict those tensors - doing so will not evict anything. + // lock_count count how many time a tensor is referenced by get. + size_t lock_count = 0; + size_t external_count = 0; + void lock() { + ++lock_count; + } + void unlock() { + TORCH_CHECK(lock_count > 0); + --lock_count; + } + intrusive_ptr head_remat; + size_t memory; + time_t last_used_time; + // An aliaspool cant register itself to the checkpointpool - you have to do it yourself. + AliasPool(const Unsafe&, intrusive_ptr head_remat, size_t memory) : + head_remat(head_remat), + memory(memory), + last_used_time(std::chrono::system_clock::now()) { + } + // hold the evicted tensor group if it is evicted. + // is empty if it is not evicted + ecn_ptr ecn; + double cost(time_t current_time); + bool evictable() const { + return lock_count == 0 && head_remat && !ecn; + } + void evict(); + void register_external() { + ++external_count; + } + void release_external() { + TORCH_CHECK(external_count > 0); + --external_count; + if (external_count == 0 && evictable()) { + evict(); + } + } + // if it was evicted, refresh it. otherwise do nothing. + // have to check so, because when we rematerialize a single tensor in an aliaspool, + // we will set it to non-evicted, and when we rematerialize it's tensor they will also reset this. + void set_not_evicted(const intrusive_ptr& self); + void release_resources() final { + tensors.clear(); + neighbors.clear(); + head_remat.reset(); + } +}; + +struct CheckpointTensorCell : intrusive_ptr_target { + std::unique_ptr t; + bool defined = false; + bool is_undefined_tensor; + DispatchKeySet key_set_; + DispatchKeySet key_set() const { + TORCH_CHECK(defined); + return key_set_; + } + caffe2::TypeMeta dtype_; + caffe2::TypeMeta dtype() const { + TORCH_CHECK(defined); + return dtype_; + } + c10::optional optional_device_; + c10::optional optional_device() const { + TORCH_CHECK(defined); + return optional_device_; + } + // A Tensor is evictable iff it's AliasPool is evictable. + // A evictable tensor must have Rematerializer. + intrusive_ptr pool; + intrusive_ptr remat; + void evict() { + TORCH_CHECK(remat); + t.reset(); + } + void fill(const Tensor& t); + explicit CheckpointTensorCell(const Tensor& t, const intrusive_ptr& pool) : pool(pool) { + fill(t); + } + explicit CheckpointTensorCell(const Tensor& t, + const intrusive_ptr& pool, + const intrusive_ptr& remat) : + pool(pool), remat(remat) { + fill(t); + } + size_t memory() { + TORCH_CHECK(defined); + return pool->memory; + } + Tensor get() { + if (! t) { + TORCH_CHECK(remat); + remat->remat(); + } + TORCH_CHECK(t); + TORCH_CHECK(! t->key_set().has(DispatchKey::Checkpoint)); + TORCH_CHECK(! t->key_set().has(DispatchKey::Autograd)); + pool->last_used_time = std::chrono::system_clock::now(); + return *t; + } + // pin() make a cell unevictable. + // This allow us to deallocate the rematerializer, + // and potentially tensors that are only used by the rematerializer. + // todo: pin() in the paper mean lock(). find better name? + void pin() { + get(); + pool->head_remat.reset(); + remat.reset(); + } + void release_resources() final { + t.reset(); + pool.reset(); + remat.reset(); + } +}; + +// An external reference. +// Each strong will have at most one external reference. +// By keeping such an invariant, whenever an external reference die, +// We know that the underlying strong is only used internally. +// Thus, when it die we can apply optimization like banishing/infinite staleness. +// We keep this invariant by only allowing CheckpointTensorImpl to make new External, +// When new CheckpointTensorImpl is constructed. +struct External : intrusive_ptr_target { + External(const strong& value) : value(value) { + value->pool->register_external(); + } + External(const Tensor& value) : + External(strong::make(value, + intrusive_ptr::make(Unsafe(), + intrusive_ptr(), + memory(value)))) { } + External(const Tensor& value, + const intrusive_ptr& pool, + const intrusive_ptr& remat) : + External(strong::make(value, pool, remat)) { } + strong value; + void release_resources() override; +}; + +struct TORCH_API CheckpointTensorImpl : TensorImpl { + int id = gen_counter(); + static int counter; + static int gen_counter() { + return counter++; + } + std::string counter_name() const { + return std::string("x") + std::to_string(id); + } + + Ref> ref; + + void release_resources() final; + + explicit CheckpointTensorImpl(const Ref>& ref); + + explicit CheckpointTensorImpl(const intrusive_ptr& e) : + CheckpointTensorImpl(Ref>::make(e)) { } + + explicit CheckpointTensorImpl(const Tensor& t); + + Tensor get() const; + + static Tensors make(const std::string& name, + const rematerialize_function_t& remat, + const Tensors& inputs); + + intrusive_ptr shallow_copy_and_detach(const VariableVersion& version_counter, + bool allow_tensor_metadata_change) const override; + + intrusive_ptr shallow_copy_and_detach(VariableVersion&& version_counter, + bool allow_tensor_metadata_change) const override; + + void shallow_copy_from(const c10::intrusive_ptr& impl) override; + + int64_t dim() const override { + return get().dim(); + } + int64_t numel() const override { + return get().numel(); + } + IntArrayRef sizes() const override { + return get().sizes(); + } + int64_t size(int64_t d) const override { + return get().size(d); + } + IntArrayRef strides() const override { + return get().strides(); + } + int64_t stride(int64_t d) const override { + return get().stride(d); + } + bool is_contiguous(at::MemoryFormat memory_format) const override { + return get().is_contiguous(memory_format); + } + bool has_storage() const override { + return false; + } +}; + +// CheckpointPool manage all the CheckpointTensor. +// It allow one to: +// 0: Search over all aliaspool to evict tensors. +// 1: pin all the tensors. +struct CheckpointPool { + std::vector> aps; + std::vector> exts; + std::random_device rd; + std::mt19937 gen = std::mt19937(rd()); + // whether to take a square-root sample of the pool during an eviction loop + bool sample_tensors = true; + // ignore tensors < 1% of the average tensor size + bool ignore_small_tensors = true; + bool has_memory_budget = false; + long memory_budget; + void evict(); + void auto_evict(); + void clear_checkpointpool(); + void add(const intrusive_ptr&); +}; + +struct CheckpointFunctions; +TORCH_API CheckpointFunctions* GetCheckpointFunctions(); + +// using function pointer to pass through linking boundary +struct CheckpointFunctions { + virtual ~CheckpointFunctions() { } +#define DefineCheckpointFunction(RETURN, NAME, ...) \ + virtual RETURN NAME(__VA_ARGS__) = 0; \ + static RETURN static_ ## NAME(__VA_ARGS__) { \ + return GetCheckpointFunctions()->NAME(__VA_ARGS__); \ + } + DefineCheckpointFunction(void, new_log, std::string(str)); + DefineCheckpointFunction(void, annotate_log, std::string(str)); + DefineCheckpointFunction(void, toggle_log, bool(log)); + DefineCheckpointFunction(void, clear_checkpointpool); + DefineCheckpointFunction(void, unset_memory_budget); + DefineCheckpointFunction(void, set_memory_budget, long(budget)); + DefineCheckpointFunction(void, toggle_sampling, bool(sample)); + DefineCheckpointFunction(void, toggle_ignore_small_tensors, bool(ignore)); + DefineCheckpointFunction(void, reset_profile); + DefineCheckpointFunction(void, toggle_profile, bool(profile)); + DefineCheckpointFunction(long, base_compute_time); + DefineCheckpointFunction(long, remat_compute_time); + DefineCheckpointFunction(long, compute_time); + DefineCheckpointFunction(long, cost_time); + DefineCheckpointFunction(long, search_time); + DefineCheckpointFunction(long, loop_time); +}; + +} diff --git a/aten/src/ATen/Logger.h b/aten/src/ATen/Logger.h new file mode 100644 index 00000000000..f4ff4a489c9 --- /dev/null +++ b/aten/src/ATen/Logger.h @@ -0,0 +1,128 @@ +#pragma once + +#include +#include +#include <../../../third_party/json/single_include/nlohmann/json.hpp> + +namespace at { + +struct DTRLogger { + std::string time_prefix; + std::ofstream out; + static std::string get_time_prefix() { + std::time_t t = std::time(nullptr); + std::tm* tm = std::localtime(&t); + return + std::to_string(1900+tm->tm_year) + "-" + + std::to_string(1+tm->tm_mon) + "-" + + std::to_string(tm->tm_mday) + "-" + + std::to_string(tm->tm_hour) + "-" + + std::to_string(tm->tm_min) + "-" + + std::to_string(tm->tm_sec); + } + std::string get_filename(const std::string& name) { + return time_prefix + "-" + name + ".log"; + } + DTRLogger() : time_prefix(get_time_prefix()), out(get_filename("default")) { } + void log(const std::string& str) { + out << str << std::endl; + } + static DTRLogger& logger() { + static DTRLogger ret; + return ret; + } + +}; + +using json = nlohmann::json; +const std::string INSTRUCTION = "INSTRUCTION"; +const std::string ANNOTATION = "ANNOTATION"; +const std::string RELEASE = "RELEASE"; +const std::string PIN = "PIN"; +const std::string TIME = "TIME"; +const std::string ARGS = "ARGS"; +const std::string MEMORY = "MEMORY"; +const std::string ALIAS = "ALIAS"; +const std::string NAME = "NAME"; +const std::string CONSTANT = "CONSTANT"; + +void DTRLogConstant(const std::string& name) { + json j; + j[INSTRUCTION] = CONSTANT; + j[NAME] = name; + DTRLogger::logger().log(j.dump()); +} + +void DTRLogMemory(const std::string& name, size_t memory) { + json j; + j[INSTRUCTION] = MEMORY; + j[NAME] = name; + j[MEMORY] = std::to_string(memory); + DTRLogger::logger().log(j.dump()); +} + +void DTRLogAlias(const std::string& name, int index) { + json j; + j[INSTRUCTION] = ALIAS; + j[NAME] = name; + j[ALIAS] = std::to_string(index); + DTRLogger::logger().log(j.dump()); +} + +void DTRLogCopyFrom(const std::string& to, const std::string& from) { + json j; + j[INSTRUCTION] = "COPY_FROM"; + j["DST"] = to; + j["SRC"] = from; + DTRLogger::logger().log(j.dump()); +} + +void DTRLogCopy(const std::string& new_name, const std::string& old_name) { + json j; + j[INSTRUCTION] = "COPY"; + j["DST"] = new_name; + j["SRC"] = old_name; + DTRLogger::logger().log(j.dump()); +} + +void DTRLogMutate(const std::string& name, + const std::vector& args, + const std::vector& mutate, + const std::string& time) { + json j; + j[INSTRUCTION] = "MUTATE"; + j[NAME] = name; + j[ARGS] = args; + j["MUTATE"] = mutate; + j[TIME] = time; + DTRLogger::logger().log(j.dump()); +} + +void DTRLogRelease(const std::string& name) { + json j; + j[INSTRUCTION] = RELEASE; + j[NAME] = name; + DTRLogger::logger().log(j.dump()); +} + +void DTRLogPin(const std::string& name) { + json j; + j[INSTRUCTION] = PIN; + j[NAME] = name; + DTRLogger::logger().log(j.dump()); +} + +void DTRLogCall(const std::vector& res, + const std::string& name, + const std::vector& args, + const std::string& time) { + json j; + j[INSTRUCTION] = "CALL"; + j[NAME] = name; + j["RESULT"] = res; + j[ARGS] = args; + j[TIME] = time; + DTRLogger::logger().log(j.dump()); +} + +} diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 602293eec88..c8a4e7ff87b 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -1,5 +1,34 @@ # See README.md in this directory for more guidance +# convert a tensor to a checkpoint tensor. +# if the input is already a checkpoint tensor, +# checkpoint will fail while try_checkpoint will +# simply return the input. +- func: checkpoint(Tensor self) -> Tensor + variants: method + +- func: try_checkpoint(Tensor self) -> Tensor + variants: method + +- func: is_checkpoint(Tensor self) -> bool + variants: method + +# convert checkpointed tensor into normal tensor. +# uncheckpoint assume the input is checkpointed and will fail otherwise. +# try_uncheckpoint return the input if it is not checkpointed. +- func: uncheckpoint(Tensor self) -> Tensor + variants: method + +- func: try_uncheckpoint(Tensor self) -> Tensor + variants: method + +# deprecated. call try_uncheckpoint instead. +- func: decheckpoint(Tensor self) -> Tensor + variants: method + +- func: pin(Tensor(a!) self) -> () + variants: method + # *********NB: _cast_* operators are DEPRECATED and will be removed # eventually. These were previously used before TorchScript IR supported # representing ScalarType's. They are now superseded by usage of @@ -3939,6 +3968,8 @@ - func: _has_compatible_shallow_copy_type(Tensor self, Tensor from) -> bool variants: function +# dispatch: +# DefaultBackend: _has_compatible_shallow_copy_type - func: _unique(Tensor self, bool sorted=True, bool return_inverse=False) -> (Tensor, Tensor) variants: function @@ -4782,18 +4813,26 @@ use_c10_dispatcher: hacky_wrapper_for_legacy_signatures variants: method device_guard: False + dispatch: + DefaultBackend: to - func: to.device(Tensor self, Device device, ScalarType dtype, bool non_blocking=False, bool copy=False, MemoryFormat? memory_format=None) -> Tensor variants: method device_guard: False + dispatch: + DefaultBackend: to - func: to.dtype(Tensor self, ScalarType dtype, bool non_blocking=False, bool copy=False, MemoryFormat? memory_format=None) -> Tensor variants: method device_guard: False + dispatch: + DefaultBackend: to - func: to.other(Tensor self, Tensor other, bool non_blocking=False, bool copy=False, MemoryFormat? memory_format=None) -> Tensor variants: method device_guard: False + dispatch: + DefaultBackend: to - func: meshgrid(Tensor[] tensors) -> Tensor[] @@ -5117,6 +5156,8 @@ - func: bitwise_and.Tensor(Tensor self, Tensor other) -> Tensor variants: method, function + dispatch: + CPU, CUDA: bitwise_and - func: bitwise_and_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) variants: method diff --git a/aten/src/ATen/templates/TensorBody.h b/aten/src/ATen/templates/TensorBody.h index 96867b0be86..3ff1f1fcf71 100644 --- a/aten/src/ATen/templates/TensorBody.h +++ b/aten/src/ATen/templates/TensorBody.h @@ -343,6 +343,11 @@ class TORCH_API Tensor { return impl_->device(); } + /// Returns a `Tensor`'s device. + inline c10::optional optional_device() const { + return impl_->optional_device(); + } + /// Returns a `Tensor`'s device index. int64_t get_device() const; diff --git a/c10/core/DispatchKey.cpp b/c10/core/DispatchKey.cpp index 2ae5a87333b..23ca6d95093 100644 --- a/c10/core/DispatchKey.cpp +++ b/c10/core/DispatchKey.cpp @@ -74,6 +74,9 @@ const char* toString(DispatchKey t) { case DispatchKey::Meta: return "Meta"; + case DispatchKey::Checkpoint: + return "Checkpoint"; + case DispatchKey::Autograd: return "Autograd"; case DispatchKey::AutogradCPU: diff --git a/c10/core/DispatchKey.h b/c10/core/DispatchKey.h index 40d952c67a4..6c864512362 100644 --- a/c10/core/DispatchKey.h +++ b/c10/core/DispatchKey.h @@ -196,6 +196,9 @@ enum class DispatchKey : uint8_t { // constituent parts. Named, + // Checkpoint must go after Autograd. This way, Autograd will hook ad outside of CheckpointTensor. + Checkpoint, + // Note [Alias Dispatch Key : Autograd] // All backends are oblivious to autograd; autograd is handled as a // layer which happens on top of all backends. It inspects the autograd diff --git a/c10/core/TensorImpl.h b/c10/core/TensorImpl.h index 06b3f4fc527..77480b8119b 100644 --- a/c10/core/TensorImpl.h +++ b/c10/core/TensorImpl.h @@ -571,6 +571,10 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { return *device_opt_; } + c10::optional optional_device() const { + return device_opt_; + } + Layout layout() const { // NB: This method is not virtual and avoid dispatches for perf. if (is_sparse()) { @@ -1017,6 +1021,12 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { * compatible with SparseCUDA. */ inline bool has_compatible_shallow_copy_type(DispatchKeySet from) { + if (key_set_ == from) { + return true; + } + if (key_set_.has(DispatchKey::Checkpoint) || from.has(DispatchKey::Checkpoint)) { + return false; + } auto is_dense = [](DispatchKeySet ts) { return ts.has(DispatchKey::CPU) || ts.has(DispatchKey::CUDA) || ts.has(DispatchKey::HIP) || ts.has(DispatchKey::XPU); @@ -1026,7 +1036,7 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { ts.has(DispatchKey::SparseCUDA) || ts.has(DispatchKey::SparseHIP) || ts.has(DispatchKey::SparseXPU); }; - return (key_set_ == from) || (is_dense(key_set_) && is_dense(from)) || (is_sparse(key_set_) && is_sparse(from)); + return (is_dense(key_set_) && is_dense(from)) || (is_sparse(key_set_) && is_sparse(from)); } /** diff --git a/third_party/json b/third_party/json new file mode 160000 index 00000000000..19843b038ca --- /dev/null +++ b/third_party/json @@ -0,0 +1 @@ +Subproject commit 19843b038caa463164d6f89ea1b2765fae7552e9 diff --git a/tools/codegen/gen.py b/tools/codegen/gen.py index 3dbadae56be..8049d42b105 100644 --- a/tools/codegen/gen.py +++ b/tools/codegen/gen.py @@ -919,6 +919,7 @@ def make_file_manager(install_dir: str) -> FileManager: DispatchKey.QuantizedCUDA, DispatchKey.Math, DispatchKey.DefaultBackend, + DispatchKey.Checkpoint, # Meta is a magic key: it is automatically generated for structured # kernels DispatchKey.Meta, diff --git a/tools/codegen/model.py b/tools/codegen/model.py index 4b03c6899ff..7b49e7c167c 100644 --- a/tools/codegen/model.py +++ b/tools/codegen/model.py @@ -79,6 +79,7 @@ class DispatchKey(Enum): SparseHIP = auto() SparseXPU = auto() NestedTensor = auto() + Checkpoint = auto() PrivateUse1 = auto() PrivateUse2 = auto() PrivateUse3 = auto() diff --git a/torch/_tensor_str.py b/torch/_tensor_str.py index 1aef783ee66..320c2ff0a88 100644 --- a/torch/_tensor_str.py +++ b/torch/_tensor_str.py @@ -222,6 +222,9 @@ def _tensor_str(self, indent): if self.numel() == 0: return '[]' + if self.is_checkpoint(): + self = self.uncheckpoint() + if self.has_names(): # There are two main codepaths (possibly more) that tensor printing goes through: # - tensor data can fit comfortably on screen diff --git a/torch/csrc/Module.cpp b/torch/csrc/Module.cpp index 6235664707d..d3e23c1d362 100644 --- a/torch/csrc/Module.cpp +++ b/torch/csrc/Module.cpp @@ -17,6 +17,7 @@ #include #include #include +#include #include #include @@ -104,6 +105,28 @@ static PyObject * THPModule_initNames(PyObject *self, PyObject *arg) } Py_RETURN_NONE; } + +void InitCheckpointFunctions(PyObject* module) { + auto py_module = py::reinterpret_borrow(module); +#define PY_FFI(name) py_module.def(#name, at::CheckpointFunctions::static_ ## name) + PY_FFI(new_log); + PY_FFI(annotate_log); + PY_FFI(toggle_log); + PY_FFI(clear_checkpointpool); + PY_FFI(unset_memory_budget); + PY_FFI(set_memory_budget); + PY_FFI(toggle_sampling); + PY_FFI(toggle_ignore_small_tensors); + PY_FFI(reset_profile); + PY_FFI(toggle_profile); + PY_FFI(base_compute_time); + PY_FFI(remat_compute_time); + PY_FFI(compute_time); + PY_FFI(cost_time); + PY_FFI(search_time); + PY_FFI(loop_time); +} + // // Callback for python part. Used for additional initialization of python classes static PyObject * THPModule_initExtension(PyObject *_unused, PyObject *shm_manager_path) @@ -141,6 +164,7 @@ static PyObject * THPModule_initExtension(PyObject *_unused, PyObject *shm_manag THPComplexDoubleStorage_postInit(module); THPComplexFloatStorage_postInit(module); THPAutograd_initFunctions(); + InitCheckpointFunctions(module); Py_RETURN_NONE; END_HANDLE_TH_ERRORS } diff --git a/torch/nn/functional.py b/torch/nn/functional.py index 78da6841037..0fe4ae60ef1 100644 --- a/torch/nn/functional.py +++ b/torch/nn/functional.py @@ -4761,6 +4761,8 @@ def multi_head_attention_forward( q = q.contiguous().view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1) if k is not None: + # must be reshape instead of view to work for the above two call. + # todo: this seems to work? k = k.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1) if v is not None: v = v.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1)