From dc8878f64adabe633e114e4b778c23e252a0c719 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9B=BE=E9=9B=A8=E9=AD=94=E7=90=86=E6=B2=99?= Date: Fri, 6 Mar 2020 11:57:11 -0800 Subject: [PATCH] Backend Cleanroom port (#13) Implements CheckpointTensor backend Hook up the overload with a passing test. (#15) * save * save * review * review * add back files Adding some overload (#16) * save * save * review * remove outdated file * save * can compile again * save * save * up up array * finsih the overloads * replace vec[n] with vec.at(n) make resnet pass again, adding more logs for logging simulator (#17) * commit * add more overloads * fix log * save Logging everything for resnet (#18) * save * save * save * save * use release instead of free Fix logger (special handling for small constants) (#19) * save * fix comment TreeLSTM overloads (#21) Overloads for mean and mean.dim (needed for densenet) Add overloads for U-Net Add ability to annotate log and make new log file (#25) Restore overloads needed for LSTM and GRU Implemented unrolled_gan (#23) * save * save * save Overload bitwise and for ParityACT Add various overloads for ACT model Log aliases of operator outputs Overloads and changes needed for transformer Overloads for topk functions Add overloads for deepspeech More overloads for ACT DTR implementation first pass (#37) save before change save before change add aliaspool design, and add weak/strong pointer discussion add more code rebase add allocator hook save metadata to prepare for eviction save refactor: move log to a seprate file add file raii save save comment save save save save find error, bisecting code save save address review comment address comment address comment fix segfault save save pin save save save save save refactor - remove stitch save save [ impl ] overload diag & mv refactor - remove stitch save save restore equivalentclassnode save save save 50% resnet here we go 50% resnet here we go save save save save save save More opts save Try removing pop-ing from loop save save save save mike try this mike try this save undo system optimization for now save save save fix compile error save save --- .gitmodules | 3 + aten/src/ATen/CheckpointTensorImpl.cpp | 661 ++++++++ aten/src/ATen/CheckpointTensorImpl.h | 465 ++++++ aten/src/ATen/Logger.h | 197 +++ aten/src/ATen/native/Activation.cpp | 90 ++ aten/src/ATen/native/Checkpoint.cpp | 1620 ++++++++++++++++++++ aten/src/ATen/native/TensorShape.cpp | 46 - aten/src/ATen/native/native_functions.yaml | 433 +++++- aten/src/ATen/templates/TensorBody.h | 3 + aten/src/ATen/templates/TensorMethods.cpp | 4 + c10/core/Backend.h | 87 +- c10/core/DispatchKey.cpp | 2 + c10/core/DispatchKey.h | 3 + c10/core/TensorImpl.h | 12 +- test.py | 4 + third_party/json | 1 + tools/autograd/templates/Functions.cpp | 44 - torch/csrc/utils/tensor_numpy.cpp | 3 +- torch/nn/functional.py | 8 +- 19 files changed, 3544 insertions(+), 142 deletions(-) create mode 100644 aten/src/ATen/CheckpointTensorImpl.cpp create mode 100644 aten/src/ATen/CheckpointTensorImpl.h create mode 100644 aten/src/ATen/Logger.h create mode 100644 aten/src/ATen/native/Checkpoint.cpp create mode 100644 test.py create mode 160000 third_party/json diff --git a/.gitmodules b/.gitmodules index 509ab94f1cf..ceb096acc9c 100644 --- a/.gitmodules +++ b/.gitmodules @@ -130,3 +130,6 @@ ignore = dirty path = third_party/tensorpipe url = https://github.com/pytorch/tensorpipe.git +[submodule "third_party/json"] + path = third_party/json + url = git@github.com:nlohmann/json.git diff --git a/aten/src/ATen/CheckpointTensorImpl.cpp b/aten/src/ATen/CheckpointTensorImpl.cpp new file mode 100644 index 00000000000..d8a669be4c1 --- /dev/null +++ b/aten/src/ATen/CheckpointTensorImpl.cpp @@ -0,0 +1,661 @@ +#include +#include +#include + +#include +#include +#include +#include + +namespace at { + +using Clock = std::chrono::high_resolution_clock; +using Time = Clock::time_point; +using Duration = Clock::duration; +using FinalTime = std::chrono::nanoseconds; + +struct PerfStats; + +struct Timer { + std::string name; + Time start; + Timer(std::string name, Time start) : name(name), start(start) {} + Timer() {} + ~Timer(); +}; + +constexpr bool stats = true; + +struct PerfStats { + using TimerStats = std::tuple; + Time start; + std::unordered_map calls; + std::vector timers; + + PerfStats() : start(Clock::now()), calls(0), timers() {} + + /*Timer track(std::string name) { + if (stats) { + auto it = this->calls.find(name); + if (it != this->calls.end()) { + it->second += 1; + } else { + this->calls.insert({name, 0}); + } + + return Timer(name, Clock::now()); + } + return Timer(); + }*/ + void track(const char*) { } + + ~PerfStats() { + if (!stats) { return; } + auto start = std::get<1>(this->timers[0]); + auto now = Clock::now(); + std::cout << "All done. Here are some perf stats fresh off the preses." << std::endl; + std::unordered_map durations; + + Duration total = now - this->start; + + // For now simple strategy, count up all the time taken + // by each "tagged call site". + for (auto timer : timers) { + auto name = std::get<0>(timer); + Duration duration = std::get<3>(timer); + auto it = durations.find(name); + if (it != durations.end()) { + it->second += duration; + } else { + durations.insert({name, duration}); + } + } + + std::vector> data; + + // Convert the durations + for (auto d : durations) { + // auto duration = std::chrono::duration_cast(d.second); + data.push_back(d); + } + + std::sort(data.begin(), data.end(), + [](const std::pair & a, const std::pair & b) -> bool { + return a.second > b.second; + }); + + for (auto d : data) { + auto duration = std::chrono::duration_cast(d.second); + auto total_duration = std::chrono::duration_cast(total); + double percentage = ((double)duration.count())/((double)total_duration.count()) * 100; + auto call_count = this->calls.find(d.first); + TORCH_CHECK(call_count != this->calls.end()); + std::cout << "CallSite: " << d.first << " CallCount: " << call_count->second << " Cost: " << duration.count() << "ns" << " (%" << percentage << ")" << std::endl; + } + } +}; + +static PerfStats STATS = PerfStats(); + +size_t memory_sum = 0; +size_t memory_max = 0; +size_t memory_count = 0; + +void reset_memory_stat() { + memory_sum = 0; + memory_max = 0; + memory_count = 0; +} + +inline size_t memory(const Tensor& t) { + if (! t.has_storage()) { + return 0; + } + auto& storage = t.storage(); + size_t res = storage.nbytes(); + memory_sum += res; + memory_max = std::max(memory_max, res); + memory_count += 1; + return res; +} + +Timer::~Timer() { + Time now = Clock::now(); + Duration elapsed = now - start; + PerfStats::TimerStats stats = { name , start, now, elapsed }; + STATS.timers.push_back(stats); +} + +CheckpointPool pool; +void CheckpointPool::add(const intrusive_ptr& p) { + if (p->memory > 0 && (memory_count == 0 || p->memory >= 0.01 * double(memory_sum/memory_count))) { + aps.push_back(weak_intrusive_ptr(p)); + } +} + +long current_memory() { + STATS.track("current_memory"); + auto device_stat = c10::cuda::CUDACachingAllocator::getDeviceStats(0); + return device_stat.allocated_bytes[0].current; +} + +void CheckpointPool::auto_evict() { + STATS.track("CheckpointPool::auto_evict"); + if (has_memory_budget) { + while (current_memory() > memory_budget) { + evict(); + } + } +} + +void CheckpointPool::evict() { + STATS.track("CheckpointPool::evict"); + TORCH_CHECK(aps.size() > 0); + bool shrinked = false; + int evict_idx = -1; + double evict_cost = INFINITY; + time_t current_time = std::chrono::system_clock::now(); + auto remove_from_aps = [&](size_t i) { + aps[i] = aps[aps.size() - 1]; + aps.pop_back(); + }; + std::uniform_int_distribution<> distrib(1, 1 * std::max(1, static_cast(std::sqrt(aps.size())))); + // sampling a random independent subset of all evictable tensors to find the cheapest tensor to evict. + for (size_t i = 0; i < aps.size();) { + auto cannot_evict = [&]() { + shrinked = true; + remove_from_aps(i); + }; + auto ap_strong = aps[i].lock(); + if (!ap_strong.defined()) { + cannot_evict(); + } + else if (ap_strong->ecn) { + cannot_evict(); + } + else { + if (ap_strong->evictable()) { + double cost = ap_strong->cost(current_time); + if (cost < evict_cost) { + evict_cost = cost; + evict_idx = i; + } + } + i += distrib(gen); + } + } + if (evict_idx == -1) { + TORCH_CHECK(shrinked); + } else { + auto evict_from_idx = [&](size_t idx) { + auto ap_strong = aps[idx].lock(); + TORCH_CHECK(ap_strong.defined()); + ap_strong->evict(); + remove_from_aps(evict_idx); + }; + evict_from_idx(evict_idx); + } +} + +CheckpointPool::CheckpointPool() { } + +bool use_log = false; +long compute_time_ = 0; + +namespace native { + +Tensor checkpoint(const Tensor& t) { + STATS.track("checkpoint"); + auto cpti = intrusive_ptr::make(t); + if (use_log) { + DTRLogConstant(cpti->counter_name()); + DTRLogMemory(cpti->counter_name(), cpti->ref->value->value->memory()); + } + return Tensor(cpti); +} + +Tensor uncheckpoint(const Tensor& t) { + STATS.track("uncheckpoint"); + auto* cpti = dynamic_cast(t.unsafeGetTensorImpl()); + TORCH_CHECK(cpti != nullptr); + return cpti->ref->value->value->get(); +} + +void pin(const Tensor& t) { + STATS.track("pin"); + auto* cpti = dynamic_cast(t.unsafeGetTensorImpl()); + TORCH_CHECK(cpti != nullptr); + cpti->ref->value->value->pin(); +} + +Tensor decheckpoint(const Tensor& t) { + STATS.track("decheckpoint"); + auto* cpti = dynamic_cast(t.unsafeGetTensorImpl()); + return cpti ? cpti->ref->value->value->get() : t; +} + +bool is_checkpoint(const Tensor& t) { + STATS.track("is_checkpoint"); + auto* cpti = dynamic_cast(t.unsafeGetTensorImpl()); + return cpti != nullptr; +} + +Tensor try_checkpoint(const Tensor& t) { + STATS.track("try_checkpiont"); + return is_checkpoint(t) ? t : checkpoint(t); +} + +void new_log(std::string str) { + DTRLogger::logger().out = std::ofstream(DTRLogger::logger().get_filename(str)); +} + +void annotate_log(std::string str) { + if (!use_log) { return; } + if (log_json) { + json j; + j[INSTRUCTION] = "ANNOTATE"; + j[ANNOTATION] = str; + DTRLogger::logger().log(j.dump()); + } else { + DTRLogger::logger().log("# " + str); + } +} + +void toggle_log(bool b) { + use_log = b; +} + +void clear_checkpointpool() { + while (likely(!pool.exts.empty())) { + if (auto e = pool.exts.back().lock()) { + e->value->pin(); + } + pool.exts.pop_back(); + } +} + +void unset_memory_budget() { + pool.has_memory_budget = false; +} + +void set_memory_budget(long budget) { + pool.memory_budget = budget; + pool.has_memory_budget = true; +} + +void reset_compute_time() { + compute_time_ = 0; +} + +long compute_time() { + return compute_time_; +} + +} + +[[inline]] +Tensor uncheckpoint(const strong& input) { + return input->get(); +} + +Tensors uncheckpoint(const strongs& inputs) { + STATS.track("uncheckpoint"); + Tensors ret; + ret.reserve(inputs.size()); + for (const strong& input : inputs) { + // inlined manually + ret.push_back(input->get()); + } + return ret; +}; + +Tensors try_checkpoint(const Tensors& inputs) { + STATS.track("try_checkpoint"); + Tensors ret; + ret.reserve(inputs.size()); + for (const Tensor& input : inputs) { + ret.push_back(at::native::try_checkpoint(input)); + } + return ret; +} + +CheckpointInfo merge_cpi(CheckpointInfo l, CheckpointInfo r) { + STATS.track("merge_cpi"); + return CheckpointInfo(l.compute_cost + r.compute_cost); +} + +void AliasPool::evict() { + STATS.track("AliasPool::evict"); + TORCH_CHECK(!ecn); + ecn = head_remat->get_ecn(); + auto ecns = neighbor_ecn(); + for (const auto& necn : ecns) { + merge(merge_cpi, ecn, necn); + } + // cudacaching allocator might be dead when program finished and is deallocating resources. + // auto b4 = current_memory(); + TORCH_CHECK(memory > 0); + TORCH_CHECK(lock_count == 0); + TORCH_CHECK(!is_evicted); + is_evicted = true; + for (const weak& w : tensors) { + if (auto cell = w.lock()) { + cell->evict(); + } + } + // TORCH_CHECK(current_memory() < b4); +} + +double AliasPool::cost(time_t current_time) { + auto cpi = head_remat->get_cpi(); + auto ecns = neighbor_ecn(); + for (const auto& necn : ecns) { + cpi = merge_cpi(cpi, get_t(necn)); + } + return cpi.cost(memory, (current_time - last_used_time).count()); +} + +void External::release_resources() { + value->pool->release_external(); + value.reset(); +} + +void Rematerializer::remat() { + // TODO: refactor using RAII for exception safety. + for (const strong& s : inputs) { + s->pool->lock(); + } + Tensors ts = uncheckpoint(inputs); + time_t pre = std::chrono::system_clock::now(); + auto ret = func(ts); + time_t post = std::chrono::system_clock::now(); + pool.auto_evict(); + compute_time_ += (post - pre).count(); + TORCH_CHECK(ret.size() == outputs.size()); + for (size_t i = 0; i < outputs.size(); ++i) { + if (auto output_cell = outputs[i].lock()) { + output_cell->fill(ret[i]); + } + } + ecn.reset(); + for (const strong& s : inputs) { + s->pool->unlock(); + } +} + +ecn_ptr Rematerializer::get_ecn() { + if (!ecn) { + ecn = ecn_ptr::make(CheckpointInfo(compute_cost)); + } + return ecn; +} + +CheckpointInfo Rematerializer::get_cpi() { + return CheckpointInfo(ecn ? duration_t(0) : compute_cost); +} + +std::set AliasPool::neighbor_ecn() { + STATS.track("AliasPool::neighbor_ecn"); + std::set ptr_set; + int size = neighbors.size(); + for (size_t i = 0; i < size;) { + if (auto cptc = neighbors[i].lock()) { + if (cptc->pool->ecn) { + ptr_set.insert(cptc->pool->ecn); + } + ++i; + } else { + neighbors[i] = neighbors[size - 1]; + size = size - 1; + } + } + if (size < neighbors.size()) { + neighbors.erase(neighbors.begin() + size); + } + return ptr_set; +} + +void AliasPool::set_not_evicted(const intrusive_ptr& self) { + if (unlikely(is_evicted)) { + STATS.track("AliasPool::set_not_evicted(inside)"); + is_evicted = false; + if (ecn) { + TORCH_CHECK(head_remat); + auto cpi = get_t(ecn); + update_t(ecn, CheckpointInfo(cpi.compute_cost - head_remat->compute_cost)); + ecn.reset(); + } + pool.add(self); + } +} + +void CheckpointTensorCell::fill(const Tensor& t) { + STATS.track("CheckpointTensorCell::fill"); + if (!(this->t)) { + this->t = std::make_unique(t.detach()); + pool->set_not_evicted(pool); + if (!defined) { + defined = true; + is_undefined_tensor = !t.defined(); + key_set_ = t.key_set(); + if (t.requires_grad()) { + key_set_ = key_set_.add(DispatchKey::Autograd); + } + dtype_ = t.dtype(); + optional_device_ = t.optional_device(); + } + } +} + +intrusive_ptr CheckpointTensorImpl::shallow_copy_and_detach(const VariableVersion& version_counter, + bool allow_tensor_metadata_change) const { + auto ret = intrusive_ptr::make(ref); + if (use_log) { + DTRLogCopy(ret->counter_name(), counter_name()); + } + return ret; +} + +void CheckpointTensorImpl::shallow_copy_from(const c10::intrusive_ptr& impl) { + STATS.track("CheckpointTensorCell::shallow_copy_from"); + TORCH_CHECK(impl->key_set().has(DispatchKey::CheckpointTensorId)); + auto* cpti = dynamic_cast(impl.get()); + TORCH_CHECK(cpti != nullptr); + ref->value = cpti->ref->value; + if (use_log) { + DTRLogCopyFrom(counter_name(), cpti->counter_name()); + } +} + +int CheckpointTensorImpl::counter = 0; + +bool is_alias(const Tensor& l, const Tensor& r) { + return l.defined() && r.defined() && l.is_alias_of(r); +} + +// return an index for alias. +// we dont care which one because they all lead to the same alias pool. +// return -1 for no alias. +// may god forgive my sin. +int get_alias(const Tensors& ts, const Tensor& t) { + if (t.defined()) { + for (size_t i = 0; i < ts.size(); ++i) { + if (ts[i].defined() && t.is_alias_of(ts[i])) { + return i; + } + } + } + return -1; +} + +struct MakeRawResult { + std::vector> outputs; + std::vector aliases; + duration_t time; + intrusive_ptr rematerializer; +}; + +void add_neighbor(const strong& l, const strong& r) { + l->pool->neighbors.push_back(weak(r)); + r->pool->neighbors.push_back(weak(l)); +} + +// remat take a single vector of tensors, +// while there are two vector, one storing nonconstants and one storing constants. +// the constants are small and they will not be considered for eviction. +// however, we have to stitch the two vectors together to pass it in remat. +// the size_t in constants decide the location to stitch them in, while input_values fill in the rest. +MakeRawResult make_raw(const rematerialize_function_t& remat_f, + const strongs& inputs) { + STATS.track("make_raw"); + for (const strong& s : inputs) { + s->pool->lock(); + } + Tensors raw_inputs = uncheckpoint(inputs); + time_t pre = std::chrono::system_clock::now(); + auto raw_outputs = remat_f(raw_inputs); + time_t post = std::chrono::system_clock::now(); + pool.auto_evict(); + compute_time_ += (post - pre).count(); + std::vector> outputs; + std::vector aliases; + weaks weak_outputs; + auto remat = intrusive_ptr::make(Unsafe(), remat_f, inputs, post - pre); + + for (const Tensor& t : raw_outputs) { + intrusive_ptr alias_pool; + int alias = get_alias(raw_inputs, t); + if (alias == -1) { + auto m = memory(t); + alias_pool = intrusive_ptr::make(Unsafe(), remat, m); + pool.add(alias_pool); + } + else { + alias_pool = inputs[alias]->pool; + if (alias_pool->head_remat) { + alias_pool->head_remat->compute_cost += (post - pre); + } + } + auto e = intrusive_ptr::make(t, alias_pool, remat); + pool.exts.push_back(weak_intrusive_ptr(e)); + alias_pool->tensors.push_back(weak(e->value)); + outputs.push_back(e); + aliases.push_back(alias); + weak_outputs.push_back(weak(outputs.back()->value)); + } + remat->outputs = weak_outputs; + for (size_t i = 0; i < inputs.size(); ++i) { + for (size_t j = 0; j < outputs.size(); ++j) { + if (!is_alias(raw_inputs[i], raw_outputs[j])) { + add_neighbor(inputs[i], outputs[j]->value); + } + } + } + for (const strong& s : inputs) { + s->pool->unlock(); + } + return {outputs, aliases, post - pre, remat}; +} + +std::string from_time(duration_t t) { + return std::to_string(std::chrono::nanoseconds(t).count()); +} + +Tensors CheckpointTensorImpl::make(const std::string& name, + const rematerialize_function_t& remat, + const Tensors& inputs) { + STATS.track("CheckPointTensorImpl::make"); + Tensors checkpointed_inputs = try_checkpoint(inputs); + auto input_size = checkpointed_inputs.size(); + + strongs input_values; + input_values.reserve(input_size); + + std::vector args; + args.reserve(input_size); + + for (const Tensor& t: checkpointed_inputs) { + auto* cpti = dynamic_cast(t.unsafeGetTensorImpl()); + TORCH_CHECK(cpti); + input_values.push_back(cpti->ref->value->value); + if (use_log) { + args.push_back(cpti->counter_name()); + } + } + + auto ret = make_raw(remat, input_values); + + Tensors tensors; + tensors.reserve(ret.outputs.size()); + + for (const auto& t: ret.outputs) { + auto cp = Tensor(intrusive_ptr::make(t)); + tensors.push_back(cp); + } + + if (use_log) { + std::vector res; + res.reserve(ret.outputs.size()); + + for (const auto& tensor : tensors) { + res.push_back(get_cpti(tensor)->counter_name()); + } + + DTRLogCall(res, name, args, from_time(ret.time)); + for (size_t i = 0; i < tensors.size(); ++i) { + Tensor t = tensors[i]; + auto cpti = get_cpti(t); + DTRLogMemory(cpti->counter_name(), cpti->ref->value->value->memory()); + DTRLogAlias(cpti->counter_name(), ret.aliases[i]); + } + } + + return tensors; +} + +// TODO: check that mutated value does not have alias. +void CheckpointTensorImpl::mutate(const std::string& name, + const mutate_function_t& mutate, + const Tensors& inputs, + const std::vector& mutate_idx) { + auto remat = [=](const Tensors& t) -> Tensors { + Tensors new_input_values = t; + for (size_t idx: mutate_idx) { + new_input_values[idx] = t[idx].clone(); + } + mutate(new_input_values); + return new_input_values; + }; + Tensors checkpointed_inputs = try_checkpoint(inputs); + strongs input_values; + std::vector args; + for (const Tensor& t: checkpointed_inputs) { + auto* cpti = dynamic_cast(t.unsafeGetTensorImpl()); + TORCH_CHECK(cpti); + input_values.push_back(cpti->ref->value->value); + if (use_log) { + args.push_back(cpti->counter_name()); + } + } + auto ret = make_raw(remat, input_values); + const auto& modified = ret.outputs; + for (size_t idx: mutate_idx) { + cell_from_tensor(inputs[idx])->value = modified[idx]; + } + if (use_log) { + DTRLogMutate(name, args, mutate_idx, from_time(ret.time)); + } +} + +void CheckpointTensorImpl::release_resources() { + if (use_log) { + DTRLogRelease(counter_name()); + } + ref.reset(); +} + +CheckpointTensorImpl::CheckpointTensorImpl(const Tensor& t) : CheckpointTensorImpl(intrusive_ptr::make(t)) { + pool.exts.push_back(weak_intrusive_ptr(ref->value)); +} + +} diff --git a/aten/src/ATen/CheckpointTensorImpl.h b/aten/src/ATen/CheckpointTensorImpl.h new file mode 100644 index 00000000000..7eec97e8e75 --- /dev/null +++ b/aten/src/ATen/CheckpointTensorImpl.h @@ -0,0 +1,465 @@ +#pragma once + +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +#define likely(x) __builtin_expect(!!(x), 1) +#define unlikely(x) __builtin_expect(!!(x), 0) +#define TORCH_CHECK(a, ...) // profile mode + +// System Description: +// Every Tensor is managed by a CheckpointTensor, +// that describe how it is computed, (the function and the inputs) +// And might optionally hold the tensor value. +// The tensor value might be dropped, and when requested later, recomputed and cached again. + +// Corner Cases: +// A CheckpointedTensor might require_grad. +// In this case the underlying data must not require_grad, +// as we want backpropagation on the outer, uncheckpoined level. +// To be more specific, suppose a tensor is recomputed multiple times. +// We want to only compute the gradient exactly once. +// To do this, the wrapper must be require_grad, and the wrapped value must not. +// A CheckpointedTensor might be constant. +// In this case it is unevictable. +// An operator might return multiple output. +// In this case the computation info (rematerializer) is shared between all of them, +// And when the function get computed again all value get cached. +// An operator might not return value, but only mutate input value. +// To combat this, we COW the operator, and wrap CheckpopintTensor with a Ref. +// By doing this the inner CheckpointTensor is kept purely functional. +// An operator might try to mutate uncheckpointed tensor. +// We do not support this and will error. +// An operator might create aliases. +// We track alias in AliasPool. +// Each AliasPool hold a set of tensor that is alias to eachother. +// An operator might try to create Alias to an unevictable tensor. +// In such a case the output tensor is unevictable. +// An operator might try to mutate Tensor with Alias. +// We do not support this case an will error if a Tensor has any alive Alias. +// However it could be done without a major redesign of the system - +// Each AliasPool will hold weak pointers to the External Reference. +// When alias mutation occur, +// we make a rematerialize_function that take in the base tensor (other tensor alias from) +// and output all the new value of the aliases, then update the Ref. +// Of course, the cleaner way is to not support this. +// Shame on those who use this feature. + +// Memory Safety: +// The objects here will have lots of backedges. +// In order to collect memory when computation is completed, +// We require that all strong pointer is of the form of value -> input. +// This ensure that everything will be released if there is no external ref whatsoever. + +// Optimization: +// We treat tensor that has no external reference differently - +// They will never be externally used again so we assume their next use time is infinite +// so, if it doesnt has any evicted neighbor it will get evicted immediately. + +// Note: to code fast I do not use RAII and just assume the code will not try to recover from exception. +// It should be easy to fix though. + +namespace at { + +// TODO: using a pool allocator might make more sense - no need to allocate and delete each pointer individually. +template +struct EquivalentClassNode : intrusive_ptr_target { + explicit EquivalentClassNode(const T& t) : t_unsafe(t) { } + mutable intrusive_ptr parent; + bool is_root() { + return !parent; + } + void release_resources() override { + parent.reset(); + } + T t_unsafe; +}; + +template +T& get_t(const intrusive_ptr>& n) { + return find_root(n)->t_unsafe; +} + +template +static void update_t(const intrusive_ptr>& n, const T& t) { + find_root(n)->t_unsafe = t; +} + +template +intrusive_ptr> find_root(const intrusive_ptr>& n) { + if (n->is_root()) { + return n; + } else { + n->parent = find_root(n->parent); + return n->parent; + } +} + +template +intrusive_ptr> merge(const std::function& merge_t, + const intrusive_ptr>& lhs, + const intrusive_ptr>& rhs) { + auto l = find_root(lhs); + auto r = find_root(rhs); + if (l == r) { + return l; + } + l->parent = r; + r->t_unsafe = merge_t(l->t_unsafe, r->t_unsafe); + return r; +} + +size_t memory(const Tensor& t); + +template +struct RefCell final : intrusive_ptr_target { + mutable T value; + void release_resources() final { + static_release_resources(value); + } + RefCell(const T& t) : value(t) { } +}; + +template +using Ref = intrusive_ptr>; + +template +void static_release_resources(intrusive_ptr& ptr) { + ptr.reset(); +} + +class CheckpointTensorCell; +using strong = intrusive_ptr; +using strongs = std::vector; +using weak = weak_intrusive_ptr; +using weaks = std::vector; +using Tensors = std::vector; +using rematerialize_function_t = std::function; +using mutate_function_t = std::function; + +using time_t = std::chrono::time_point; +using duration_t = std::chrono::system_clock::duration; +struct CheckpointInfo { + duration_t compute_cost; + // @ZACH: Floating Point instability? + double cost(size_t memory, size_t staleness) const { + TORCH_CHECK(memory > 0); + TORCH_CHECK(staleness > 0); + return compute_cost.count() / static_cast(memory * staleness); + } + CheckpointInfo(duration_t compute_cost) : + compute_cost(compute_cost) { + } +}; + +// ecn represent a evicted tensor group. +// it is a set of tensor that are evicted, and if two evicted tensor are input -> output to each other, +// they must be in an ecn. +// note: we try to support removal from ecn by subtracting compute_cost and memory. +// this will create suprious connection but that should be fine empircally. +// below is an example of a suprious connection: +// a -> b, a -> c +// a, b, c got evicted so belong to a single ecn. +// a got rematerialized. +// b, c still belong to a single ecn although there is no connection. +using ecn_ptr = intrusive_ptr>; + +struct Unsafe { }; + +// The rematerializer could be called to reinvoke an operator. +// Tensor point to remat which point to Tensor. +// To build the cycle remat support a default constructor, +// And allow you to fill in the member later. +struct Rematerializer : intrusive_ptr_target { + rematerialize_function_t func; + strongs inputs; + weaks outputs; + duration_t compute_cost; + // when some output in here get evicted, they should belong to this ecn. + // a rematerializer have to track this, + // because when multiple output of a rematerializer get evicted, + // we only want to count the compute cost once. + ecn_ptr ecn; + Rematerializer(const Unsafe&, + const rematerialize_function_t& func, + const strongs& inputs, + duration_t compute_cost) : + func(func), + inputs(inputs), + compute_cost(compute_cost) { + } + void release_resources() final { + func = rematerialize_function_t(); + inputs.clear(); + outputs.clear(); + } + void remat(); + ecn_ptr get_ecn(); + CheckpointInfo get_cpi(); +}; + +// Track all Tensor that share the same Storage. +// This is the atomic level of eviction - when evicting, everything here will get evicted. +// When an AliasPool is evicted, the Storage of the underlying tensor must be freed. +// Additionally, the AliasPool contain weak pointer to all children of tensors, +// in order to compute the score of evicting a Storage. +struct AliasPool : intrusive_ptr_target { + weaks tensors; + weaks neighbors; + std::set neighbor_ecn(); + // get() might hold some raw Tensor, rendering them unevictable. + // it is likely that get() will run out of memory, and when it does so, it will try to evict. + // so, it is crucial that we dont try to evict those tensors - doing so will not evict anything. + // lock_count count how many time a tensor is referenced by get. + size_t lock_count = 0; + size_t external_count = 0; + void lock() { + ++lock_count; + } + void unlock() { + --lock_count; + } + intrusive_ptr head_remat; + bool evictable() const { + return lock_count == 0 && head_remat; + } + // if it is not evictable it must not be evicted. + bool is_evicted = false; + size_t memory; + time_t last_used_time; + // An aliaspool cant register itself to the checkpointpool - you have to do it yourself. + AliasPool(const Unsafe&, intrusive_ptr head_remat, size_t memory) : + head_remat(head_remat), + memory(memory), + last_used_time(std::chrono::system_clock::now()) { + } + // if it is evicted, then hold the evicted tensor group. + ecn_ptr ecn; + double cost(time_t current_time); + void evict(); + void register_external() { + ++external_count; + } + void release_external() { + --external_count; + if (external_count == 0) { + if (lock_count > 0) {return;} + TORCH_CHECK(lock_count == 0); + if (memory > 0 && (!ecn) && head_remat) { + evict(); + } + } + } + // if it was evicted, refresh it. otherwise do nothing. + // have to check so, because when we rematerialize a single tensor in an aliaspool, + // we will set it to non-evicted, and when we rematerialize it's tensor they will also reset this. + void set_not_evicted(const intrusive_ptr& self); + void release_resources() final { + tensors.clear(); + neighbors.clear(); + head_remat.reset(); + } +}; + +struct CheckpointTensorCell : intrusive_ptr_target { + std::unique_ptr t; + bool defined = false; + bool is_undefined_tensor; + DispatchKeySet key_set_; + DispatchKeySet key_set() const { + TORCH_CHECK(defined); + return key_set_; + } + caffe2::TypeMeta dtype_; + caffe2::TypeMeta dtype() const { + TORCH_CHECK(defined); + return dtype_; + } + c10::optional optional_device_; + c10::optional optional_device() const { + TORCH_CHECK(defined); + return optional_device_; + } + // A Tensor is evictable iff it's AliasPool is evictable. + // A evictable tensor must have Rematerializer. + intrusive_ptr pool; + intrusive_ptr remat; + void evict() { + TORCH_CHECK(remat); + t.reset(); + } + void fill(const Tensor& t); + explicit CheckpointTensorCell(const Tensor& t, const intrusive_ptr& pool) : pool(pool) { + fill(t); + } + explicit CheckpointTensorCell(const Tensor& t, + const intrusive_ptr& pool, + const intrusive_ptr& remat) : + pool(pool), remat(remat) { + fill(t); + } + size_t memory() { + TORCH_CHECK(defined); + return pool->memory; + } + Tensor get() { + if (! t) { + TORCH_CHECK(remat); + remat->remat(); + } + TORCH_CHECK(t); + TORCH_CHECK(! t->key_set().has(DispatchKey::CheckpointTensorId)); + pool->last_used_time = std::chrono::system_clock::now(); + return *t; + } + void pin() { + get(); + pool->head_remat.reset(); + remat.reset(); + } + void release_resources() final { + t.reset(); + pool.reset(); + remat.reset(); + } +}; + +// An external reference. +// Each strong will have at most one external reference. +// By keeping such an invariant, whenever an external reference die, +// We know that the underlying strong is only used internally. +// Thus, when it die we can apply optimization like banishing/infinite staleness. +// We keep this invariant by only allowing CheckpointTensorImpl to make new External, +// When new CheckpointTensorImpl is constructed. +struct External : intrusive_ptr_target { + External(const strong& value) : value(value) { + value->pool->register_external(); + } + External(const Tensor& value) : + External(strong::make(value, + intrusive_ptr::make(Unsafe(), + intrusive_ptr(), + memory(value)))) { } + External(const Tensor& value, + const intrusive_ptr& pool, + const intrusive_ptr& remat) : + External(strong::make(value, pool, remat)) { } + strong value; + void release_resources() override; +}; + +inline DispatchKeySet convert_key_set(const DispatchKeySet& t) { + CHECK(!t.has(DispatchKey::Checkpoint)); + auto ret = t.add(DispatchKey::Checkpoint); + return ret; +} + +struct CheckpointTensorImpl : TensorImpl { + int id = gen_counter(); + static int counter; + static int gen_counter() { + return counter++; + } + std::string counter_name() const { + return std::string("x") + std::to_string(id); + } + + Ref> ref; + + void release_resources() final; + + explicit CheckpointTensorImpl(const Ref>& ref) : + TensorImpl(convert_key_set(ref->value->value->key_set()), + ref->value->value->dtype(), + ref->value->value->optional_device()), + ref(ref) { + if (key_set().has(DispatchKey::Autograd)) { + set_requires_grad(true); + } + } + + explicit CheckpointTensorImpl(const intrusive_ptr& e) : + CheckpointTensorImpl(Ref>::make(e)) { } + + explicit CheckpointTensorImpl(const Tensor& t); + + static Tensors make(const std::string& name, + const rematerialize_function_t& remat, + const Tensors& inputs); + + // mutate_idx indicate which of the inputs will get mutated. + static void mutate(const std::string& name, + const mutate_function_t& mutate, + const Tensors& inputs, + const std::vector& mutate_idx); + intrusive_ptr shallow_copy_and_detach(const VariableVersion& version_counter, + bool allow_tensor_metadata_change) const override; + void shallow_copy_from(const c10::intrusive_ptr& impl) override; + int64_t dim() const override { + return ref->value->value->get().dim(); + } + int64_t numel() const override { + return ref->value->value->get().numel(); + } + IntArrayRef sizes() const override { + return ref->value->value->get().sizes(); + } + int64_t size(int64_t d) const override { + return ref->value->value->get().size(d); + } + IntArrayRef strides() const override { + return ref->value->value->get().strides(); + } + int64_t stride(int64_t d) const override { + return ref->value->value->get().stride(d); + } + bool has_storage() const override { + return false; + } +}; + +// CheckpointPool keep a list of AliasPool, and search over them to choose the best one to evict. +struct CheckpointPool { + std::vector> aps; + std::vector> exts; + std::random_device rd; + std::mt19937 gen = std::mt19937(rd()); + bool has_memory_budget = false; + long memory_budget; + void evict(); + void auto_evict(); + void clear_checkpointpool(); + void add(const intrusive_ptr&); + CheckpointPool(); +}; + +inline CheckpointTensorImpl* get_cpti(const Tensor& t) { + auto* cpti = dynamic_cast(t.unsafeGetTensorImpl()); + TORCH_CHECK(cpti != nullptr); + return cpti; +} + +inline Ref> cell_from_tensor(const Tensor& t) { + return get_cpti(t)->ref; +} + +} diff --git a/aten/src/ATen/Logger.h b/aten/src/ATen/Logger.h new file mode 100644 index 00000000000..7e3c01a9af9 --- /dev/null +++ b/aten/src/ATen/Logger.h @@ -0,0 +1,197 @@ +#pragma once + +#include +#include +#include <../../../third_party/json/single_include/nlohmann/json.hpp> + +namespace at { + +struct DTRLogger { + std::string time_prefix; + std::ofstream out; + static std::string get_time_prefix() { + std::time_t t = std::time(nullptr); + std::tm* tm = std::localtime(&t); + return + std::to_string(1900+tm->tm_year) + "-" + + std::to_string(1+tm->tm_mon) + "-" + + std::to_string(tm->tm_mday) + "-" + + std::to_string(tm->tm_hour) + "-" + + std::to_string(tm->tm_min) + "-" + + std::to_string(tm->tm_sec); + } + std::string get_filename(const std::string& name) { + return time_prefix + "-" + name + ".log"; + } + DTRLogger() : time_prefix(get_time_prefix()), out(get_filename("default")) { } + void log(const std::string& str) { + out << str << std::endl; + } + static DTRLogger& logger() { + static DTRLogger ret; + return ret; + } + +}; + +using json = nlohmann::json; +constexpr bool log_json = true; +const std::string INSTRUCTION = "INSTRUCTION"; +const std::string ANNOTATION = "ANNOTATION"; +const std::string RELEASE = "RELEASE"; +const std::string PIN = "PIN"; +const std::string TIME = "TIME"; +const std::string ARGS = "ARGS"; +const std::string MEMORY = "MEMORY"; +const std::string ALIAS = "ALIAS"; +const std::string NAME = "NAME"; +const std::string CONSTANT = "CONSTANT"; + +void DTRLogConstant(const std::string& name) { + if (log_json) { + json j; + j[INSTRUCTION] = CONSTANT; + j[NAME] = name; + DTRLogger::logger().log(j.dump()); + } else { + DTRLogger::logger().log(CONSTANT + " " + name); + } +} + +void DTRLogMemory(const std::string& name, size_t memory) { + if (log_json) { + json j; + j[INSTRUCTION] = MEMORY; + j[NAME] = name; + j[MEMORY] = std::to_string(memory); + DTRLogger::logger().log(j.dump()); + } else { + DTRLogger::logger().log(name + " " + MEMORY + ": " + std::to_string(memory)); + } +} + +void DTRLogAlias(const std::string& name, int index) { + if (log_json) { + json j; + j[INSTRUCTION] = ALIAS; + j[NAME] = name; + j[ALIAS] = std::to_string(index); + DTRLogger::logger().log(j.dump()); + } else { + DTRLogger::logger().log(name + " " + ALIAS + ": " + std::to_string(index)); + } +} + +void DTRLogCopyFrom(const std::string& to, const std::string& from) { + if (log_json) { + json j; + j[INSTRUCTION] = "COPY_FROM"; + j["DST"] = to; + j["SRC"] = from; + DTRLogger::logger().log(j.dump()); + } else { + DTRLogger::logger().log(to + " <- " + from); + } +} + +void DTRLogCopy(const std::string& new_name, const std::string& old_name) { + if (log_json) { + json j; + j[INSTRUCTION] = "COPY"; + j["DST"] = new_name; + j["SRC"] = old_name; + DTRLogger::logger().log(j.dump()); + } else { + DTRLogger::logger().log(new_name + " = " + old_name); + } +} + +void DTRLogMutate(const std::string& name, + const std::vector& args, + const std::vector& mutate, + const std::string& time) { + if (log_json) { + json j; + j[INSTRUCTION] = "MUTATE"; + j[NAME] = name; + j[ARGS] = args; + j["MUTATE"] = mutate; + j[TIME] = time; + DTRLogger::logger().log(j.dump()); + } else { + std::string log = name; + log += "("; + for (const auto& s : args) { + log += s; + log += ", "; + } + log += ") "; + log += " MUTATING: "; + log += "("; + for (const size_t i : mutate) { + log += std::to_string(i); + log += ", "; + } + log += ") "; + log += TIME; + log += ": "; + log += time; + DTRLogger::logger().log(log); + } +} + +void DTRLogRelease(const std::string& name) { + if (log_json) { + json j; + j[INSTRUCTION] = RELEASE; + j[NAME] = name; + DTRLogger::logger().log(j.dump()); + } else { + DTRLogger::logger().log(RELEASE + ": " + name); + } +} + +void DTRLogPin(const std::string& name) { + if (log_json) { + json j; + j[INSTRUCTION] = PIN; + j[NAME] = name; + DTRLogger::logger().log(j.dump()); + } else { + DTRLogger::logger().log(RELEASE + ": " + name); + } +} + +void DTRLogCall(const std::vector& res, + const std::string& name, + const std::vector& args, + const std::string& time) { + if (log_json) { + json j; + j[INSTRUCTION] = "CALL"; + j[NAME] = name; + j["RESULT"] = res; + j[ARGS] = args; + j[TIME] = time; + DTRLogger::logger().log(j.dump()); + } else { + std::string arg = name + "("; + for (const auto& s : args) { + arg += s; + arg += ", "; + } + arg += ")"; + std::string log = "("; + for (const auto& s: res) { + log += s; + log += ", "; + } + log += ") = "; + log += arg; + log += " TIME: "; + log += time; + DTRLogger::logger().log(log); + } +} + +} diff --git a/aten/src/ATen/native/Activation.cpp b/aten/src/ATen/native/Activation.cpp index a43f16465db..71cef81aad8 100644 --- a/aten/src/ATen/native/Activation.cpp +++ b/aten/src/ATen/native/Activation.cpp @@ -828,4 +828,94 @@ Tensor& log_sigmoid_backward_out_cpu( DEFINE_DISPATCH(GeluKernel); DEFINE_DISPATCH(GeluBackwardKernel); +Tensor slice_backward(const Tensor& grad, IntArrayRef input_sizes, int64_t dim, int64_t start, int64_t end, int64_t step) { + auto grad_input = at::zeros(input_sizes, grad.options()); + grad_input.slice(dim, start, end, step).copy_(grad); + return grad_input; +} + +Tensor select_backward(const Tensor& grad, IntArrayRef input_sizes, int64_t dim, int64_t index) { + auto grad_input = at::zeros(input_sizes, grad.options()); + grad_input.select(dim, index).copy_(grad); + return grad_input; +} + +std::vector split(const Tensor& self, int64_t split_size, int64_t dim) { + TORCH_CHECK(self.dim() != 0, "split expects at least a 1-dimensional tensor"); + TORCH_CHECK(split_size >= 0, "split expects split_size be non-negative, but got split_size=", split_size); + int64_t dim_size = self.size(dim); + TORCH_CHECK(split_size > 0 || self.size(dim) == 0, + "split_size can only be 0 if dimension size is 0, " + "but got dimension size of ", dim_size); + // if split_size is 0 and dimension size is 0, there is 1 split. + int64_t num_splits = 1; + if (split_size != 0) { + // ensuring num_splits is at least 1 makes consistent the case where split_size > dim_size + // (returns a single split). We might want to error here, but keep it for BC. + num_splits = std::max((dim_size + split_size - 1) / split_size, 1); + } + std::vector splits(num_splits); + int64_t last_split_size = split_size - (split_size * num_splits - dim_size); + + for (int64_t i = 0; i < num_splits; ++i) { + auto length = i < num_splits - 1 ? split_size : last_split_size; + splits[i] = self.narrow(dim, i * split_size, length); + } + return splits; +} + +std::vector split_with_sizes(const Tensor& self, IntArrayRef split_sizes, int64_t dim) { + TORCH_CHECK(self.dim() != 0, "split expects at least a 1-dimensional tensor"); + int64_t dim_size = self.size(dim); + int64_t num_splits = split_sizes.size(); + std::vector splits(num_splits); + int64_t start_idx = 0; + int64_t i; + + for (i = 0; i < num_splits; ++i) { + auto length = split_sizes[i]; + TORCH_CHECK(length >= 0, + "split_with_sizes expects split_sizes have only non-negative ", + "entries, but got split_sizes=", split_sizes); + splits[i] = self.narrow(dim, start_idx, length); + start_idx += length; + } + TORCH_CHECK(start_idx == dim_size, + "split_with_sizes expects split_sizes to sum exactly to ", dim_size, + " (input tensor's size at dimension ", dim, "), ", "but got split_sizes=", split_sizes); + return splits; +} + +Tensor split_backward(c10::ArrayRef grads, + int64_t split_size, int64_t dim, IntArrayRef sizes, const at::TensorOptions &options) { + dim = at::maybe_wrap_dim(dim, sizes.size()); + int64_t dim_size = sizes[dim]; + int64_t num_splits = grads.size(); + std::vector split_sizes(num_splits, split_size); + split_sizes[num_splits - 1] = split_size - (split_size * num_splits - dim_size); + return at::native::split_with_sizes_backward(grads, split_sizes, dim, sizes, options); +} + +Tensor split_with_sizes_backward(c10::ArrayRef grads, + IntArrayRef split_sizes, int64_t dim, IntArrayRef sizes, const at::TensorOptions &options) { + dim = at::maybe_wrap_dim(dim, sizes.size()); + + // it's possible some of the grads are not defined (represents tensors of all 0s). + // Since at::cat can't handle those, let's define them + std::vector grads_all_defined(grads.size()); + for (size_t j = 0; j < grads.size(); ++j) { + if (grads[j].defined()) { + grads_all_defined[j] = grads[j]; + } else { + auto length = split_sizes[j]; + auto grad_size = sizes.vec(); + grad_size[dim] = length; + grads_all_defined[j] = at::zeros(grad_size, options); + } + } + + auto ret = at::cat(grads_all_defined, dim); + return ret; +} + }} // namespace at::native diff --git a/aten/src/ATen/native/Checkpoint.cpp b/aten/src/ATen/native/Checkpoint.cpp new file mode 100644 index 00000000000..ac321deeea8 --- /dev/null +++ b/aten/src/ATen/native/Checkpoint.cpp @@ -0,0 +1,1620 @@ +#include +#include + +namespace at { namespace native { + +Tensor checkpoint_add(const Tensor& a, const Tensor& b, c10::Scalar c) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::add(vec.at(0), vec.at(1), c)}; + }; + return CheckpointTensorImpl::make("add", rt, {a, b})[0]; +} + +Tensor checkpoint_t(at::Tensor const& a) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::t(vec.at(0))}; + }; + return CheckpointTensorImpl::make("t", rt, {a})[0]; +} + +Tensor checkpoint_add(at::Tensor const& a, c10::Scalar b, c10::Scalar c) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::add(vec.at(0), b, c)}; + }; + return CheckpointTensorImpl::make("add", rt, {a})[0]; +} + +Tensor& checkpoint_add_(Tensor& a, const Tensor& b, Scalar c) { + mutate_function_t mt = + [=](const Tensors& vec) { + vec.at(0).add_(vec.at(1), c); + }; + CheckpointTensorImpl::mutate("add_", mt, {a, b}, {0}); + return a; +} + +Tensor checkpoint_mul(at::Tensor const& a, at::Tensor const& b) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::mul(vec.at(0), vec.at(1))}; + }; + return CheckpointTensorImpl::make("mul", rt, {a, b})[0]; +} + +Tensor& checkpoint_mul_(at::Tensor& a, at::Tensor const& b) { + mutate_function_t mt = + [=](const Tensors& vec) { + vec.at(0).mul_(vec.at(1)); + }; + CheckpointTensorImpl::mutate("mul_", mt, {a, b}, {0}); + return a; +} + +Tensor& checkpoint_mul_(at::Tensor& a, c10::Scalar b) { + mutate_function_t mt = + [=](const Tensors& vec) { + vec.at(0).mul_(b); + }; + CheckpointTensorImpl::mutate("mul_", mt, {a}, {0}); + return a; +} + +Tensor checkpoint_zeros_like(at::Tensor const& a, c10::TensorOptions const& b, c10::optional c) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::zeros_like(vec.at(0), b, c)}; + }; + return CheckpointTensorImpl::make("zeros_like", rt, {a})[0]; +} + +Tensor checkpoint_ones_like(at::Tensor const& a, c10::TensorOptions const& b, c10::optional c) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::ones_like(vec.at(0), b, c)}; + }; + return CheckpointTensorImpl::make("ones_like", rt, {a})[0]; +} + +Tensor checkpoint_addcmul(at::Tensor const& a, at::Tensor const& b, at::Tensor const& c, c10::Scalar d) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::addcmul(vec.at(0), vec.at(1), vec.at(2), d)}; + }; + return CheckpointTensorImpl::make("addcmul", rt, {a, b, c})[0]; +} + +Tensor& checkpoint_addcmul_(at::Tensor& a, at::Tensor const& b, at::Tensor const& c, c10::Scalar d) { + mutate_function_t mt = + [=](const Tensors& vec) { + vec.at(0).addcmul_(vec.at(1), vec.at(2), d); + }; + CheckpointTensorImpl::mutate("addcmul_", mt, {a, b, c}, {0}); + return a; +} + +Tensor checkpoint_abs(const Tensor& a) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::abs(vec.at(0))}; + }; + return CheckpointTensorImpl::make("abs", rt, {a})[0]; +} + +Tensor checkpoint_sqrt(const Tensor& a) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::sqrt(vec.at(0))}; + }; + return CheckpointTensorImpl::make("sqrt", rt, {a})[0]; +} + +Tensor& checkpoint_addcdiv_(at::Tensor& a, at::Tensor const& b, at::Tensor const& c, c10::Scalar d) { + mutate_function_t mt = + [=](const Tensors& vec) { + vec.at(0).addcdiv_(vec.at(1), vec.at(2), d); + }; + CheckpointTensorImpl::mutate("addcdiv_", mt, {a, b, c}, {0}); + return a; +} + +Tensor checkpoint_addcdiv(at::Tensor const& a, at::Tensor const& b, at::Tensor const& c, c10::Scalar d) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::addcdiv(vec.at(0), vec.at(1), vec.at(2), d)}; + }; + return CheckpointTensorImpl::make("addcdiv", rt, {a, b, c})[0]; +} + +Tensor checkpoint_to(at::Tensor const& a, c10::TensorOptions const& b, bool c, bool d, c10::optional e) { + c10::TensorOptions b_ = b; + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {vec.at(0).to(b_, c, d, e)}; + }; + return CheckpointTensorImpl::make("to", rt, {a})[0]; +} + +Tensor checkpoint_div(const Tensor& a, const Tensor& b) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::div(vec.at(0), vec.at(1))}; + }; + return CheckpointTensorImpl::make("div", rt, {a, b})[0]; +} + +Tensor& checkpoint_div_(Tensor& a, const Tensor& b) { + mutate_function_t mt = + [=](const Tensors& vec) { + vec.at(0).div_(vec.at(1)); + }; + CheckpointTensorImpl::mutate("div_", mt, {a, b}, {0}); + return a; +} + +Tensor checkpoint_clone(at::Tensor const& a, c10::optional b) { + if (b) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::clone(vec.at(0), b)}; + }; + return CheckpointTensorImpl::make("clone", rt, {a})[0]; + } + else { + return a; + } +} + +Tensor checkpoint_where(at::Tensor const& a, at::Tensor const& b, at::Tensor const& c) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::where(vec.at(0), vec.at(1), vec.at(2))}; + }; + return CheckpointTensorImpl::make("where", rt, {a, b, c})[0]; +} + +Tensor checkpoint_constant_pad_nd(Tensor const& a, c10::ArrayRef b, c10::Scalar c) { + std::vector b_ = b.vec(); + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::constant_pad_nd(vec.at(0), b_, c)}; + }; + return CheckpointTensorImpl::make("constant_pad_nd", rt, {a})[0]; +} + +Tensor checkpoint_binary_cross_entropy(const Tensor& a, const Tensor& b, const Tensor& c, long d) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::binary_cross_entropy(vec.at(0), vec.at(1), vec.at(2), d)}; + }; + return CheckpointTensorImpl::make("binary_cross_entropy", rt, {a, b, c})[0]; +} + +Tensor& checkpoint_binary_cross_entropy_out(Tensor& a, const Tensor& b, const Tensor& c, const Tensor& d, long e) { + mutate_function_t mt = + [=](const Tensors& vec) { + Tensor self = vec.at(0); + at::binary_cross_entropy_out(self, vec.at(1), vec.at(2), vec.at(3), e); + }; + CheckpointTensorImpl::mutate("binary_cross_entropy_out", mt, {a, b, c, d}, {0}); + return a; +} + +Tensor checkpoint_binary_cross_entropy_backward(const Tensor& a, const Tensor& b, const Tensor& c, const Tensor& d, long e) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::binary_cross_entropy_backward(vec.at(0), vec.at(1), vec.at(2), vec.at(3), e)}; + }; + return CheckpointTensorImpl::make("binary_cross_entropy_backward", rt, {a, b, c, d})[0]; +} + +Tensor& checkpoint_binary_cross_entropy_backward_out(Tensor& a, const Tensor& b, const Tensor& c, const Tensor& d, const Tensor& e, long f) { + mutate_function_t mt = + [=](const Tensors& vec) { + Tensor self = vec.at(0); + at::binary_cross_entropy_backward_out(self, vec.at(1), vec.at(2), vec.at(3), vec.at(4), f); + }; + CheckpointTensorImpl::mutate("binary_cross_entropy_backward_out", mt, {a, b, c, d, e}, {0}); + return a; +} + +Tensor checkpoint_embedding(const Tensor& a, const Tensor& b, long c, bool d, bool e) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::embedding(vec.at(0), vec.at(1), c, d, e)}; + }; + return CheckpointTensorImpl::make("embedding", rt, {a, b})[0]; +} + +Tensor checkpoint_embedding_backward(const Tensor& a, const Tensor& b, long c, long d, bool e, bool f) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::embedding_backward(vec.at(0), vec.at(1), c, d, e, f)}; + }; + return CheckpointTensorImpl::make("embedding", rt, {a, b})[0]; +} + +std::tuple +checkpoint_cudnn_batch_norm(const Tensor& a, const Tensor& b, const Tensor& c, const Tensor& d, const Tensor& e, bool f, double g, double h) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + auto ret = at::cudnn_batch_norm(vec.at(0), vec.at(1), vec.at(2), vec.at(3), vec.at(4), f, g, h); + return {std::get<0>(ret), std::get<1>(ret), std::get<2>(ret), std::get<3>(ret)}; + }; + auto ret = CheckpointTensorImpl::make("cudnn_batch_norm", rt, {a, b, c, d, e}); + return {ret[0], ret[1], ret[2], ret[3]}; +} + +std::tuple checkpoint_cudnn_batch_norm_backward(at::Tensor const& a, at::Tensor const& b, at::Tensor const& c, at::Tensor const& d, at::Tensor const& e, at::Tensor const& f, at::Tensor const& g, double h, at::Tensor const& i) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + auto ret = at::cudnn_batch_norm_backward(vec.at(0), vec.at(1), vec.at(2), vec.at(3), vec.at(4), vec.at(5), vec.at(6), h, vec.at(7)); + return {std::get<0>(ret), std::get<1>(ret), std::get<2>(ret)}; + }; + auto ret = CheckpointTensorImpl::make("cudnn_batch_norm_backward", rt, {a, b, c, d, e, f, g, i}); + return {ret[0], ret[1], ret[2]}; +} + +Tensor checkpoint_as_strided(const Tensor& a, c10::ArrayRef b, c10::ArrayRef c, c10::optional d) { + std::vector b_ = b.vec(), c_ = c.vec(); + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::as_strided(vec.at(0), b_, c_, d)}; + }; + return CheckpointTensorImpl::make("as_strided", rt, {a})[0]; +} + +Tensor checkpoint__masked_scale(const Tensor& a, const Tensor& b, double c) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::_masked_scale(vec.at(0), vec.at(1), c)}; + }; + return CheckpointTensorImpl::make("_masked_scale", rt, {a, b})[0]; +} + +Tensor checkpoint_cudnn_convolution(const Tensor& a, const Tensor& b, c10::ArrayRef c, c10::ArrayRef d, c10::ArrayRef e, long f, bool g, bool h) { + std::vector c_ = c.vec(), d_ = d.vec(), e_ = e.vec(); + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::cudnn_convolution(vec.at(0), vec.at(1), c_, d_, e_, f, g, h)}; + }; + return CheckpointTensorImpl::make("cudnn_convolution", rt, {a, b})[0]; +} + +Tensor checkpoint_cudnn_convolution_transpose(const Tensor& a, const Tensor& b, c10::ArrayRef c, c10::ArrayRef d, c10::ArrayRef e, c10::ArrayRef f, long g, bool h, bool i) { + std::vector c_ = c.vec(), d_ = d.vec(), e_ = e.vec(), f_ = f.vec(); + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::cudnn_convolution_transpose(vec.at(0), vec.at(1), c_, d_, e_, f_, g, h, i)}; + }; + return CheckpointTensorImpl::make("cudnn_convolution_transpose", rt, {a, b})[0]; +} + +std::tuple checkpoint_cudnn_convolution_backward(const Tensor& a, const Tensor& b, const Tensor& c, c10::ArrayRef d, c10::ArrayRef e, c10::ArrayRef f, long g, bool h, bool i, std::array j) { + std::vector d_ = d.vec(), e_ = e.vec(), f_ = f.vec(); + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + auto ret = at::cudnn_convolution_backward(vec.at(0), vec.at(1), vec.at(2), d_, e_, f_, g, h, i, j); + return {std::get<0>(ret), std::get<1>(ret)}; + }; + auto ret = CheckpointTensorImpl::make("cudnn_convolution_backward", rt, {a, b, c}); + return {ret[0], ret[1]}; +} + +std::tuple checkpoint_cudnn_convolution_transpose_backward(const Tensor& a, const Tensor& b, const Tensor& c, c10::ArrayRef d, c10::ArrayRef e, c10::ArrayRef f, c10::ArrayRef g, long h, bool i, bool j, std::array k) { + std::vector d_ = d.vec(), e_ = e.vec(), f_ = f.vec(), g_ = g.vec(); + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + auto ret = at::cudnn_convolution_transpose_backward(vec.at(0), vec.at(1), vec.at(2), d_, e_, f_, g_, h, i, j, k); + return {std::get<0>(ret), std::get<1>(ret)}; + }; + auto ret = CheckpointTensorImpl::make("cudnn_convolution_transpose_backward", rt, {a, b, c}); + return {ret[0], ret[1]}; +} + +Tensor checkpoint_cudnn_convolution_backward_input(c10::ArrayRef a, const Tensor& b, const Tensor& c, c10::ArrayRef d, c10::ArrayRef e, c10::ArrayRef f, long g, bool h, bool i) { + std::vector a_ = a.vec(), d_ = d.vec(), e_ = e.vec(), f_ = f.vec(); + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::cudnn_convolution_backward_input(a_, vec.at(0), vec.at(1), d_, e_, f_, g, h, i)}; + }; + return CheckpointTensorImpl::make("cudnn_convolution_backward_input", rt, {b, c})[0]; +} + +Tensor checkpoint_cudnn_convolution_transpose_backward_input(const Tensor& a, const Tensor& b, c10::ArrayRef c, c10::ArrayRef d, c10::ArrayRef e, long f, bool g, bool h) { + std::vector c_ = c.vec(), d_ = d.vec(), e_ = e.vec(); + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::cudnn_convolution_transpose_backward_input(vec.at(0), vec.at(1), c_, d_, e_, f, g, h)}; + }; + return CheckpointTensorImpl::make("cudnn_convolution_transpose_backward_input", rt, {a, b})[0]; +} + +Tensor checkpoint_cudnn_convolution_backward_weight(c10::ArrayRef a, const Tensor& b, const Tensor& c, c10::ArrayRef d, c10::ArrayRef e, c10::ArrayRef f, long g, bool h, bool i) { + std::vector a_ = a.vec(), d_ = d.vec(), e_ = e.vec(), f_ = f.vec(); + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::cudnn_convolution_backward_weight(a_, vec.at(0), vec.at(1), d_, e_, f_, g, h, i)}; + }; + return CheckpointTensorImpl::make("cudnn_convolution_backward_weight", rt, {b, c})[0]; +} + +Tensor checkpoint_cudnn_convolution_transpose_backward_weight(c10::ArrayRef a, const Tensor& b, const Tensor& c, c10::ArrayRef d, c10::ArrayRef e, c10::ArrayRef f, long g, bool h, bool i) { + std::vector a_ = a.vec(), d_ = d.vec(), e_ = e.vec(), f_ = f.vec(); + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::cudnn_convolution_transpose_backward_weight(a_, vec.at(0), vec.at(1), d_, e_, f_, g, h, i)}; + }; + return CheckpointTensorImpl::make("cudnn_convolution_transpose_backward_weight", rt, {b, c})[0]; +} + +Tensor checkpoint_relu(const Tensor& a) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::relu(vec.at(0))}; + }; + return CheckpointTensorImpl::make("relu", rt, {a})[0]; +} + +Tensor& checkpoint_relu_(Tensor& a) { + mutate_function_t mt = + [=](const Tensors& vec) { + vec.at(0).relu_(); + }; + CheckpointTensorImpl::mutate("relu_", mt, {a}, {0}); + return a; +} + +Tensor checkpoint_log(at::Tensor const& a) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::log(vec.at(0))}; + }; + return CheckpointTensorImpl::make("log", rt, {a})[0]; +} + +Tensor& checkpoint_log_out(at::Tensor& a, at::Tensor const& b) { + mutate_function_t mt = + [=](const Tensors& vec) { + Tensor a_ = vec.at(0); + at::log_out(a_, vec.at(1)); + }; + CheckpointTensorImpl::mutate("log_out", mt, {a, b}, {0}); + return a; +} + +Tensor checkpoint_rsub(at::Tensor const& a, at::Tensor const& b, c10::Scalar c) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::rsub(vec.at(0), vec.at(1), c)}; + }; + return CheckpointTensorImpl::make("rsub", rt, {a, b})[0]; +} + +Tensor checkpoint_rsub(at::Tensor const& a, c10::Scalar b, c10::Scalar c) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::rsub(vec.at(0), b, c)}; + }; + return CheckpointTensorImpl::make("rsub", rt, {a})[0]; +} + +Tensor checkpoint_mul(at::Tensor const& a, c10::Scalar b) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::mul(vec.at(0), b)}; + }; + return CheckpointTensorImpl::make("mul", rt, {a})[0]; +} + +std::tuple checkpoint_max_pool2d_with_indices_out(Tensor& a, Tensor& b, const Tensor& c, c10::ArrayRef d, c10::ArrayRef e, c10::ArrayRef f, c10::ArrayRef g, bool h) { + std::vector d_ = d.vec(), e_ = e.vec(), f_ = f.vec(), g_ = g.vec(); + mutate_function_t mt = + [=](const Tensors& vec) { + Tensor a_ = vec.at(0), b_ = vec.at(1); + at::max_pool2d_with_indices_out(a_, b_, vec.at(2), d_, e_, f_, g_, h); + }; + CheckpointTensorImpl::mutate("max_pool2d_with_indices_out", mt, {a, b, c}, {0, 1}); + return {a, b}; +} + +Tensor checkpoint_avg_pool2d(const Tensor& a, c10::ArrayRef b, c10::ArrayRef c, c10::ArrayRef d, bool e, bool f, c10::optional g) { + std::vector b_ = b.vec(), c_ = c.vec(), d_ = d.vec(); + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::avg_pool2d(vec.at(0), b_, c_, d_, e, f, g)}; + }; + return CheckpointTensorImpl::make("avg_pool2d", rt, {a})[0]; +} + +Tensor checkpoint_avg_pool2d_backward(const Tensor& a, const Tensor& b, c10::ArrayRef c, c10::ArrayRef d, c10::ArrayRef e, bool f, bool g, c10::optional h) { + std::vector c_ = c.vec(), d_ = d.vec(), e_ = e.vec(); + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::avg_pool2d_backward(vec.at(0), vec.at(1), c_, d_, e_, f, g, h)}; + }; + return CheckpointTensorImpl::make("avg_pool2d_backward", rt, {a, b})[0]; +} + +Tensor& checkpoint_avg_pool2d_out(Tensor& a, const Tensor& b, c10::ArrayRef c, c10::ArrayRef d, c10::ArrayRef e, bool f, bool g, c10::optional h) { + std::vector c_ = c.vec(), d_ = d.vec(), e_ = e.vec(); + mutate_function_t mt = + [=](const Tensors& vec) { + Tensor a_ = vec.at(0); + at::avg_pool2d_out(a_, vec.at(1), c_, d_, e_, f, g, h); + }; + CheckpointTensorImpl::mutate("avg_pool2d_out", mt, {a, b}, {0}); + return a; +} + +Tensor& checkpoint_avg_pool2d_backward_grad_input(Tensor& a, const Tensor& b, const Tensor& c, c10::ArrayRef d, c10::ArrayRef e, c10::ArrayRef f, bool g, bool h, c10::optional i) { + std::vector d_ = d.vec(), e_ = e.vec(), f_ = f.vec(); + mutate_function_t mt = + [=](const Tensors& vec) { + Tensor a_ = vec.at(0); + at::avg_pool2d_backward_out(a_, vec.at(1), vec.at(2), d_, e_, f_, g, h, i); + }; + CheckpointTensorImpl::mutate("avg_pool2d_backward_grad_input", mt, {a, b, c}, {0}); + return a; +} + +std::tuple checkpoint_max_pool2d_with_indices(const Tensor& a, c10::ArrayRef b, c10::ArrayRef c, c10::ArrayRef d, c10::ArrayRef e, bool f) { + std::vector b_ = b.vec(), c_ = c.vec(), d_ = d.vec(), e_ = e.vec(); + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + auto ret = at::max_pool2d_with_indices(vec.at(0), b_, c_, d_, e_, f); + return {std::get<0>(ret), std::get<1>(ret)}; + }; + auto ret = CheckpointTensorImpl::make("max_pool2d_backward", rt, {a}); + return {ret[0], ret[1]}; +} + +Tensor& checkpoint_max_pool2d_with_indices_backward_grad_input(Tensor& a, const Tensor& b, const Tensor& c, c10::ArrayRef d, c10::ArrayRef e, c10::ArrayRef f, c10::ArrayRef g, bool h, const Tensor& i) { + std::vector d_ = d.vec(), e_ = e.vec(), f_ = f.vec(), g_ = g.vec(); + mutate_function_t mt = + [=](const Tensors& vec) { + Tensor a_ = vec.at(0); + at::max_pool2d_with_indices_backward_out(a_, vec.at(1), vec.at(2), d_, e_, f_, g, h, vec.at(3)); + }; + CheckpointTensorImpl::mutate("max_pool2d_with_indices_backward_grad_input", mt, {a, b, c, i}, {0}); + return a; +} + +Tensor checkpoint_max_pool2d_with_indices_backward(const Tensor& a, const Tensor& b, c10::ArrayRef c, c10::ArrayRef d, c10::ArrayRef e, c10::ArrayRef f, bool g, const Tensor& h) { + std::vector c_ = c.vec(), d_ = d.vec(), e_ = e.vec(), f_ = f.vec(); + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::max_pool2d_with_indices_backward(vec.at(0), vec.at(1), c_, d_, e_, f_, g, vec.at(2))}; + }; + return CheckpointTensorImpl::make("max_pool2d_with_indices_backward", rt, {a, b, h})[0]; +} + +Tensor checkpoint_view(const Tensor& a, c10::ArrayRef b) { + std::vector b_ = b.vec(); + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {vec.at(0).view(b_)}; + }; + return CheckpointTensorImpl::make("view", rt, {a})[0]; +} + +Tensor checkpoint_ne_Scalar(const Tensor& a, c10::Scalar b) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::ne(vec.at(0), b)}; + }; + return CheckpointTensorImpl::make("ne_Scalar", rt, {a})[0]; +} + +Tensor& checkpoint_ne_Scalar_out(Tensor& a, const Tensor& b, c10::Scalar c) { + mutate_function_t mt = + [=](const Tensors& vec) { + Tensor a_ = vec.at(0); + at::ne_out(a_, vec.at(1), c); + }; + CheckpointTensorImpl::mutate("ne_Scalar_out", mt, {a, b}, {0}); + return a; +} + +Tensor checkpoint_ne_Tensor(const Tensor& a, const Tensor& b) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::ne(vec.at(0), vec.at(1))}; + }; + return CheckpointTensorImpl::make("ne_Tensor", rt, {a, b})[0]; +} + +Tensor& checkpoint_ne_Tensor_out(Tensor& a, const Tensor& b, const Tensor& c) { + mutate_function_t mt = + [=](const Tensors& vec) { + Tensor a_ = vec.at(0); + at::ne_out(a_, vec.at(1), vec.at(2)); + }; + CheckpointTensorImpl::mutate("ne_Tensor_out", mt, {a, b, c}, {0}); + return a; +} + +Tensor checkpoint_eq_Scalar(const Tensor& a, c10::Scalar b) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::eq(vec.at(0), b)}; + }; + return CheckpointTensorImpl::make("eq_Scalar", rt, {a})[0]; +} + +Tensor& checkpoint_eq_Scalar_out(Tensor& a, const Tensor& b, c10::Scalar c) { + mutate_function_t mt = + [=](const Tensors& vec) { + Tensor a_ = vec.at(0); + at::eq_out(a_, vec.at(1), c); + }; + CheckpointTensorImpl::mutate("eq_Scalar_out", mt, {a, b}, {0}); + return a; +} + +Tensor checkpoint_eq_Tensor(const Tensor& a, const Tensor& b) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::eq(vec.at(0), vec.at(1))}; + }; + return CheckpointTensorImpl::make("eq_Tensor", rt, {a, b})[0]; +} + +Tensor& checkpoint_eq_Tensor_out(Tensor& a, const Tensor& b, const Tensor& c) { + mutate_function_t mt = + [=](const Tensors& vec) { + Tensor a_ = vec.at(0); + at::eq_out(a_, vec.at(1), vec.at(2)); + }; + CheckpointTensorImpl::mutate("eq_Tensor_out", mt, {a, b, c}, {0}); + return a; +} + +Tensor checkpoint_addmm(const Tensor& a, const Tensor& b, const Tensor& c, c10::Scalar d, c10::Scalar e) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::addmm(vec.at(0), vec.at(1), vec.at(2), d, e)}; + }; + return CheckpointTensorImpl::make("addmm", rt, {a, b, c})[0]; +} + +Tensor& checkpoint_addmm_out(Tensor& a, const Tensor& b, const Tensor& c, const Tensor& d, c10::Scalar e, c10::Scalar f) { + mutate_function_t mt = + [=](const Tensors& vec) { + Tensor a_ = vec.at(0); + at::addmm_out(a_, vec.at(1), vec.at(2), d, e, f); + }; + CheckpointTensorImpl::mutate("addmm_out", mt, {a, b, c}, {0}); + return a; +} + +Tensor& checkpoint_addmm_(Tensor& a, const Tensor& b, const Tensor& c, c10::Scalar d, c10::Scalar e) { + mutate_function_t mt = + [=](const Tensors& vec) { + Tensor a_ = vec.at(0); + a.addmm_(vec.at(1), vec.at(2), d, e); + }; + CheckpointTensorImpl::mutate("addmm_", mt, {a, b, c}, {0}); + return a; +} + +Tensor checkpoint_sigmoid(const Tensor& a) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::sigmoid(vec.at(0))}; + }; + return CheckpointTensorImpl::make("sigmoid", rt, {a})[0]; +} + +Tensor& checkpoint_sigmoid_(Tensor& a) { + mutate_function_t mt = + [=](const Tensors& vec) { + Tensor a_ = vec.at(0); + a.sigmoid_(); + }; + CheckpointTensorImpl::mutate("sigmoid_", mt, {a}, {0}); + return a; +} + +Tensor checkpoint__log_softmax(const Tensor& a, long b, bool c) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::_log_softmax(vec.at(0), b, c)}; + }; + return CheckpointTensorImpl::make("_log_softmax", rt, {a})[0]; +} + +Tensor checkpoint__log_softmax_backward_data(const Tensor& a, const Tensor& b, long c, const Tensor& d) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::_log_softmax_backward_data(vec.at(0), vec.at(1), c, vec.at(2))}; + }; + return CheckpointTensorImpl::make("_log_softmax_backward_data", rt, {a, b, d})[0]; +} + +std::tuple checkpoint_nll_loss_forward(const Tensor& a, const Tensor& b, const Tensor& c, long d, long e) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + auto ret = at::nll_loss_forward(vec.at(0), vec.at(1), vec.at(2), d, e); + return {std::get<0>(ret), std::get<1>(ret)}; + }; + auto ret = CheckpointTensorImpl::make("nll_loss_forward", rt, {a, b, c}); + return {ret[0], ret[1]}; +} + +std::tuple checkpoint_nll_loss_forward_out(Tensor& a, Tensor& b, const Tensor& c, const Tensor& d, const Tensor& e, long f, long g) { + mutate_function_t mt = + [=](const Tensors& vec) { + Tensor a_ = vec.at(0); + Tensor b_ = vec.at(1); + at::nll_loss_forward_out(a_, b_, vec.at(2), vec.at(3), vec.at(4), f, g); + }; + CheckpointTensorImpl::mutate("nll_loss_forward_out", mt, {a, b, c, d, e}, {0, 1}); + return {a, b}; +} + +Tensor checkpoint_nll_loss_backward(const Tensor& a, const Tensor& b, const Tensor& c, const Tensor& d, long e, long f, const Tensor& g) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::nll_loss_backward(vec.at(0), vec.at(1), vec.at(2), vec.at(3), e, f, vec.at(4))}; + }; + return CheckpointTensorImpl::make("nll_loss_backward", rt, {a, b, c, d, g})[0]; +} + +Tensor& checkpoint_nll_loss_backward_grad_input(Tensor& a, const Tensor& b, const Tensor& c, const Tensor& d, const Tensor& e, long f, long g, const Tensor& h) { + mutate_function_t mt = + [=](const Tensors& vec) { + Tensor a_ = vec.at(0); + at::nll_loss_backward_out(a_, vec.at(1), vec.at(2), vec.at(3), vec.at(4), f, g, vec.at(5)); + }; + CheckpointTensorImpl::mutate("nll_loss_backward_grad_input", mt, {a, b, c, d, e, h}, {0}); + return a; +} + +Tensor checkpoint_mm(const Tensor& a, const Tensor& b) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::mm(vec.at(0), vec.at(1))}; + }; + return CheckpointTensorImpl::make("mm", rt, {a, b})[0]; +} + +Tensor& checkpoint_mm_out(Tensor& a, const Tensor& b, const Tensor& c) { + mutate_function_t mt = + [=](const Tensors& vec) { + Tensor a_ = vec.at(0); + at::mm_out(a_, vec.at(1), vec.at(2)); + }; + CheckpointTensorImpl::mutate("mm_out", mt, {a, b, c}, {0}); + return a; +} + +Tensor checkpoint_sum(const Tensor& a, c10::optional b) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::sum(vec.at(0), b)}; + }; + return CheckpointTensorImpl::make("sum", rt, {a})[0]; +} + +Tensor checkpoint_sum_dim_IntList(const Tensor& a, c10::ArrayRef b, bool c, c10::optional d) { + std::vector b_ = b.vec(); + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::sum(vec.at(0), b_, c, d)}; + }; + return CheckpointTensorImpl::make("sum_dim_IntList", rt, {a})[0]; +} + +Tensor checkpoint_threshold(const Tensor& a, c10::Scalar b, c10::Scalar c) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::threshold(vec.at(0), b, c)}; + }; + return CheckpointTensorImpl::make("threshold", rt, {a})[0]; +} + +Tensor& checkpoint_threshold_(Tensor& a, c10::Scalar b, c10::Scalar c) { + mutate_function_t mt = + [=](const Tensors& vec) { + Tensor a_ = vec.at(0); + at::threshold_(a_, b, c); + }; + CheckpointTensorImpl::mutate("threshold_", mt, {a}, {0}); + return a; +} + +Tensor& checkpoint_threshold_out(Tensor& a, const Tensor& b, c10::Scalar c, c10::Scalar d) { + mutate_function_t mt = + [=](const Tensors& vec) { + Tensor a_ = vec.at(0); + at::threshold_out(a_, b, c, d); + }; + CheckpointTensorImpl::mutate("threshold_out", mt, {a}, {0}); + return a; +} + +Tensor checkpoint_threshold_backward(const Tensor& a, const Tensor& b, c10::Scalar c) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::threshold_backward(vec.at(0), vec.at(1), c)}; + }; + return CheckpointTensorImpl::make("threshold_backward", rt, {a, b})[0]; +} + +Tensor checkpoint_select(const Tensor& a, long b, long c) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::select(vec.at(0), b, c)}; + }; + return CheckpointTensorImpl::make("select", rt, {a})[0]; +} + +Tensor checkpoint_select_backward(const Tensor& a, c10::ArrayRef b, long c, long d) { + std::vector b_ = b.vec(); + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::select_backward(vec.at(0), b_, c, d)}; + }; + return CheckpointTensorImpl::make("select_backward", rt, {a})[0]; +} + +Tensor checkpoint_slice(const Tensor& a, long b, long c, long d, long e) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::slice(vec.at(0), b, c, d, e)}; + }; + return CheckpointTensorImpl::make("slice", rt, {a})[0]; +} + +Tensor checkpoint_slice_backward(const Tensor& a, c10::ArrayRef b, long c, long d, long e, long f) { + std::vector b_ = b.vec(); + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::slice_backward(vec.at(0), b_, c, d, e, f)}; + }; + return CheckpointTensorImpl::make("slice_backward", rt, {a})[0]; +} + +Tensor& checkpoint_zero_(Tensor& a) { + mutate_function_t mt = + [=](const Tensors& vec) { + vec.at(0).zero_(); + }; + CheckpointTensorImpl::mutate("zero_", mt, {a}, {0}); + return a; +} + +Tensor& checkpoint_squeeze_(at::Tensor& a, at::Dimname b) { + mutate_function_t mt = + [=](const Tensors& vec) { + vec.at(0).squeeze_(b); + }; + CheckpointTensorImpl::mutate("squeeze_", mt, {a}, {0}); + return a; +} + +Tensor& checkpoint_squeeze_(at::Tensor& a) { + mutate_function_t mt = + [=](const Tensors& vec) { + vec.at(0).squeeze_(); + }; + CheckpointTensorImpl::mutate("squeeze_", mt, {a}, {0}); + return a; +} + +Tensor& checkpoint_squeeze_(at::Tensor& a, long b) { + mutate_function_t mt = + [=](const Tensors& vec) { + vec.at(0).squeeze_(b); + }; + CheckpointTensorImpl::mutate("squeeze_", mt, {a}, {0}); + return a; +} + +Tensor checkpoint_sigmoid_backward(at::Tensor const& a, at::Tensor const& b) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::sigmoid_backward(vec.at(0), vec.at(1))}; + }; + return CheckpointTensorImpl::make("sigmoid_backward", rt, {a, b})[0]; +} + +Tensor& checkpoint_sigmoid_backward_out(at::Tensor& a, at::Tensor const& b, at::Tensor const& c) { + mutate_function_t mt = + [=](const Tensors& vec) { + Tensor a_ = vec.at(0); + at::sigmoid_backward_out(a_, vec.at(1), vec.at(2)); + }; + CheckpointTensorImpl::mutate("sigmoid_backward_out", mt, {a, b, c}, {0}); + return a; +} + +Tensor& checkpoint_sign_out(at::Tensor& a, at::Tensor const& b) { + mutate_function_t mt = + [=](const Tensors& vec) { + Tensor a_ = vec.at(0); + at::sign_out(a_, vec.at(1)); + }; + CheckpointTensorImpl::mutate("sign_out", mt, {a, b}, {0}); + return a; +} + +Tensor checkpoint_sign(const Tensor& a) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::sign(vec.at(0))}; + }; + return CheckpointTensorImpl::make("sign", rt, {a})[0]; +} + +Tensor checkpoint_tanh(const Tensor& a) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::tanh(vec.at(0))}; + }; + return CheckpointTensorImpl::make("tanh", rt, {a})[0]; +} + +Tensor checkpoint_tanh_backward(at::Tensor const& a, at::Tensor const& b) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::tanh_backward(vec.at(0), vec.at(1))}; + }; + return CheckpointTensorImpl::make("tanh_backward", rt, {a, b})[0]; +} + +Tensor& checkpoint_tanh_backward_out(at::Tensor& a, at::Tensor const& b, at::Tensor const& c) { + mutate_function_t mt = + [=](const Tensors& vec) { + Tensor a_ = vec.at(0); + at::tanh_backward_out(a_, vec.at(1), vec.at(2)); + }; + CheckpointTensorImpl::mutate("tanh_backward_out", mt, {a, b, c}, {0}); + return a; +} + +Tensor checkpoint_neg(at::Tensor const& a) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::neg(vec.at(0))}; + }; + return CheckpointTensorImpl::make("neg", rt, {a})[0]; +} + +Tensor checkpoint_sub(at::Tensor const& a, at::Tensor const& b, c10::Scalar c) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::sub(vec.at(0), vec.at(1), c)}; + }; + return CheckpointTensorImpl::make("sub", rt, {a, b})[0]; +} + +Tensor& checkpoint_sub_(at::Tensor& a, at::Tensor const& b, c10::Scalar c) { + mutate_function_t mt = + [=](const Tensors& vec) { + Tensor self = vec.at(0); + self.sub_(vec.at(1), c); + }; + CheckpointTensorImpl::mutate("sub_", mt, {a, b}, {0}); + return a; +} + +Tensor checkpoint_repeat(const at::Tensor& a, c10::ArrayRef b) { + std::vector b_ = b.vec(); + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {vec.at(0).repeat(b_)}; + }; + return CheckpointTensorImpl::make("repeat", rt, {a})[0]; +} + +Tensor checkpoint_mean(const Tensor& self, c10::optional dtype) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::native::mean_cpu_gpu(vec[0], dtype)}; + }; + return CheckpointTensorImpl::make("mean", rt, {self})[0]; +} + +Tensor checkpoint_mean(const Tensor& self, IntArrayRef dim, bool keepdim, c10::optional dtype) { + std::vector dim_ = dim.vec(); + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::native::mean_cpu_gpu(vec[0], dim_, keepdim, dtype)}; + }; + return CheckpointTensorImpl::make("mean.dim", rt, {self})[0]; +} + +Tensor checkpoint__cat(c10::ArrayRef a, long b) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::cat(vec, b)}; + }; + std::vector s; + for (const Tensor& t : a) { + s.push_back(t); + } + return CheckpointTensorImpl::make("_cat", rt, s)[0]; +} + +Tensor& checkpoint__cat_out(Tensor& a, c10::ArrayRef b, long c) { + std::vector args; + args.push_back(a); + for (const Tensor& t : b) { + args.push_back(t); + } + mutate_function_t mt = + [=](const Tensors& vec) { + Tensor t = vec[0]; + at::cat_out(t, ArrayRef(vec.data() + 1, vec.size() - 1), c); + }; + CheckpointTensorImpl::mutate("_cat_out", mt, args, {0}); + return a; +} + +Tensor checkpoint_kl_div(at::Tensor const& a, at::Tensor const& b, long c, bool d) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::kl_div(vec.at(0), vec.at(1), c, d)}; + }; + return CheckpointTensorImpl::make("kl_div", rt, {a, b})[0]; +} + +Tensor checkpoint_kl_div_backward(at::Tensor const& a, at::Tensor const& b, at::Tensor const& c, long d, bool e) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::kl_div_backward(vec.at(0), vec.at(1), vec.at(2), d, e)}; + }; + return CheckpointTensorImpl::make("kl_div_backward", rt, {a, b, c})[0]; +} + +Tensor checkpoint_upsample_bilinear2d(at::Tensor const& self, c10::ArrayRef output_size, bool align_corners, c10::optional scales_h, c10::optional scales_w) { + std::vector output_size_ = output_size.vec(); + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::upsample_bilinear2d(vec.at(0), output_size_, align_corners, scales_h, scales_w)}; + }; + return CheckpointTensorImpl::make("upsample_bilinear2d", rt, {self})[0]; +} + +Tensor& checkpoint_upsample_bilinear2d_out(at::Tensor& out, const at::Tensor& self, c10::ArrayRef output_size, bool align_corners, c10::optional scales_h, c10::optional scales_w) { + std::vector output_size_ = output_size.vec(); + mutate_function_t mt = + [=](const Tensors& vec) { + Tensor out = vec.at(0); + at::upsample_bilinear2d_out(out, vec.at(1), output_size_, align_corners, scales_h, scales_w); + }; + CheckpointTensorImpl::mutate("binary_cross_entropy_out", mt, {out, self}, {0}); + return out; +} + +Tensor& checkpoint_upsample_bilinear2d_backward_out(at::Tensor& grad_input, const at::Tensor& grad_output, c10::ArrayRef output_size, c10::ArrayRef input_size, bool align_corners, c10::optional scales_h, c10::optional scales_w) { + std::vector output_size_ = output_size.vec(); + std::vector input_size_ = input_size.vec(); + mutate_function_t mt = + [=](const Tensors& vec) { + Tensor grad_input = vec.at(0); + at::upsample_bilinear2d_backward_out(grad_input, vec.at(1), output_size_, input_size_, align_corners, scales_h, scales_w); + }; + CheckpointTensorImpl::mutate("upsample_bilinear2d_backward_out", mt, {grad_input, grad_output}, {0}); + return grad_input; +} + +Tensor checkpoint_upsample_bilinear2d_backward(at::Tensor const& grad_output, c10::ArrayRef output_size, c10::ArrayRef input_size, bool align_corners, c10::optional scales_h, c10::optional scales_w) { + std::vector output_size_ = output_size.vec(); + std::vector input_size_ = input_size.vec(); + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::upsample_bilinear2d_backward(vec.at(0), output_size_, input_size_, align_corners, scales_h, scales_w)}; + }; + return CheckpointTensorImpl::make("upsample_bilinear2d_backward", rt, {grad_output})[0]; +} + +Tensor& checkpoint_clamp_min_(Tensor& a, Scalar min) { + mutate_function_t mt = + [=](const Tensors& vec) { + Tensor self = vec.at(0); + at::clamp_min_(self, min); + }; + CheckpointTensorImpl::mutate("clamp_min_", mt, {a}, {0}); + return a; +} + +Tensor& checkpoint_clamp_min_out(Tensor& out, const Tensor& self, Scalar min) { + mutate_function_t mt = + [=](const Tensors& vec) { + Tensor out = vec.at(0); + at::clamp_min_out(out, vec.at(1), min); + }; + CheckpointTensorImpl::mutate("clamp_min__out", mt, {out, self}, {0}); + return out; +} + +Tensor checkpoint_binary_cross_entropy_with_logits(const Tensor& input, const Tensor& target, const Tensor& weight, const Tensor& pos_weight, int64_t reduction) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::binary_cross_entropy_with_logits(vec.at(0), vec.at(1), vec.at(2), vec.at(3), reduction)}; + }; + return CheckpointTensorImpl::make("binary_cross_entropy_with_logits", rt, {input, target, weight, pos_weight})[0]; +} + +Tensor checkpoint_binary_cross_entropy_with_logits_backward(const Tensor& grad, const Tensor& input, const Tensor& target, const Tensor& weight, const Tensor& pos_weight, int64_t reduction) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::binary_cross_entropy_with_logits_backward(vec.at(0), vec.at(1), vec.at(2), vec.at(3), vec.at(4), reduction)}; + }; + return CheckpointTensorImpl::make("binary_cross_entropy_with_logits_backward", rt, {grad, input, target, weight, pos_weight})[0]; +} + +std::tuple checkpoint__fused_dropout(const Tensor & self, double p, c10::optional g) { + // TODO: Figure out how to properly duplicate the generator; + // note that the commented-out code below results in a segfault! + // Ref> gen; + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + // Generator* cur = gen.t ? gen.t.get() : g; + // auto newG = cur->clone(); + // auto res = at::_fused_dropout(vec.at(0), p, cur); + // gen.t = newG; + auto res = at::_fused_dropout(vec.at(0), p, g); + return {std::get<0>(res), std::get<1>(res)}; + }; + auto res = CheckpointTensorImpl::make("_fused_droupout_", rt, {self}); + return {res[0], res[1]}; +} + +std::tuple checkpoint__thnn_fused_lstm_cell(const Tensor& input_gates, const Tensor& hidden_gates, const Tensor& cx, const Tensor& input_bias, const Tensor& hidden_bias) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + auto res = at::_thnn_fused_lstm_cell(vec.at(0), vec.at(1), vec.at(2), + vec.at(3), vec.at(4)); + return {std::get<0>(res), std::get<1>(res), std::get<2>(res)}; + }; + auto res = CheckpointTensorImpl::make("_thnn_fused_lstm_cell", rt, + {input_gates, hidden_gates, cx, input_bias, hidden_bias}); + return {res[0], res[1], res[2]}; +} + +std::tuple checkpoint__thnn_fused_lstm_cell_backward(const Tensor& grad_hy, const Tensor& grad_cy, const Tensor& cx, const Tensor& cy, const Tensor& workspace, bool has_bias) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + auto res = at::_thnn_fused_lstm_cell_backward(vec.at(0), vec.at(1), vec.at(2), vec.at(3), vec.at(4), has_bias); + return {std::get<0>(res), std::get<1>(res), + std::get<2>(res), std::get<3>(res), std::get<4>(res)}; + }; + auto res = CheckpointTensorImpl::make("_thnn_fused_lstm_cell_backward", rt, + {grad_hy, grad_cy, cx, cy, workspace}); + return {res[0], res[1], res[2], res[3], res[4]}; +} + +std::tuple checkpoint__thnn_fused_gru_cell(const Tensor& input_gates, const Tensor& hidden_gates, const Tensor& hx, const Tensor& input_bias, const Tensor& hidden_bias) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + auto res = at::_thnn_fused_gru_cell(vec.at(0), vec.at(1), vec.at(2), vec.at(3), vec.at(4)); + return {std::get<0>(res), std::get<1>(res)}; + }; + auto res = CheckpointTensorImpl::make("_thnn_fused_gru_cell", rt, + {input_gates, hidden_gates, hx, input_bias, hidden_bias}); + return {res[0], res[1]}; +} + +std::tuple checkpoint__thnn_fused_gru_cell_backward(const Tensor& grad_hy, const Tensor& workspace, bool has_bias) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + auto res = at::_thnn_fused_gru_cell_backward(vec.at(0), vec.at(1), has_bias); + return {std::get<0>(res), std::get<1>(res), + std::get<2>(res), std::get<3>(res), std::get<4>(res)}; + }; + auto res = CheckpointTensorImpl::make("_thnn_fused_gru_cell_backward", rt, + {grad_hy, workspace}); + return {res[0], res[1], res[2], res[3], res[4]}; +} + +Tensor& checkpoint_bitwise_and_out(Tensor& self, const Tensor& other, const Tensor& out) { + mutate_function_t mt = + [=](const Tensors& vec) { + Tensor self = vec.at(0); + at::bitwise_and_out(self, vec.at(1), vec.at(2)); + }; + CheckpointTensorImpl::mutate("bitwise_and_out", mt, {self, other, out}, {0}); + return self; +} + +Tensor& checkpoint_bitwise_and_out(Tensor& self, const Tensor& out, Scalar other) { + mutate_function_t mt = + [=](const Tensors& vec) { + Tensor self = vec.at(0); + at::bitwise_and_out(self, vec.at(1), other); + }; + CheckpointTensorImpl::mutate("bitwise_and_out", mt, {self, out}, {0}); + return self; +} + +Tensor& checkpoint_fill_(Tensor& self, const Tensor& value) { + mutate_function_t mt = + [=](const Tensors& vec) { + Tensor self = vec.at(0); + at::fill_(self, vec.at(1)); + }; + CheckpointTensorImpl::mutate("fill_", mt, {self, value}, {0}); + return self; +} + +Tensor& checkpoint_fill_(Tensor& self, Scalar value) { + mutate_function_t mt = + [=](const Tensors& vec) { + Tensor self = vec.at(0); + at::fill_(self, value); + }; + CheckpointTensorImpl::mutate("fill_", mt, {self}, {0}); + return self; +} + +Tensor& checkpoint_masked_select_out(Tensor& self, const Tensor& mask, const Tensor& out) { + mutate_function_t mt = + [=](const Tensors& vec) { + Tensor self = vec.at(0); + at::masked_select_out(self, vec.at(1), vec.at(2)); + }; + CheckpointTensorImpl::mutate("masked_select_out", mt, {self, mask, out}, {0}); + return self; +} + +Tensor checkpoint_masked_select(const Tensor& self, const Tensor& mask) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::masked_select(vec.at(0), vec.at(1))}; + }; + return CheckpointTensorImpl::make("masked_select", rt, {self, mask})[0]; +} + +Tensor checkpoint_index(const Tensor& self, ArrayRef indices) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + auto self = vec.at(0); + auto indices = std::vector(vec.begin() + 1, vec.end()); + return {at::index(self, indices)}; + }; + + std::vector s = {self}; + for (const Tensor& t: indices) { + s.push_back(t); + } + return CheckpointTensorImpl::make("index", rt, s)[0]; +} + +Tensor& checkpoint_index_put_(Tensor& self, ArrayRef indices, const Tensor& values, const bool accumulate) { + mutate_function_t mt = + [=](const Tensors& vec) { + Tensor self = vec.at(0); + auto values = vec.at(1); + auto indices = std::vector(vec.begin() + 2, vec.end()); + at::index_put_(self, indices, values, accumulate); + }; + std::vector s = {self, values}; + for (const Tensor& t: indices) { + s.push_back(t); + } + CheckpointTensorImpl::mutate("index_put_", mt, s, {0}); + return self; +} + +Tensor checkpoint_bmm(const Tensor& self, const Tensor& mat2) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::bmm(vec.at(0), vec.at(1))}; + }; + return CheckpointTensorImpl::make("bmm", rt, {self, mat2})[0]; +} + +Tensor checkpoint__softmax(const Tensor& self, long dim, bool half_to_float) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::_softmax(vec.at(0), dim, half_to_float)}; + }; + return CheckpointTensorImpl::make("_softmax", rt, {self})[0]; +} + +Tensor checkpoint__softmax_backward_data(const Tensor& grad_output, const Tensor& output, long dim, const Tensor& self) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::_softmax_backward_data(vec.at(0), vec.at(1), dim, vec.at(2))}; + }; + return CheckpointTensorImpl::make("_softmax_backward_data", rt, {grad_output, output, self})[0]; +} + +std::tuple +checkpoint_layer_norm(const Tensor& input, const Tensor& weight, const Tensor& bias, long M, long N, double eps) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + auto ret = at::native_layer_norm(vec.at(0), vec.at(1), vec.at(2), M, N, eps); + return {std::get<0>(ret), std::get<1>(ret), std::get<2>(ret)}; + }; + auto ret = CheckpointTensorImpl::make("native_layer_norm", rt, {input, weight, bias}); + return {ret[0], ret[1], ret[2]}; +} + +std::tuple +checkpoint_layer_norm_backward(const Tensor& grad_out, const Tensor& input, const Tensor& mean, const Tensor& rstd, const Tensor& weight, long M, long N, std::array output_mask) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + auto ret = at::native_layer_norm_backward(vec.at(0), vec.at(1), vec.at(2), vec.at(3), vec.at(4), M, N, output_mask); + return {std::get<0>(ret), std::get<1>(ret), std::get<2>(ret)}; + }; + auto ret = CheckpointTensorImpl::make("native_layer_norm_backward", rt, {grad_out, input, mean, rstd, weight}); + return {ret[0], ret[1], ret[2]}; +} + +std::tuple +checkpoint_topk(const Tensor& self, long k, long dim, bool largest, bool sorted) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + auto ret = at::topk(vec.at(0), k, dim, largest, sorted); + return {std::get<0>(ret), std::get<1>(ret)}; + }; + auto ret = CheckpointTensorImpl::make("topk", rt, {self}); + return {ret[0], ret[1]}; +} + +std::tuple +checkpoint_topk_values(Tensor& values, Tensor& indices, const Tensor& self, long k, long dim, bool largest, bool sorted) { + mutate_function_t mt = + [=](const Tensors& vec) { + Tensor values_ = vec.at(0); + Tensor indices_ = vec.at(1); + at::topk_out(values_, indices_, vec.at(2), k, dim, largest, sorted); + }; + CheckpointTensorImpl::mutate("topk_values", mt, {values, indices, self}, {0, 1}); + return {values, indices}; +} + +Tensor& checkpoint_masked_fill_(Tensor& self, const Tensor& mask, Scalar value) { + mutate_function_t mt = + [=](const Tensors& vec) { + Tensor self_ = vec.at(0); + self_.masked_fill_(vec.at(1), value); + }; + CheckpointTensorImpl::mutate("masked_fill_Scalar", mt, {self, mask}, {0}); + return {self}; +} + +Tensor& checkpoint_masked_fill_(Tensor& self, const Tensor& mask, const Tensor& value) { + mutate_function_t mt = + [=](const Tensors& vec) { + Tensor self_ = vec.at(0); + self_.masked_fill_(vec.at(1), vec.at(2)); + }; + CheckpointTensorImpl::mutate("masked_fill_Tensor", mt, {self, mask, value}, {0}); + return {self}; +} + +Tensor checkpoint_clamp(const Tensor& self, c10::optional min, c10::optional max) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::clamp(vec.at(0), min, max)}; + }; + return CheckpointTensorImpl::make("clamp", rt, {self})[0]; +} + +Tensor& checkpoint_clamp_(Tensor& self, c10::optional min, c10::optional max) { + mutate_function_t mt = + [=](const Tensors& vec) { + Tensor self_ = vec.at(0); + at::clamp_(self_, min, max); + }; + CheckpointTensorImpl::mutate("clamp_", mt, {self}, {0}); + return {self}; +} + +Tensor& checkpoint_clamp_out(Tensor& out, const Tensor& self, c10::optional min, c10::optional max) { + mutate_function_t mt = + [=](const Tensors& vec) { + Tensor out_ = vec.at(0); + at::clamp_out(out_, vec.at(1), min, max); + }; + CheckpointTensorImpl::mutate("clamp_out", mt, {out, self}, {0}); + return {out}; +} + +std::tuple checkpoint_thnn_conv2d_forward_out(Tensor& output, Tensor& finput, Tensor& fgrad_input, const Tensor& self, const Tensor& weight, c10::ArrayRef kernel_size, const Tensor& bias, c10::ArrayRef stride, c10::ArrayRef padding) { + auto kernel_size_ = kernel_size.vec(); + auto stride_ = stride.vec(); + auto padding_ = padding.vec(); + mutate_function_t mt = + [=](const Tensors& vec) { + Tensor output_ = vec.at(0); + Tensor finput_ = vec.at(1); + Tensor fgrad_input_ = vec.at(2); + at::thnn_conv2d_forward_out(output_, finput_, fgrad_input_, vec.at(3), vec.at(4), kernel_size_, vec.at(5), stride_, padding_); + }; + CheckpointTensorImpl::mutate("thnn_conv2d_forward_out", mt, {output, finput, fgrad_input, self, weight, bias}, {0, 1, 2}); + return {output, finput, fgrad_input}; +} + +std::tuple checkpoint_thnn_conv2d_forward(const Tensor& self, const Tensor& weight, c10::ArrayRef kernel_size, const Tensor& bias, c10::ArrayRef stride, c10::ArrayRef padding) { + auto kernel_size_ = kernel_size.vec(); + auto stride_ = stride.vec(); + auto padding_ = padding.vec(); + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + auto ret = at::thnn_conv2d_forward(vec.at(0), vec.at(1), kernel_size_, vec.at(2), stride_, padding_); + return {std::get<0>(ret), std::get<1>(ret), std::get<2>(ret)}; + }; + auto ret = CheckpointTensorImpl::make("thnn_conv2d_forward", rt, {self, weight, bias}); + return {ret[0], ret[1], ret[2]}; +} + +std::tuple checkpoint_thnn_conv2d_backward_out(Tensor& grad_input, Tensor& grad_weight, Tensor& grad_bias, const Tensor& grad_output, const Tensor& self, const Tensor& weight, c10::ArrayRef kernel_size, c10::ArrayRef stride, c10::ArrayRef padding, const Tensor& finput, const Tensor& fgrad_input) { + auto kernel_size_ = kernel_size.vec(); + auto stride_ = stride.vec(); + auto padding_ = padding.vec(); + mutate_function_t mt = + [=](const Tensors& vec) { + Tensor grad_input_ = vec.at(0); + Tensor grad_weight_ = vec.at(1); + Tensor grad_bias_ = vec.at(2); + at::thnn_conv2d_backward_out(grad_input_, grad_weight_, grad_bias_, vec.at(3), vec.at(4), vec.at(5), kernel_size_, stride_, padding_, vec.at(6), vec.at(7)); + }; + CheckpointTensorImpl::mutate("thnn_conv2d_backward_out", mt, {grad_input, grad_weight, grad_bias, grad_output, self, weight, finput, fgrad_input}, {0, 1, 2}); + return {grad_input, grad_weight, grad_bias}; +} + +std::tuple checkpoint_thnn_conv2d_backward(const Tensor& grad_output, const Tensor& self, const Tensor& weight, c10::ArrayRef kernel_size, c10::ArrayRef stride, c10::ArrayRef padding, const Tensor& finput, const Tensor& fgrad_input, std::array output_mask) { + auto kernel_size_ = kernel_size.vec(); + auto stride_ = stride.vec(); + auto padding_ = padding.vec(); + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + auto ret = at::thnn_conv2d_backward(vec.at(0), vec.at(1), vec.at(2), kernel_size_, stride_, padding_, vec.at(3), vec.at(4), output_mask); + return {std::get<0>(ret), std::get<1>(ret), std::get<2>(ret)}; + }; + auto ret = CheckpointTensorImpl::make("thnn_conv2d_backward", rt, {grad_output, self, weight, finput, fgrad_input}); + return {ret[0], ret[1], ret[2]}; +} + +std::tuple checkpoint_native_batch_norm(const Tensor& input, const Tensor& weight, const Tensor& bias, const Tensor& running_mean, const Tensor& running_var, bool training, double momentum, double eps) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + auto ret = at::native_batch_norm(vec.at(0), vec.at(1), vec.at(2), vec.at(3), vec.at(4), training, momentum, eps); + return {std::get<0>(ret), std::get<1>(ret), std::get<2>(ret)}; + }; + auto ret = CheckpointTensorImpl::make("native_batch_norm", rt, {input, weight, bias, running_mean, running_var}); + return {ret[0], ret[1], ret[2]}; +} + +std::tuple checkpoint_native_batch_norm_out(Tensor& out, Tensor& save_mean, Tensor& save_invstd, const Tensor& input, const Tensor& weight, const Tensor& bias, const Tensor& running_mean, const Tensor& running_var, bool training, double momentum, double eps) { + mutate_function_t mt = + [=](const Tensors& vec) { + Tensor out_ = vec.at(0); + Tensor save_mean_ = vec.at(1); + Tensor save_invstd_ = vec.at(2); + at::native_batch_norm_out(out_, save_mean_, save_invstd_, vec.at(3), vec.at(4), vec.at(5), vec.at(6), vec.at(7), training, momentum, eps); + }; + CheckpointTensorImpl::mutate("native_batch_norm_out", mt, {out, save_mean, save_invstd, input, weight, bias, running_mean, running_var}, {0, 1, 2}); + return {out, save_mean, save_invstd}; +} + +std::tuple checkpoint_native_batch_norm_backward(const Tensor& grad_out, const Tensor& input, const Tensor& weight, const Tensor& running_mean, const Tensor& running_var, const Tensor& save_mean, const Tensor& save_invstd, bool train, double eps, std::array output_mask) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + auto ret = at::native_batch_norm_backward(vec.at(0), vec.at(1), vec.at(2), vec.at(3), vec.at(4), vec.at(5), vec.at(6), train, eps, output_mask); + return {std::get<0>(ret), std::get<1>(ret), std::get<2>(ret)}; + }; + auto ret = CheckpointTensorImpl::make("native_batch_norm_backward", rt, {grad_out, input, weight, running_mean, running_var, save_mean, save_invstd}); + return {ret[0], ret[1], ret[2]}; +} + +std::tuple checkpoint__cudnn_ctc_loss(const Tensor& log_probs, const Tensor& targets, ArrayRef input_lengths, ArrayRef target_lengths, long blank, bool deterministic, bool zero_infinity) { + auto input_lengths_ = input_lengths.vec(); + auto target_lengths_ = target_lengths.vec(); + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + auto ret = at::_cudnn_ctc_loss(vec.at(0), vec.at(1), input_lengths_, target_lengths_, blank, deterministic, zero_infinity); + return {std::get<0>(ret), std::get<1>(ret)}; + }; + auto ret = CheckpointTensorImpl::make("_cudnn_ctc_loss", rt, {log_probs, targets}); + return {ret[0], ret[1]}; +} + +std::tuple checkpoint__ctc_loss(const Tensor& log_probs, const Tensor& targets, ArrayRef input_lengths, ArrayRef target_lengths, long blank, bool zero_infinity) { + auto input_lengths_ = input_lengths.vec(); + auto target_lengths_ = target_lengths.vec(); + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + auto ret = at::_ctc_loss(vec.at(0), vec.at(1), input_lengths_, target_lengths_, blank, zero_infinity); + return {std::get<0>(ret), std::get<1>(ret)}; + }; + auto ret = CheckpointTensorImpl::make("_ctc_loss", rt, {log_probs, targets}); + return {ret[0], ret[1]}; +} + +Tensor checkpoint__ctc_loss_backward(const Tensor& grad, const Tensor& log_probs, const Tensor& targets, ArrayRef input_lengths, ArrayRef target_lengths, const Tensor& neg_log_likelihood, const Tensor& log_alpha, long blank, bool zero_infinity) { + auto input_lengths_ = input_lengths.vec(); + auto target_lengths_ = target_lengths.vec(); + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::_ctc_loss_backward(vec.at(0), vec.at(1), vec.at(2), input_lengths_, target_lengths_, vec.at(3), vec.at(4), blank, zero_infinity)}; + }; + return CheckpointTensorImpl::make("_ctc_loss_backward", rt, {grad, log_probs, targets, neg_log_likelihood, log_alpha})[0]; +} + +Tensor& checkpoint_hardtanh_backward_out(Tensor& grad_input, const Tensor& grad_output, const Tensor& self, Scalar min_val, Scalar max_val) { + mutate_function_t mt = + [=](const Tensors& vec) { + Tensor grad_input_ = vec.at(0); + at::hardtanh_backward_out(grad_input_, vec.at(1), vec.at(2), min_val, max_val); + }; + CheckpointTensorImpl::mutate("hardtanh_backward_out", mt, {grad_input, grad_output, self}, {0}); + return {grad_input}; +} + +Tensor checkpoint_hardtanh_backward(const Tensor& grad_output, const Tensor& self, Scalar min_val, Scalar max_val) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::hardtanh_backward(vec.at(0), vec.at(1), min_val, max_val)}; + }; + return CheckpointTensorImpl::make("hardtanh_backward", rt, {grad_output, self})[0]; +} + +Tensor checkpoint_nonzero(const Tensor& self) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::nonzero(vec.at(0))}; + }; + return CheckpointTensorImpl::make("nonzero", rt, {self})[0]; +} + +Tensor& checkpoint_nonzero_out(Tensor& out, const Tensor& self) { + mutate_function_t mt = + [=](const Tensors& vec) { + Tensor out_ = vec.at(0); + at::nonzero_out(out_, vec.at(1)); + }; + CheckpointTensorImpl::mutate("nonzero_out", mt, {out, self}, {0}); + return {out}; +} + +Tensor checkpoint_lt(const Tensor& self, Scalar other) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::lt(vec.at(0), other)}; + }; + return CheckpointTensorImpl::make("lt_Scalar", rt, {self})[0]; +} + +Tensor& checkpoint_lt_out(Tensor& out, const Tensor& self, Scalar other) { + mutate_function_t mt = + [=](const Tensors& vec) { + Tensor out_ = vec.at(0); + at::lt_out(out_, vec.at(1), other); + }; + CheckpointTensorImpl::mutate("lt_Scalar_out", mt, {out, self}, {0}); + return {out}; +} + +Tensor checkpoint_lt(const Tensor& self, const Tensor& other) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::lt(vec.at(0), vec.at(1))}; + }; + return CheckpointTensorImpl::make("lt_Tensor", rt, {self, other})[0]; +} + +Tensor& checkpoint_lt_out(Tensor& out, const Tensor& self, const Tensor& other) { + mutate_function_t mt = + [=](const Tensors& vec) { + Tensor out_ = vec.at(0); + at::lt_out(out_, vec.at(1), vec.at(2)); + }; + CheckpointTensorImpl::mutate("lt_Tensor_out", mt, {out, self, other}, {0}); + return {out}; +} + +Tensor checkpoint_any(const Tensor& self) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::any(vec.at(0))}; + }; + return CheckpointTensorImpl::make("any", rt, {self})[0]; +} + +bool checkpoint__use_cudnn_ctc_loss(const Tensor& log_probs, const Tensor& targets, ArrayRef input_lengths, ArrayRef target_lengths, long blank) { + return at::_use_cudnn_ctc_loss(decheckpoint(log_probs), decheckpoint(targets), input_lengths, target_lengths, blank); +} + +bool checkpoint_equal(const Tensor& self, const Tensor& other) { + // there can't possibly be a reason to rematerialize + // a single bool so we'll just compute it now + return at::equal(decheckpoint(self), decheckpoint(other)); +} + +Scalar checkpoint__local_scalar_dense(at::Tensor const& a) { + return at::_local_scalar_dense(decheckpoint(a)); +} + +Tensor checkpoint_split_with_sizes_backward(c10::ArrayRef a, c10::ArrayRef b, long c, c10::ArrayRef d, c10::TensorOptions const& e) { + std::vector a_ = a.vec(); + std::vector d_ = d.vec(); + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::split_with_sizes_backward(vec, b, c, d_, e)}; + }; + return CheckpointTensorImpl::make("split_with_sizes_backward", rt, a_)[0]; +} + +std::vector checkpoint_split_with_sizes(at::Tensor const& a, c10::ArrayRef b, long c) { + std::vector b_ = b.vec(); + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return at::split_with_sizes(vec.at(0), b_, c); + }; + return CheckpointTensorImpl::make("split_with_sizes", rt, {a}); +} + +Tensor checkpoint_split_backward(c10::ArrayRef a, long b, long c, c10::ArrayRef d, const c10::TensorOptions& e) { + std::vector a_ = a.vec(); + std::vector d_ = d.vec(); + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::split_backward(vec, b, c, d_, e)}; + }; + return CheckpointTensorImpl::make("split_backward", rt, a_)[0]; +} + +std::vector checkpoint_split(const at::Tensor& a, long b, long c) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return at::split(vec.at(0), b, c); + }; + return CheckpointTensorImpl::make("split", rt, {a}); +} + +Tensor checkpoint_expand(at::Tensor const& a, c10::ArrayRef b, bool c) { + std::vector b_ = b.vec(); + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {vec.at(0).expand(b_, c)}; + }; + return CheckpointTensorImpl::make("expand", rt, {a})[0]; +} + +Tensor checkpoint_diag(at::Tensor const& self, long diagonal) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::diag(vec.at(0), diagonal)}; + }; + return CheckpointTensorImpl::make("diag", rt, {self})[0]; +} + +Tensor& checkpoint_diag_out(at::Tensor& out, const Tensor& self, long diagonal) { + mutate_function_t mt = + [=](const Tensors& vec) { + Tensor out_ = vec.at(0); + at::diag_out(out_, vec.at(1), diagonal); + }; + CheckpointTensorImpl::mutate("diag_out", mt, {out, self}, {0}); + return {out}; +} + +Tensor checkpoint_mv(at::Tensor const& self, at::Tensor const& vec) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::mv(vec.at(0), vec.at(1))}; + }; + return CheckpointTensorImpl::make("mv", rt, {self, vec})[0]; +} + +Tensor& checkpoint_mv_out(at::Tensor& out, const Tensor& self, const Tensor& vec) { + mutate_function_t mt = + [=](const Tensors& vec) { + Tensor out_ = vec.at(0); + at::mv_out(out_, vec.at(1), vec.at(2)); + }; + CheckpointTensorImpl::mutate("mv_out", mt, {out, self, vec}, {0}); + return {out}; +} + +}} diff --git a/aten/src/ATen/native/TensorShape.cpp b/aten/src/ATen/native/TensorShape.cpp index 1854938a30d..4ee19b6964d 100644 --- a/aten/src/ATen/native/TensorShape.cpp +++ b/aten/src/ATen/native/TensorShape.cpp @@ -1047,30 +1047,6 @@ Tensor slice(const Tensor& self, int64_t dim, int64_t start, int64_t end, int64_ return result; } -std::vector split(const Tensor& self, int64_t split_size, int64_t dim) { - TORCH_CHECK(self.dim() != 0, "split expects at least a 1-dimensional tensor"); - TORCH_CHECK(split_size >= 0, "split expects split_size be non-negative, but got split_size=", split_size); - int64_t dim_size = self.size(dim); - TORCH_CHECK(split_size > 0 || self.size(dim) == 0, - "split_size can only be 0 if dimension size is 0, " - "but got dimension size of ", dim_size); - // if split_size is 0 and dimension size is 0, there is 1 split. - int64_t num_splits = 1; - if (split_size != 0) { - // ensuring num_splits is at least 1 makes consistent the case where split_size > dim_size - // (returns a single split). We might want to error here, but keep it for BC. - num_splits = std::max((dim_size + split_size - 1) / split_size, 1); - } - std::vector splits(num_splits); - int64_t last_split_size = split_size - (split_size * num_splits - dim_size); - - for (int64_t i = 0; i < num_splits; ++i) { - auto length = i < num_splits - 1 ? split_size : last_split_size; - splits[i] = self.narrow(dim, i * split_size, length); - } - return splits; -} - std::vector unsafe_split(const Tensor& self, int64_t split_size, int64_t dim) { auto result = at::native::split(self, split_size, dim); for (auto& t : result) { @@ -1079,28 +1055,6 @@ std::vector unsafe_split(const Tensor& self, int64_t split_size, int64_t return result; } -std::vector split_with_sizes(const Tensor& self, IntArrayRef split_sizes, int64_t dim) { - TORCH_CHECK(self.dim() != 0, "split expects at least a 1-dimensional tensor"); - int64_t dim_size = self.size(dim); - int64_t num_splits = split_sizes.size(); - std::vector splits(num_splits); - int64_t start_idx = 0; - int64_t i; - - for (i = 0; i < num_splits; ++i) { - auto length = split_sizes[i]; - TORCH_CHECK(length >= 0, - "split_with_sizes expects split_sizes have only non-negative ", - "entries, but got split_sizes=", split_sizes); - splits[i] = self.narrow(dim, start_idx, length); - start_idx += length; - } - TORCH_CHECK(start_idx == dim_size, - "split_with_sizes expects split_sizes to sum exactly to ", dim_size, - " (input tensor's size at dimension ", dim, "), ", "but got split_sizes=", split_sizes); - return splits; -} - std::vector unsafe_split_with_sizes(const Tensor& self, IntArrayRef split_sizes, int64_t dim) { auto result = at::native::split_with_sizes(self, split_sizes, dim); for (auto& t : result) { diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index dcf349324e2..bdf36210523 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -5,6 +5,60 @@ # representing ScalarType's. They are now superseded by usage of # `aten::to()`. The ops remain here for backward compatibility purposes. +- func: checkpoint(Tensor self) -> Tensor + variants: method + +# convert a tensor to a checkpoint tensor. +# if the input is already a checkpoint tensor, +# checkpoint will fail while try_checkpoint will +# simply return the input. +- func: try_checkpoint(Tensor self) -> Tensor + variants: method + +- func: is_checkpoint(Tensor self) -> bool + variants: method + +# convert checkpointed tensor into normal tensor. +# uncheckpoint assume the input is checkpointed and will fail otherwise. +# decheckpoint return the input if it is not checkpointed. +- func: uncheckpoint(Tensor self) -> Tensor + variants: method + +- func: decheckpoint(Tensor self) -> Tensor + variants: method + +- func: pin(Tensor(a!) self) -> () + variants: method + +- func: new_log(str logname) -> () + variants: function + +- func: annotate_log(str logname) -> () + variants: function + +- func: toggle_log(bool use_log) -> () + variants: function + +- func: clear_checkpointpool() -> () + variants: function + +- func: set_memory_budget(int budget) -> () + variants: function + +- func: unset_memory_budget() -> () + variants: function + +- func: reset_compute_time() -> () + variants: function + +- func: compute_time() -> int + variants: function + +# Temporary type cast operators. These are needed to trace type-casts now since +# Type's are not supported in the IR. Instead, we call down to these +# specialized operators for each datatype. +# TODO: remove when we have Type support in the IR + # DEPRECATED. DO NOT USE - func: _cast_Byte(Tensor self, bool non_blocking=False) -> Tensor use_c10_dispatcher: full @@ -131,11 +185,13 @@ use_c10_dispatcher: full dispatch: CUDA: _use_cudnn_ctc_loss + Checkpoint: checkpoint__use_cudnn_ctc_loss - func: _cudnn_ctc_loss(Tensor log_probs, Tensor targets, int[] input_lengths, int[] target_lengths, int blank, bool deterministic, bool zero_infinity) -> (Tensor, Tensor) use_c10_dispatcher: full dispatch: CUDA: _cudnn_ctc_loss + Checkpoint: checkpoint__cudnn_ctc_loss - func: _use_cudnn_rnn_flatten_weight() -> bool use_c10_dispatcher: full @@ -168,12 +224,15 @@ variants: function dispatch: CUDA: fused_dropout_cuda + Checkpoint: checkpoint__fused_dropout + supports_named_tensor: True - func: _masked_scale(Tensor self, Tensor mask, float scale) -> Tensor use_c10_dispatcher: full variants: function dispatch: CUDA: masked_scale_cuda + Checkpoint: checkpoint__masked_scale - func: _sobol_engine_draw(Tensor quasi, int n, Tensor sobolstate, int dimension, int num_generated, ScalarType? dtype) -> (Tensor, Tensor) use_c10_dispatcher: full @@ -220,6 +279,10 @@ - func: abs(Tensor self) -> Tensor use_c10_dispatcher: full variants: function, method + supports_named_tensor: True + dispatch: + CUDA, CPU: abs + Checkpoint: checkpoint_abs - func: abs_(Tensor(a!) self) -> Tensor(a!) use_c10_dispatcher: full @@ -326,6 +389,8 @@ CPU, CUDA: add SparseCPU, SparseCUDA: add_sparse MkldnnCPU: mkldnn_add + Checkpoint: checkpoint_add + supports_named_tensor: True - func: add_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> Tensor(a!) use_c10_dispatcher: full @@ -334,6 +399,8 @@ CPU, CUDA: add_ SparseCPU, SparseCUDA: add_sparse_ MkldnnCPU: mkldnn_add_ + Checkpoint: checkpoint_add_ + supports_named_tensor: True - func: add.out(Tensor self, Tensor other, *, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!) dispatch: @@ -363,6 +430,10 @@ - func: add.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> Tensor use_c10_dispatcher: full variants: function, method + supports_named_tensor: True + dispatch: + CPU, CUDA: add + Checkpoint: checkpoint_add - func: add_.Scalar(Tensor(a!) self, Scalar other, Scalar alpha=1) -> Tensor(a!) use_c10_dispatcher: full @@ -533,6 +604,7 @@ dispatch: CPU, CUDA: as_strided_tensorimpl QuantizedCPU, QuantizedCUDA: as_strided_qtensorimpl + Checkpoint: checkpoint_as_strided device_guard: False - func: as_strided_(Tensor(a!) self, int[] size, int[] stride, int? storage_offset=None) -> Tensor(a!) @@ -677,6 +749,7 @@ dispatch: CPU: binary_cross_entropy_cpu CUDA: binary_cross_entropy_cuda + Checkpoint: checkpoint_binary_cross_entropy - func: binary_cross_entropy.out(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean, *, Tensor(a!) out) -> Tensor(a!) python_module: nn @@ -684,6 +757,7 @@ dispatch: CPU: binary_cross_entropy_out_cpu CUDA: binary_cross_entropy_out_cuda + Checkpoint: checkpoint_binary_cross_entropy_out - func: binary_cross_entropy_backward(Tensor grad_output, Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean) -> Tensor use_c10_dispatcher: full @@ -692,6 +766,7 @@ dispatch: CPU: binary_cross_entropy_backward_cpu CUDA: binary_cross_entropy_backward_cuda + Checkpoint: checkpoint_binary_cross_entropy_backward - func: binary_cross_entropy_backward.grad_input(Tensor grad_output, Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean, *, Tensor(a!) grad_input) -> Tensor(a!) python_module: nn @@ -699,14 +774,23 @@ dispatch: CPU: binary_cross_entropy_backward_out_cpu CUDA: binary_cross_entropy_backward_out_cuda + Checkpoint: checkpoint_binary_cross_entropy_backward_out - func: binary_cross_entropy_with_logits(Tensor self, Tensor target, Tensor? weight=None, Tensor? pos_weight=None, int reduction=Mean) -> Tensor use_c10_dispatcher: full variants: function + dispatch: + CPU: binary_cross_entropy_with_logits + CUDA: binary_cross_entropy_with_logits + Checkpoint: checkpoint_binary_cross_entropy_with_logits - func: binary_cross_entropy_with_logits_backward(Tensor grad_output, Tensor self, Tensor target, Tensor? weight=None, Tensor? pos_weight=None, int reduction=Mean) -> Tensor use_c10_dispatcher: full variants: function + dispatch: + CPU: binary_cross_entropy_with_logits_backward + CUDA: binary_cross_entropy_with_logits_backward + Checkpoint: checkpoint_binary_cross_entropy_with_logits_backward - func: bincount(Tensor self, Tensor? weights=None, int minlength=0) -> Tensor use_c10_dispatcher: full @@ -789,12 +873,7 @@ CUDA: bmm_cuda SparseCPU: bmm_sparse_cpu SparseCUDA: bmm_sparse_cuda - -- func: _bmm(Tensor self, Tensor mat2, *, bool deterministic=False) -> Tensor - use_c10_dispatcher: full - variants: function - dispatch: - SparseCUDA: _bmm_sparse_cuda + Checkpoint: checkpoint_bmm - func: bmm.out(Tensor self, Tensor mat2, *, Tensor(a!) out) -> Tensor(a!) variants: function @@ -804,6 +883,12 @@ SparseCPU: bmm_out_sparse_cpu SparseCUDA: bmm_out_sparse_cuda +- func: _bmm(Tensor self, Tensor mat2, *, bool deterministic=False) -> Tensor + use_c10_dispatcher: full + variants: function + dispatch: + SparseCUDA: _bmm_sparse_cuda + - func: _bmm.out(Tensor self, Tensor mat2, *, bool deterministic=False, Tensor(a!) out) -> Tensor(a!) variants: function dispatch: @@ -858,12 +943,17 @@ dispatch: CPU, CUDA: clamp QuantizedCPU: clamp_quantized_cpu + Checkpoint: checkpoint_clamp - func: clamp_(Tensor(a!) self, Scalar? min=None, Scalar? max=None) -> Tensor(a!) use_c10_dispatcher: full variants: function, method - func: clamp.out(Tensor self, Scalar? min=None, Scalar? max=None, *, Tensor(a!) out) -> Tensor(a!) + supports_named_tensor: True + dispatch: + CPU, CUDA: clamp_out + Checkpoint: checkpoint_clamp_out - func: clamp_max(Tensor self, Scalar max) -> Tensor use_c10_dispatcher: full @@ -882,8 +972,15 @@ - func: clamp_min_(Tensor(a!) self, Scalar min) -> Tensor(a!) use_c10_dispatcher: full variants: function, method + dispatch: + CPU, CUDA: clamp_min_ + Checkpoint: checkpoint_clamp_min_ - func: clamp_min.out(Tensor self, Scalar min, *, Tensor(a!) out) -> Tensor(a!) + supports_named_tensor: True + dispatch: + CPU, CUDA: clamp_min_out + Checkpoint: checkpoint_clamp_min_out # clip is an alias for clamp - func: clip(Tensor self, Scalar? min=None, Scalar? max=None) -> Tensor @@ -914,6 +1011,10 @@ - func: constant_pad_nd(Tensor self, int[] pad, Scalar value=0) -> Tensor use_c10_dispatcher: full variants: function + dispatch: + CPU: constant_pad_nd + CUDA: constant_pad_nd + Checkpoint: checkpoint_constant_pad_nd - func: contiguous(Tensor(a) self, *, MemoryFormat memory_format=contiguous_format) -> Tensor(a) use_c10_dispatcher: full @@ -1020,12 +1121,14 @@ use_c10_dispatcher: full dispatch: CUDA: cudnn_batch_norm + Checkpoint: checkpoint_cudnn_batch_norm # NB: You can only use this if you used cudnn_batch_norm training=True - func: cudnn_batch_norm_backward(Tensor input, Tensor grad_output, Tensor weight, Tensor? running_mean, Tensor? running_var, Tensor? save_mean, Tensor? save_var, float epsilon, Tensor reserveSpace) -> (Tensor, Tensor, Tensor) use_c10_dispatcher: full dispatch: CUDA: cudnn_batch_norm_backward + Checkpoint: checkpoint_cudnn_batch_norm_backward - func: cudnn_convolution.deprecated(Tensor self, Tensor weight, Tensor? bias, int[] padding, int[] stride, int[] dilation, int groups, bool benchmark, bool deterministic) -> Tensor use_c10_dispatcher: full @@ -1041,21 +1144,25 @@ use_c10_dispatcher: full dispatch: CUDA: cudnn_convolution + Checkpoint: checkpoint_cudnn_convolution - func: cudnn_convolution_backward_input(int[] self_size, Tensor grad_output, Tensor weight, int[] padding, int[] stride, int[] dilation, int groups, bool benchmark, bool deterministic, bool allow_tf32) -> Tensor use_c10_dispatcher: full dispatch: CUDA: cudnn_convolution_backward_input + Checkpoint: checkpoint_cudnn_convolution_backward_input - func: cudnn_convolution_backward(Tensor self, Tensor grad_output, Tensor weight, int[] padding, int[] stride, int[] dilation, int groups, bool benchmark, bool deterministic, bool allow_tf32, bool[2] output_mask) -> (Tensor, Tensor) use_c10_dispatcher: full dispatch: CUDA: cudnn_convolution_backward + Checkpoint: checkpoint_cudnn_convolution_backward - func: cudnn_convolution_backward_weight(int[] weight_size, Tensor grad_output, Tensor self, int[] padding, int[] stride, int[] dilation, int groups, bool benchmark, bool deterministic, bool allow_tf32) -> Tensor use_c10_dispatcher: full dispatch: CUDA: cudnn_convolution_backward_weight + Checkpoint: checkpoint_cudnn_convolution_backward_weight - func: cudnn_convolution_transpose.deprecated(Tensor self, Tensor weight, Tensor? bias, int[] padding, int[] output_padding, int[] stride, int[] dilation, int groups, bool benchmark, bool deterministic) -> Tensor use_c10_dispatcher: full @@ -1071,6 +1178,7 @@ use_c10_dispatcher: full dispatch: CUDA: cudnn_convolution_transpose + Checkpoint: checkpoint_cudnn_convolution_transpose # NB: output_padding not strictly needed here, but it's helpful for the float # backwards @@ -1078,16 +1186,19 @@ use_c10_dispatcher: full dispatch: CUDA: cudnn_convolution_transpose_backward + Checkpoint: checkpoint_cudnn_convolution_transpose_backward - func: cudnn_convolution_transpose_backward_input(Tensor grad_output, Tensor weight, int[] padding, int[] stride, int[] dilation, int groups, bool benchmark, bool deterministic, bool allow_tf32) -> Tensor use_c10_dispatcher: full dispatch: CUDA: cudnn_convolution_transpose_backward_input + Checkpoint: checkpoint_cudnn_convolution_transpose_backward_input - func: cudnn_convolution_transpose_backward_weight(int[] weight_size, Tensor grad_output, Tensor self, int[] padding, int[] stride, int[] dilation, int groups, bool benchmark, bool deterministic, bool allow_tf32) -> Tensor use_c10_dispatcher: full dispatch: CUDA: cudnn_convolution_transpose_backward_weight + Checkpoint: checkpoint_cudnn_convolution_transpose_backward_weight # NB: input is special cased in a way I don't quite understand - func: cudnn_grid_sampler(Tensor self, Tensor grid) -> Tensor output @@ -1168,12 +1279,14 @@ dispatch: CPU: ctc_loss_cpu CUDA: ctc_loss_gpu + Checkpoint: checkpoint__ctc_loss - func: _ctc_loss_backward(Tensor grad, Tensor log_probs, Tensor targets, int[] input_lengths, int[] target_lengths, Tensor neg_log_likelihood, Tensor log_alpha, int blank, bool zero_infinity=False) -> Tensor use_c10_dispatcher: full dispatch: CPU: ctc_loss_backward_cpu CUDA: ctc_loss_backward_gpu + Checkpoint: checkpoint__ctc_loss_backward - func: diag_embed(Tensor self, int offset=0, int dim1=-2, int dim2=-1) -> Tensor use_c10_dispatcher: full @@ -1200,6 +1313,8 @@ dispatch: CPU, CUDA: div SparseCPU, SparseCUDA: div_sparse + Checkpoint: checkpoint_div + supports_named_tensor: True - func: div_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) use_c10_dispatcher: full @@ -1207,6 +1322,8 @@ dispatch: CPU, CUDA: div_ SparseCPU, SparseCUDA: div_sparse_ + Checkpoint: checkpoint_div_ + supports_named_tensor: True - func: div.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) dispatch: @@ -1236,9 +1353,17 @@ - func: embedding(Tensor weight, Tensor indices, int padding_idx=-1, bool scale_grad_by_freq=False, bool sparse=False) -> Tensor use_c10_dispatcher: full + dispatch: + CPU: embedding + CUDA: embedding + Checkpoint: checkpoint_embedding - func: embedding_backward(Tensor grad, Tensor indices, int num_weights, int padding_idx, bool scale_grad_by_freq, bool sparse) -> Tensor use_c10_dispatcher: full + dispatch: + CPU: embedding_backward + CUDA: embedding_backward + Checkpoint: checkpoint_embedding_backward - func: embedding_dense_backward(Tensor grad_output, Tensor indices, int num_weights, int padding_idx, bool scale_grad_by_freq) -> Tensor use_c10_dispatcher: full @@ -1417,6 +1542,11 @@ use_c10_dispatcher: full variants: method # This is method-only to match the previous tensor API. In the future we could make this a function too. device_guard: False + supports_named_tensor: True + dispatch: + CPU: expand + CUDA: expand + Checkpoint: checkpoint_expand - func: expand_as(Tensor(a) self, Tensor other) -> Tensor(a) use_c10_dispatcher: full @@ -1461,10 +1591,18 @@ - func: fill_.Scalar(Tensor(a!) self, Scalar value) -> Tensor(a!) use_c10_dispatcher: full variants: function, method + dispatch: + CPU: fill_ + CUDA: fill_ + Checkpoint: checkpoint_fill_ - func: fill_.Tensor(Tensor(a!) self, Tensor value) -> Tensor(a!) use_c10_dispatcher: full variants: function, method + dispatch: + CPU: fill_ + CUDA: fill_ + Checkpoint: checkpoint_fill_ - func: floor(Tensor self) -> Tensor use_c10_dispatcher: full @@ -1667,6 +1805,10 @@ - func: index.Tensor(Tensor self, Tensor?[] indices) -> Tensor variants: function, method + dispatch: + CPU: index + CUDA: index + Checkpoint: checkpoint_index # NB: This function is special-cased in tools/autograd/gen_variable_type.py # NB: The following functions are declared in aten/src/ATen/templates/TensorBody.h and defined in aten/src/ATen/TensorIndexing.cpp: # - Tensor Tensor::index(ArrayRef indices) @@ -1688,6 +1830,10 @@ - func: index_put_(Tensor(a!) self, Tensor?[] indices, Tensor values, bool accumulate=False) -> Tensor(a!) variants: function, method + dispatch: + CPU: index_put_ + CUDA: index_put_ + Checkpoint: checkpoint_index_put_ # NB: The following functions are declared in aten/src/ATen/templates/TensorBody.h and defined in aten/src/ATen/TensorIndexing.cpp: # - Tensor & Tensor::index_put_(ArrayRef indices, Tensor const & rhs) # - Tensor & Tensor::index_put_(ArrayRef indices, Scalar v) @@ -1765,12 +1911,16 @@ - func: kl_div(Tensor self, Tensor target, int reduction=Mean, *, bool log_target=False) -> Tensor use_c10_dispatcher: full + dispatch: + CPU, CUDA: kl_div + Checkpoint: checkpoint_kl_div - func: kl_div_backward(Tensor grad_output, Tensor self, Tensor target, int reduction=Mean, *, bool log_target=False) -> Tensor use_c10_dispatcher: full dispatch: CPU: kl_div_backward_cpu CUDA: kl_div_backward_cuda + Checkpoint: checkpoint_kl_div_backward - func: kthvalue(Tensor self, int k, int dim=-1, bool keepdim=False) -> (Tensor values, Tensor indices) use_c10_dispatcher: full @@ -1794,12 +1944,14 @@ dispatch: CPU: layer_norm_cpu CUDA: layer_norm_cuda + Checkpoint: checkpoint_layer_norm - func: native_layer_norm_backward(Tensor grad_out, Tensor input, Tensor mean, Tensor rstd, Tensor? weight, int M, int N, bool[3] output_mask) -> (Tensor, Tensor, Tensor) use_c10_dispatcher: full dispatch: CPU: layer_norm_backward_cpu CUDA: layer_norm_backward_cuda + Checkpoint: checkpoint_layer_norm_backward - func: linear(Tensor input, Tensor weight, Tensor? bias=None) -> Tensor use_c10_dispatcher: full @@ -1846,6 +1998,10 @@ - func: log(Tensor self) -> Tensor use_c10_dispatcher: full variants: function, method + dispatch: + CPU: log + CUDA: log + Checkpoint: checkpoint_log - func: log_(Tensor(a!) self) -> Tensor(a!) use_c10_dispatcher: full @@ -1854,6 +2010,7 @@ - func: log.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) dispatch: CPU, CUDA: log_out + Checkpoint: checkpoint_log_out - func: log10(Tensor self) -> Tensor use_c10_dispatcher: full @@ -1932,12 +2089,14 @@ dispatch: CPU: log_softmax_cpu CUDA: log_softmax_cuda + Checkpoint: checkpoint__log_softmax - func: _log_softmax_backward_data(Tensor grad_output, Tensor output, int dim, Tensor self) -> Tensor use_c10_dispatcher: full dispatch: CPU: log_softmax_backward_cpu CUDA: log_softmax_backward_cuda + Checkpoint: checkpoint__log_softmax_backward_data - func: _logcumsumexp(Tensor self, int dim) -> Tensor use_c10_dispatcher: full @@ -2059,6 +2218,7 @@ dispatch: CPU, CUDA: mean_cpu_gpu QuantizedCPU: mean_quantized_cpu + Checkpoint: checkpoint_mean - func: mean.dim(Tensor self, int[1] dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor use_c10_dispatcher: full @@ -2066,6 +2226,7 @@ dispatch: CPU, CUDA: mean_cpu_gpu QuantizedCPU: mean_quantized_cpu + Checkpoint: checkpoint_mean - func: mean.out(Tensor self, int[1] dim, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!) dispatch: @@ -2211,12 +2372,16 @@ CPU: mm_cpu CUDA: mm_cuda SparseCPU, SparseCUDA: _sparse_mm + Checkpoint: checkpoint_mm + supports_named_tensor: True - func: mm.out(Tensor self, Tensor mat2, *, Tensor(a!) out) -> Tensor(a!) dispatch: CPU: mm_cpu_out CUDA: mm_out_cuda SparseCPU, SparseCUDA: _sparse_mm_out + Checkpoint: checkpoint_mm_out + supports_named_tensor: True - func: _sparse_mm(Tensor sparse, Tensor dense) -> Tensor use_c10_dispatcher: full @@ -2239,6 +2404,8 @@ CPU, CUDA: mul SparseCPU, SparseCUDA: mul_sparse MkldnnCPU: mkldnn_mul + Checkpoint: checkpoint_mul + supports_named_tensor: True - func: mul_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) use_c10_dispatcher: full @@ -2247,6 +2414,8 @@ CPU, CUDA: mul_ SparseCPU, SparseCUDA: mul_sparse_ MkldnnCPU: mkldnn_mul_ + Checkpoint: checkpoint_mul_ + supports_named_tensor: True - func: mul.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) dispatch: @@ -2259,10 +2428,21 @@ - func: mul.Scalar(Tensor self, Scalar other) -> Tensor use_c10_dispatcher: full variants: function, method + dispatch: + CPU: mul + CUDA: mul + SparseCPU: mul + SparseCUDA: mul + MkldnnCPU: mul + Checkpoint: checkpoint_mul - func: mul_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) use_c10_dispatcher: full variants: method + dispatch: + CPU: mul_ + CUDA: mul_ + Checkpoint: checkpoint_mul_ - func: mv(Tensor self, Tensor vec) -> Tensor use_c10_dispatcher: full @@ -2270,8 +2450,13 @@ dispatch: CPU, CUDA: mv SparseCPU, SparseCUDA: mv_sparse + Checkpoint: checkpoint_mv - func: mv.out(Tensor self, Tensor vec, *, Tensor(a!) out) -> Tensor(a!) + dispatch: + CPU, CUDA: mv_out + Checkpoint: checkpoint_mv_out + supports_named_tensor: True - func: mvlgamma(Tensor self, int p) -> Tensor use_c10_dispatcher: full @@ -2304,10 +2489,12 @@ CPU: batch_norm_cpu CUDA: batch_norm_cuda MkldnnCPU: mkldnn_batch_norm + Checkpoint: checkpoint_native_batch_norm - func: native_batch_norm.out(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps, *, Tensor(a!) out, Tensor(b!) save_mean, Tensor(c!) save_invstd) -> (Tensor(a!), Tensor(b!), Tensor(c!)) dispatch: CUDA: batch_norm_cuda_out + Checkpoint: checkpoint_native_batch_norm_out - func: batch_norm_stats(Tensor input, float eps) -> (Tensor, Tensor) use_c10_dispatcher: full @@ -2339,6 +2526,7 @@ dispatch: CPU: batch_norm_backward_cpu CUDA: batch_norm_backward_cuda + Checkpoint: checkpoint_native_batch_norm_backward - func: batch_norm_backward_reduce(Tensor grad_out, Tensor input, Tensor mean, Tensor invstd, Tensor? weight, bool input_g, bool weight_g, bool bias_g) -> (Tensor, Tensor, Tensor, Tensor) use_c10_dispatcher: full @@ -2388,6 +2576,10 @@ - func: ones_like(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor use_c10_dispatcher: full + supports_named_tensor: True + dispatch: + CPU, CUDA: ones_like + Checkpoint: checkpoint_ones_like - func: pairwise_distance(Tensor x1, Tensor x2, float p=2, float eps=1e-06, bool keepdim=False) -> Tensor use_c10_dispatcher: full @@ -2584,6 +2776,10 @@ - func: neg(Tensor self) -> Tensor use_c10_dispatcher: full variants: function, method + dispatch: + CPU: neg + CUDA: neg + Checkpoint: checkpoint_neg - func: neg_(Tensor(a!) self) -> Tensor(a!) use_c10_dispatcher: full @@ -2605,6 +2801,10 @@ - func: repeat(Tensor self, int[] repeats) -> Tensor use_c10_dispatcher: full variants: method # This is method-only to match the previous tensor API. In the future we could make this a function too. + dispatch: + CPU: repeat + CUDA: repeat + Checkpoint: checkpoint_repeat - func: repeat_interleave.Tensor(Tensor repeats) -> Tensor use_c10_dispatcher: full @@ -2661,6 +2861,8 @@ CPU, CUDA: relu MkldnnCPU: mkldnn_relu QuantizedCPU: relu_quantized_cpu + Checkpoint: checkpoint_relu + supports_named_tensor: True - func: relu_(Tensor(a!) self) -> Tensor(a!) use_c10_dispatcher: full @@ -2669,6 +2871,7 @@ CPU, CUDA: relu_ MkldnnCPU: mkldnn_relu_ QuantizedCPU: relu_quantized_cpu_ + Checkpoint: checkpoint_relu_ - func: prelu(Tensor self, Tensor weight) -> Tensor use_c10_dispatcher: full @@ -2726,6 +2929,17 @@ use_c10_dispatcher: full variants: function, method device_guard: False + supports_named_tensor: True + dispatch: + CPU, CUDA: select + Checkpoint: checkpoint_select + +- func: select_backward(Tensor grad, int[] sizes, int dim, int index) -> Tensor + variants: function + device_guard: False + dispatch: + CPU, CUDA: select_backward + Checkpoint: checkpoint_select_backward - func: selu(Tensor self) -> Tensor use_c10_dispatcher: full @@ -2761,6 +2975,7 @@ CPU, CUDA: sigmoid QuantizedCPU: sigmoid_quantized_cpu MkldnnCPU: mkldnn_sigmoid + Checkpoint: checkpoint_sigmoid - func: sigmoid_(Tensor(a!) self) -> Tensor(a!) use_c10_dispatcher: full @@ -2768,6 +2983,7 @@ dispatch: CPU, CUDA: sigmoid_ MkldnnCPU: mkldnn_sigmoid_ + Checkpoint: checkpoint_sigmoid_ - func: sigmoid.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) @@ -2842,6 +3058,19 @@ use_c10_dispatcher: full variants: function, method device_guard: False + supports_named_tensor: True + dispatch: + CPU: slice + CUDA: slice + Checkpoint: checkpoint_slice + +- func: slice_backward(Tensor grad, int[] input_sizes, int dim, int start, int end, int step) -> Tensor + variants: function, method + device_guard: False + dispatch: + CPU: slice_backward + CUDA: slice_backward + Checkpoint: checkpoint_slice_backward - func: slogdet(Tensor self) -> (Tensor sign, Tensor logabsdet) use_c10_dispatcher: full @@ -2865,12 +3094,14 @@ CPU: softmax_cpu CUDA: softmax_cuda MkldnnCPU: mkldnn_softmax + Checkpoint: checkpoint__softmax - func: _softmax_backward_data(Tensor grad_output, Tensor output, int dim, Tensor self) -> Tensor use_c10_dispatcher: full dispatch: CPU: softmax_backward_cpu CUDA: softmax_backward_cuda + Checkpoint: checkpoint__softmax_backward_data - func: unsafe_split.Tensor(Tensor self, int split_size, int dim=0) -> Tensor[] use_c10_dispatcher: full @@ -2881,6 +3112,18 @@ use_c10_dispatcher: full variants: function, method device_guard: False + supports_named_tensor: True + dispatch: + CPU: split + CUDA: split + Checkpoint: checkpoint_split + +- func: split_backward(Tensor[] grads, int split_size, int dim, int[] sizes, ScalarType dtype, Layout layout, Device device, bool pin_memory=False) -> Tensor + variants: function + dispatch: + CPU: split_backward + CUDA: split_backward + Checkpoint: checkpoint_split_backward - func: unsafe_split_with_sizes(Tensor self, int[] split_sizes, int dim=0) -> Tensor[] use_c10_dispatcher: full @@ -2891,6 +3134,18 @@ use_c10_dispatcher: full variants: function, method device_guard: False + supports_named_tensor: True + dispatch: + CPU: split_with_sizes + CUDA: split_with_sizes + Checkpoint: checkpoint_split_with_sizes + +- func: split_with_sizes_backward(Tensor[] grads, int[] split_sizes, int dim, int[] sizes, ScalarType dtype, Layout layout, Device device, bool pin_memory=False) -> Tensor + variants: function + dispatch: + CPU: split_with_sizes_backward + CUDA: split_with_sizes_backward + Checkpoint: checkpoint_split_with_sizes_backward - func: squeeze(Tensor(a) self) -> Tensor(a) use_c10_dispatcher: full @@ -2910,15 +3165,27 @@ use_c10_dispatcher: full variants: method device_guard: False + dispatch: + CPU: squeeze_ + CUDA: squeeze_ + Checkpoint: checkpoint_squeeze_ - func: squeeze_.dim(Tensor(a!) self, int dim) -> Tensor(a!) use_c10_dispatcher: full variants: method device_guard: False + dispatch: + CPU: squeeze_ + CUDA: squeeze_ + Checkpoint: checkpoint_squeeze_ - func: squeeze_.dimname(Tensor(a!) self, Dimname dim) -> Tensor(a!) variants: method device_guard: False + dispatch: + CPU: squeeze_ + CUDA: squeeze_ + Checkpoint: checkpoint_squeeze_ - func: sspaddmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor use_c10_dispatcher: full @@ -2975,10 +3242,20 @@ - func: sum(Tensor self, *, ScalarType? dtype=None) -> Tensor use_c10_dispatcher: full variants: function, method + supports_named_tensor: True + dispatch: + CPU: sum + CUDA: sum + Checkpoint: checkpoint_sum - func: sum.dim_IntList(Tensor self, int[1] dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor use_c10_dispatcher: full variants: function, method + supports_named_tensor: True + dispatch: + CPU: sum + CUDA: sum + Checkpoint: checkpoint_sum_dim_IntList - func: sum.dim_DimnameList(Tensor self, Dimname[1] dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor variants: function, method @@ -3005,6 +3282,10 @@ - func: sqrt(Tensor self) -> Tensor use_c10_dispatcher: full variants: function, method + dispatch: + CPU: sqrt + CUDA: sqrt + Checkpoint: checkpoint_sqrt - func: sqrt_(Tensor(a!) self) -> Tensor(a!) use_c10_dispatcher: full @@ -3065,6 +3346,10 @@ use_c10_dispatcher: full device_guard: False variants: function, method + supports_named_tensor: True + dispatch: + CPU, CUDA: t + Checkpoint: checkpoint_t - func: t_(Tensor(a!) self) -> Tensor(a!) use_c10_dispatcher: full @@ -3087,6 +3372,7 @@ dispatch: CPU, CUDA: tanh QuantizedCPU: tanh_quantized_cpu + Checkpoint: checkpoint_tanh - func: tanh_(Tensor(a!) self) -> Tensor(a!) use_c10_dispatcher: full @@ -3106,6 +3392,7 @@ CPU: threshold CUDA: threshold_cuda QuantizedCPU: threshold_quantized_cpu + Checkpoint: checkpoint_threshold - func: threshold_(Tensor(a!) self, Scalar threshold, Scalar value) -> Tensor(a!) use_c10_dispatcher: full @@ -3113,11 +3400,13 @@ dispatch: CPU: threshold_ CUDA: threshold__cuda + Checkpoint: checkpoint_threshold_ - func: threshold.out(Tensor self, Scalar threshold, Scalar value, *, Tensor(a!) out) -> Tensor(a!) dispatch: CPU: threshold_out CUDA: threshold_out_cuda + Checkpoint: checkpoint_threshold_out - func: threshold_backward(Tensor grad_output, Tensor self, Scalar threshold) -> Tensor use_c10_dispatcher: full @@ -3125,6 +3414,7 @@ dispatch: CPU: threshold_backward CUDA: threshold_backward_cuda + Checkpoint: checkpoint_threshold_backward - func: transpose.int(Tensor(a) self, int dim0, int dim1) -> Tensor(a) use_c10_dispatcher: full @@ -3345,6 +3635,10 @@ - func: where.self(Tensor condition, Tensor self, Tensor other) -> Tensor use_c10_dispatcher: full variants: function, method + dispatch: + CPU: where + CUDA: where + Checkpoint: checkpoint_where - func: where.ScalarSelf(Tensor condition, Scalar self, Tensor other) -> Tensor use_c10_dispatcher: full @@ -3402,6 +3696,10 @@ - func: zeros_like(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor use_c10_dispatcher: full + supports_named_tensor: True + dispatch: + CPU, CUDA: zeros_like + Checkpoint: checkpoint_zeros_like - func: _standard_gamma_grad(Tensor self, Tensor output) -> Tensor use_c10_dispatcher: full @@ -3567,6 +3865,8 @@ SparseCPU, SparseCUDA: clone_sparse MkldnnCPU: mkldnn_clone QuantizedCPU, QuantizedCUDA: quantized_clone + Checkpoint: checkpoint_clone + supports_named_tensor: True - func: resize_as_(Tensor(a!) self, Tensor the_template, *, MemoryFormat? memory_format=None) -> Tensor(a!) use_c10_dispatcher: full @@ -3591,6 +3891,7 @@ CPU, CUDA: zero_ SparseCPU, SparseCUDA: zero_sparse_ MkldnnCPU: mkldnn_zero_ + Checkpoint: checkpoint_zero_ - func: sub.out(Tensor self, Tensor other, *, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!) dispatch: @@ -3603,6 +3904,8 @@ dispatch: CPU, CUDA: sub SparseCPU, SparseCUDA: sub_sparse + Checkpoint: checkpoint_sub + supports_named_tensor: True - func: sub_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> Tensor(a!) use_c10_dispatcher: full @@ -3630,6 +3933,10 @@ - func: subtract_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> Tensor(a!) use_c10_dispatcher: full variants: method + dispatch: + CPU, CUDA, SparseCPU, SparseCUDA: subtract_ + Checkpoint: checkpoint_subtract_ + supports_named_tensor: True # For C++ only, until we have conversion from C++ numbers to Tensor - func: subtract.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> Tensor @@ -3643,6 +3950,10 @@ - func: rsub.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor use_c10_dispatcher: full variants: function + supports_named_tensor: True + dispatch: + CPU, CUDA, SparseCPU, SparseCUDA: rsub + Checkpoint: checkpoint_rsub - func: heaviside.out(Tensor self, Tensor values, *, Tensor(a!) out) -> Tensor(a!) dispatch: @@ -3659,6 +3970,10 @@ - func: rsub.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> Tensor use_c10_dispatcher: full variants: function + supports_named_tensor: True + dispatch: + CPU, CUDA, SparseCPU, SparseCUDA: rsub + Checkpoint: checkpoint_rsub # Functionally the same as addmm, but we give it a different derivative formula # that doesn't propagate gradients to non-present entries on sparse. @@ -3671,6 +3986,8 @@ CUDA: addmm_out_cuda SparseCPU: addmm_out_sparse_dense_cpu SparseCUDA: addmm_out_sparse_dense_cuda + Checkpoint: checkpoint_addmm_out + supports_named_tensor: True - func: addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor use_c10_dispatcher: full @@ -3680,6 +3997,8 @@ CUDA: addmm_cuda SparseCPU: addmm_sparse_dense_cpu SparseCUDA: addmm_sparse_dense_cuda + Checkpoint: checkpoint_addmm + supports_named_tensor: True - func: addmm_(Tensor(a!) self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor(a!) use_c10_dispatcher: full @@ -3691,6 +4010,9 @@ # broadcasting SparseCPU: s_addmm_sparse_dense_cpu_ SparseCUDA: s_addmm_sparse_dense_cuda_ + Checkpoint: checkpoint_addmm_ + supports_named_tensor: True + # NOTE [ Sparse: autograd and API ] # @@ -4136,6 +4458,10 @@ use_c10_dispatcher: full variants: method device_guard: False + supports_named_tensor: True + dispatch: + CPU, CUDA: to + Checkpoint: checkpoint_to - func: to.device(Tensor self, Device device, ScalarType dtype, bool non_blocking=False, bool copy=False, MemoryFormat? memory_format=None) -> Tensor use_c10_dispatcher: full @@ -4196,6 +4522,7 @@ dispatch: CPU: _local_scalar_dense_cpu CUDA: _local_scalar_dense_cuda + Checkpoint: checkpoint__local_scalar_dense variants: function # Fused RNN kernels @@ -4203,11 +4530,13 @@ use_c10_dispatcher: full dispatch: CUDA: _thnn_fused_lstm_cell_cuda + Checkpoint: checkpoint__thnn_fused_lstm_cell - func: _thnn_fused_lstm_cell_backward(Tensor? grad_hy, Tensor? grad_cy, Tensor cx, Tensor cy, Tensor workspace, bool has_bias) -> (Tensor, Tensor, Tensor, Tensor, Tensor) use_c10_dispatcher: full dispatch: CUDA: _thnn_fused_lstm_cell_backward_cuda + Checkpoint: checkpoint__thnn_fused_lstm_cell_backward - func: _thnn_differentiable_lstm_cell_backward(Tensor? grad_hy, Tensor? grad_cy, Tensor input_gates, Tensor hidden_gates, Tensor? input_bias, Tensor? hidden_bias, Tensor cx, Tensor cy) -> (Tensor, Tensor, Tensor, Tensor, Tensor) use_c10_dispatcher: full @@ -4216,11 +4545,13 @@ use_c10_dispatcher: full dispatch: CUDA: _thnn_fused_gru_cell_cuda + Checkpoint: checkpoint__thnn_fused_gru_cell - func: _thnn_fused_gru_cell_backward(Tensor grad_hy, Tensor workspace, bool has_bias) -> (Tensor, Tensor, Tensor, Tensor, Tensor) use_c10_dispatcher: full dispatch: CUDA: _thnn_fused_gru_cell_backward_cuda + Checkpoint: checkpoint__thnn_fused_gru_cell_backward - func: _thnn_differentiable_gru_cell_backward(Tensor grad_hy, Tensor input_gates, Tensor hidden_gates, Tensor hx, Tensor? input_bias, Tensor? hidden_bias) -> (Tensor, Tensor, Tensor, Tensor, Tensor) use_c10_dispatcher: full @@ -4350,6 +4681,8 @@ dispatch: CPU: masked_fill__cpu CUDA: masked_fill__cuda + Checkpoint: checkpoint_masked_fill_ + supports_named_tensor: True - func: masked_fill.Scalar(Tensor self, Tensor mask, Scalar value) -> Tensor use_c10_dispatcher: full @@ -4361,6 +4694,8 @@ dispatch: CPU: masked_fill__cpu CUDA: masked_fill__cuda + Checkpoint: checkpoint_masked_fill_ + supports_named_tensor: True - func: masked_fill.Tensor(Tensor self, Tensor mask, Tensor value) -> Tensor use_c10_dispatcher: full @@ -4384,6 +4719,7 @@ dispatch: CPU, CUDA, QuantizedCPU, QuantizedCUDA: view MkldnnCPU: mkldnn_view + Checkpoint: checkpoint_view - func: put_(Tensor(a!) self, Tensor index, Tensor source, bool accumulate=False) -> Tensor(a!) use_c10_dispatcher: full @@ -4542,11 +4878,13 @@ variants: function dispatch: CPU, CUDA: bitwise_and_out + Checkpoint: checkpoint_bitwise_and_out - func: bitwise_and.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) variants: function dispatch: CPU, CUDA: bitwise_and_out + Checkpoint: checkpoint_bitwise_and_out - func: bitwise_and.Scalar(Tensor self, Scalar other) -> Tensor use_c10_dispatcher: full @@ -4826,6 +5164,10 @@ - func: addcdiv_(Tensor(a!) self, Tensor tensor1, Tensor tensor2, *, Scalar value=1) -> Tensor(a!) use_c10_dispatcher: full variants: method + supports_named_tensor: True + dispatch: + CPU, CUDA: addcdiv_ + Checkpoint: checkpoint_addcdiv_ - func: random_.from(Tensor(a!) self, int from, int? to, *, Generator? generator=None) -> Tensor(a!) variants: method @@ -4857,10 +5199,14 @@ dispatch: CPU: diag_cpu_out CUDA: diag_cuda_out + Checkpoint: checkpoint_diag_out - func: diag(Tensor self, int diagonal=0) -> Tensor use_c10_dispatcher: full variants: method, function + dispatch: + CPU, CUDA: diag + Checkpoint: checkpoint_diag - func: cross.out(Tensor self, Tensor other, int? dim=None, *, Tensor(a!) out) -> Tensor(a!) @@ -4909,6 +5255,7 @@ dispatch: CPU, CUDA: ne_out QuantizedCPU: ne_out_quantized_cpu + Checkpoint: checkpoint_ne_Scalar_out - func: ne.Scalar(Tensor self, Scalar other) -> Tensor use_c10_dispatcher: full @@ -4916,11 +5263,13 @@ dispatch: CPU, CUDA: ne QuantizedCPU: ne_quantized_cpu + Checkpoint: checkpoint_ne_Scalar - func: ne.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) dispatch: CPU, CUDA: ne_out QuantizedCPU: ne_out_quantized_cpu + Checkpoint: checkpoint_ne_Tensor_out - func: ne.Tensor(Tensor self, Tensor other) -> Tensor use_c10_dispatcher: full @@ -4928,11 +5277,13 @@ dispatch: CPU, CUDA: ne QuantizedCPU: ne_quantized_cpu + Checkpoint: checkpoint_ne_Tensor - func: eq.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) dispatch: CPU, CUDA: eq_out QuantizedCPU: eq_out_quantized_cpu + Checkpoint: checkpoint_eq_Scalar_out - func: eq.Scalar(Tensor self, Scalar other) -> Tensor use_c10_dispatcher: full @@ -4940,11 +5291,13 @@ dispatch: CPU, CUDA: eq QuantizedCPU: eq_quantized_cpu + Checkpoint: checkpoint_eq_Scalar - func: eq.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) dispatch: CPU, CUDA: eq_out QuantizedCPU: eq_out_quantized_cpu + Checkpoint: checkpoint_eq_Tensor_out - func: eq.Tensor(Tensor self, Tensor other) -> Tensor use_c10_dispatcher: full @@ -4952,6 +5305,7 @@ dispatch: CPU, CUDA: eq QuantizedCPU: eq_quantized_cpu + Checkpoint: checkpoint_eq_Tensor - func: ge.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) dispatch: @@ -5029,6 +5383,7 @@ dispatch: CPU, CUDA: lt_out QuantizedCPU: lt_out_quantized_cpu + Checkpoint: checkpoint_lt_out - func: lt.Scalar(Tensor self, Scalar other) -> Tensor use_c10_dispatcher: full @@ -5036,11 +5391,13 @@ dispatch: CPU, CUDA: lt QuantizedCPU: lt_quantized_cpu + Checkpoint: checkpoint_lt - func: lt.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) dispatch: CPU, CUDA: lt_out QuantizedCPU: lt_out_quantized_cpu + Checkpoint: checkpoint_lt_out - func: lt.Tensor(Tensor self, Tensor other) -> Tensor use_c10_dispatcher: full @@ -5048,6 +5405,7 @@ dispatch: CPU, CUDA: lt QuantizedCPU: lt_quantized_cpu + Checkpoint: checkpoint_lt - func: take.out(Tensor self, Tensor index, *, Tensor(a!) out) -> Tensor(a!) dispatch: @@ -5084,6 +5442,8 @@ dispatch: CPU: masked_select_out_cpu CUDA: masked_select_out_cuda + Checkpoint: checkpoint_masked_select_out + supports_named_tensor: True - func: masked_select(Tensor self, Tensor mask) -> Tensor use_c10_dispatcher: full @@ -5091,11 +5451,14 @@ dispatch: CPU: masked_select_cpu CUDA: masked_select_cuda + Checkpoint: checkpoint_masked_select + supports_named_tensor: True - func: nonzero.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) dispatch: CPU: legacy::cpu::_th_nonzero_out CUDA: legacy::cuda::_th_nonzero_out + Checkpoint: checkpoint_nonzero_out - func: nonzero(Tensor self) -> Tensor use_c10_dispatcher: full @@ -5103,6 +5466,7 @@ dispatch: CPU: legacy::cpu::_th_nonzero CUDA: legacy::cuda::_th_nonzero + Checkpoint: checkpoint_nonzero - func: nonzero_numpy(Tensor self) -> Tensor[] use_c10_dispatcher: full @@ -5136,12 +5500,20 @@ - func: addcmul_(Tensor(a!) self, Tensor tensor1, Tensor tensor2, *, Scalar value=1) -> Tensor(a!) use_c10_dispatcher: full variants: method + supports_named_tensor: True + dispatch: + CPU, CUDA: addcmul_ + Checkpoint: checkpoint_addcmul_ - func: addcdiv.out(Tensor self, Tensor tensor1, Tensor tensor2, *, Scalar value=1, Tensor(a!) out) -> Tensor(a!) - func: addcdiv(Tensor self, Tensor tensor1, Tensor tensor2, *, Scalar value=1) -> Tensor use_c10_dispatcher: full variants: method, function + supports_named_tensor: True + dispatch: + CPU, CUDA: addcdiv + Checkpoint: checkpoint_addcdiv - func: lstsq.X(Tensor self, Tensor A, *, Tensor(a!) X, Tensor(b!) qr) -> (Tensor(a!) solution, Tensor(b!) QR) dispatch: @@ -5389,6 +5761,10 @@ - func: sign(Tensor self) -> Tensor use_c10_dispatcher: full variants: function, method + supports_named_tensor: True + dispatch: + CPU, CUDA: sign + Checkpoint: checkpoint_sign - func: sign_(Tensor(a!) self) -> Tensor(a!) use_c10_dispatcher: full @@ -5397,6 +5773,7 @@ - func: sign.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) dispatch: CPU, CUDA: sign_out + Checkpoint: checkpoint_sign_out - func: signbit(Tensor self) -> Tensor use_c10_dispatcher: full @@ -5612,6 +5989,7 @@ dispatch: CPU: topk_out_cpu CUDA: legacy::cuda::_th_topk_out + Checkpoint: checkpoint_topk_values - func: topk(Tensor self, int k, int dim=-1, bool largest=True, bool sorted=True) -> (Tensor values, Tensor indices) use_c10_dispatcher: full @@ -5619,6 +5997,7 @@ dispatch: CPU, CUDA: topk QuantizedCPU: topk_quantized_cpu + Checkpoint: checkpoint_topk - func: all(Tensor self) -> Tensor use_c10_dispatcher: full @@ -5630,6 +6009,7 @@ dispatch: CPU, CUDA: any SparseCPU, SparseCUDA: any_sparse + Checkpoint: checkpoint_any - func: renorm.out(Tensor self, Scalar p, int dim, Scalar maxnorm, *, Tensor(a!) out) -> Tensor(a!) dispatch: @@ -5664,6 +6044,8 @@ CPU: cpu_equal CUDA: cuda_equal QuantizedCPU: equal_quantized_cpu + Checkpoint: checkpoint_equal + supports_named_tensor: True - func: pow.Tensor_Tensor_out(Tensor self, Tensor exponent, *, Tensor(a!) out) -> Tensor(a!) dispatch: @@ -5782,12 +6164,13 @@ CPU: _cat_cpu CUDA: cat_cuda QuantizedCPU: cat_quantized_cpu + Checkpoint: checkpoint__cat - func: _cat.out(Tensor[] tensors, int dim=0, *, Tensor(a!) out) -> Tensor(a!) dispatch: CPU: _cat_out_cpu CUDA: cat_out_cuda - QuantizedCPU: cat_out_quantized_cpu + Checkpoint: checkpoint__cat_out - func: _foreach_add.Scalar(Tensor[] tensors, Scalar scalar) -> Tensor[] use_c10_dispatcher: full @@ -5956,6 +6339,7 @@ dispatch: CPU: nll_loss_forward_out_cpu CUDA: legacy::cuda::_thnn_nll_loss_forward_out + Checkpoint: checkpoint_nll_loss_forward_out - func: nll_loss_forward(Tensor self, Tensor target, Tensor? weight, int reduction, int ignore_index) -> (Tensor output, Tensor total_weight) use_c10_dispatcher: full @@ -5963,12 +6347,14 @@ dispatch: CPU: nll_loss_forward_cpu CUDA: legacy::cuda::_thnn_nll_loss_forward + Checkpoint: checkpoint_nll_loss_forward - func: nll_loss_backward.grad_input(Tensor grad_output, Tensor self, Tensor target, Tensor? weight, int reduction, int ignore_index, Tensor total_weight, *, Tensor(a!) grad_input) -> Tensor(a!) python_module: nn dispatch: CPU: nll_loss_backward_out_cpu CUDA: legacy::cuda::_thnn_nll_loss_backward_out + Checkpoint: checkpoint_nll_loss_backward_grad_input - func: nll_loss_backward(Tensor grad_output, Tensor self, Tensor target, Tensor? weight, int reduction, int ignore_index, Tensor total_weight) -> Tensor use_c10_dispatcher: full @@ -5976,6 +6362,7 @@ dispatch: CPU: nll_loss_backward_cpu CUDA: legacy::cuda::_thnn_nll_loss_backward + Checkpoint: checkpoint_nll_loss_backward - func: nll_loss2d.out(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean, int ignore_index=-100, *, Tensor(a!) out) -> Tensor(a!) python_module: nn @@ -6127,10 +6514,15 @@ python_module: nn dispatch: CPU, CUDA: hardtanh_backward_out + Checkpoint: checkpoint_hardtanh_backward_out - func: hardtanh_backward(Tensor grad_output, Tensor self, Scalar min_val, Scalar max_val) -> Tensor use_c10_dispatcher: full python_module: nn + dispatch: + CPU: hardtanh_backward + CUDA: hardtanh_backward + Checkpoint: checkpoint_hardtanh_backward - func: hardtanh_(Tensor(a!) self, Scalar min_val=-1, Scalar max_val=1) -> Tensor(a!) use_c10_dispatcher: full @@ -6386,6 +6778,7 @@ CPU: avg_pool2d_out_cpu CUDA: avg_pool2d_out_cuda MkldnnCPU: mkldnn_avg_pool2d_out + Checkpoint: checkpoint_avg_pool2d_out - func: avg_pool2d(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, bool ceil_mode=False, bool count_include_pad=True, int? divisor_override=None) -> Tensor use_c10_dispatcher: full @@ -6395,12 +6788,14 @@ CUDA: avg_pool2d_cuda MkldnnCPU: mkldnn_avg_pool2d QuantizedCPU: avg_pool2d_quantized_cpu + Checkpoint: checkpoint_avg_pool2d - func: avg_pool2d_backward.grad_input(Tensor grad_output, Tensor self, int[2] kernel_size, int[2] stride, int[2] padding, bool ceil_mode, bool count_include_pad, int? divisor_override, *, Tensor(a!) grad_input) -> Tensor(a!) python_module: nn dispatch: CPU: avg_pool2d_backward_out_cpu CUDA: avg_pool2d_backward_out_cuda + Checkpoint: checkpoint_avg_pool2d_backward_grad_input - func: avg_pool2d_backward(Tensor grad_output, Tensor self, int[2] kernel_size, int[2] stride, int[2] padding, bool ceil_mode, bool count_include_pad, int? divisor_override) -> Tensor use_c10_dispatcher: full @@ -6408,6 +6803,7 @@ dispatch: CPU: avg_pool2d_backward_cpu CUDA: avg_pool2d_backward_cuda + Checkpoint: checkpoint_avg_pool2d_backward - func: avg_pool3d.out(Tensor self, int[3] kernel_size, int[3] stride=[], int[3] padding=0, bool ceil_mode=False, bool count_include_pad=True, int? divisor_override=None, *, Tensor(a!) out) -> Tensor(a!) python_module: nn @@ -6500,6 +6896,7 @@ dispatch: CPU: max_pool2d_with_indices_out_cpu CUDA: max_pool2d_with_indices_out_cuda + Checkpoint: checkpoint_max_pool2d_with_indices_out # Return: (Tensor output, Tensor indices) - func: max_pool2d_with_indices(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> (Tensor, Tensor) @@ -6508,12 +6905,15 @@ dispatch: CPU: max_pool2d_with_indices_cpu CUDA: max_pool2d_with_indices_cuda + Checkpoint: checkpoint_max_pool2d_with_indices + supports_named_tensor: True - func: max_pool2d_with_indices_backward.grad_input(Tensor grad_output, Tensor self, int[2] kernel_size, int[2] stride, int[2] padding, int[2] dilation, bool ceil_mode, Tensor indices, *, Tensor(a!) grad_input) -> Tensor(a!) python_module: nn dispatch: CPU: max_pool2d_with_indices_backward_out_cpu CUDA: max_pool2d_with_indices_backward_out_cuda + Checkpoint: checkpoint_max_pool2d_with_indices_backward_grad_input - func: max_pool2d_with_indices_backward(Tensor grad_output, Tensor self, int[2] kernel_size, int[2] stride, int[2] padding, int[2] dilation, bool ceil_mode, Tensor indices) -> Tensor use_c10_dispatcher: full @@ -6521,6 +6921,7 @@ dispatch: CPU: max_pool2d_with_indices_backward_cpu CUDA: max_pool2d_with_indices_backward_cuda + Checkpoint: checkpoint_max_pool2d_with_indices_backward # Return: (Tensor output, Tensor indices) - func: max_pool3d_with_indices.out(Tensor self, int[3] kernel_size, int[3] stride=[], int[3] padding=0, int[3] dilation=1, bool ceil_mode=False, *, Tensor(a!) out, Tensor(b!) indices) -> (Tensor(a!), Tensor(b!)) @@ -6764,6 +7165,7 @@ dispatch: CPU: upsample_bilinear2d_out_cpu CUDA: upsample_bilinear2d_out_cuda + Checkpoint: checkpoint_upsample_bilinear2d_out - func: upsample_bilinear2d(Tensor self, int[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor use_c10_dispatcher: full @@ -6772,12 +7174,14 @@ CPU: upsample_bilinear2d_cpu CUDA: upsample_bilinear2d_cuda QuantizedCPU: upsample_bilinear2d_quantized_cpu + Checkpoint: checkpoint_upsample_bilinear2d - func: upsample_bilinear2d_backward.grad_input(Tensor grad_output, int[2] output_size, int[4] input_size, bool align_corners, float? scales_h=None, float? scales_w=None, *, Tensor(a!) grad_input) -> Tensor(a!) python_module: nn dispatch: CPU: upsample_bilinear2d_backward_out_cpu CUDA: upsample_bilinear2d_backward_out_cuda + Checkpoint: checkpoint_upsample_bilinear2d_backward_out - func: upsample_bilinear2d_backward(Tensor grad_output, int[2] output_size, int[4] input_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor use_c10_dispatcher: full @@ -6785,6 +7189,7 @@ dispatch: CPU: upsample_bilinear2d_backward_cpu CUDA: upsample_bilinear2d_backward_cuda + Checkpoint: checkpoint_upsample_bilinear2d_backward - func: upsample_bicubic2d.out(Tensor self, int[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None, *, Tensor(a!) out) -> Tensor(a!) python_module: nn @@ -7023,10 +7428,15 @@ python_module: nn dispatch: CPU, CUDA: sigmoid_backward_out + Checkpoint: checkpoint_sigmoid_backward_out - func: sigmoid_backward(Tensor grad_output, Tensor output) -> Tensor use_c10_dispatcher: full python_module: nn + dispatch: + CPU: sigmoid_backward + CUDA: sigmoid_backward + Checkpoint: checkpoint_sigmoid_backward - func: logit_backward.grad_input(Tensor grad_output, Tensor self, float? eps=None, *, Tensor(a!) grad_input) -> Tensor(a!) python_module: nn @@ -7041,10 +7451,15 @@ python_module: nn dispatch: CPU, CUDA: tanh_backward_out + Checkpoint: checkpoint_tanh_backward_out - func: tanh_backward(Tensor grad_output, Tensor output) -> Tensor use_c10_dispatcher: full python_module: nn + dispatch: + CPU: tanh_backward + CUDA: tanh_backward + Checkpoint: checkpoint_tanh_backward # What's a thnn_conv_ versus a slow_conv_? # @@ -7128,6 +7543,7 @@ dispatch: CPU: slow_conv2d_forward_out_cpu CUDA: legacy::cuda::_thnn_conv2d_forward_out + Checkpoint: checkpoint_thnn_conv2d_forward_out - func: thnn_conv2d_forward(Tensor self, Tensor weight, int[2] kernel_size, Tensor? bias, int[2] stride, int[2] padding) -> (Tensor output, Tensor finput, Tensor fgrad_input) use_c10_dispatcher: full @@ -7135,12 +7551,14 @@ dispatch: CPU: slow_conv2d_forward_cpu CUDA: legacy::cuda::_thnn_conv2d_forward + Checkpoint: checkpoint_thnn_conv2d_forward - func: thnn_conv2d_backward.grad_input(Tensor grad_output, Tensor self, Tensor weight, int[2] kernel_size, int[2] stride, int[2] padding, Tensor finput, Tensor fgrad_input, *, Tensor(a!)? grad_input, Tensor(b!)? grad_weight, Tensor(c!)? grad_bias) -> (Tensor(a!), Tensor(b!), Tensor(c!)) python_module: nn dispatch: CPU: slow_conv2d_backward_out_cpu CUDA: slow_conv2d_backward_out_cuda + Checkpoint: checkpoint_thnn_conv2d_backward_out - func: thnn_conv2d_backward.output_mask(Tensor grad_output, Tensor self, Tensor weight, int[2] kernel_size, int[2] stride, int[2] padding, Tensor finput, Tensor fgrad_input, bool[3] output_mask) -> (Tensor grad_input, Tensor grad_weight, Tensor grad_bias) use_c10_dispatcher: full @@ -7148,6 +7566,7 @@ dispatch: CPU: slow_conv2d_backward_cpu CUDA: slow_conv2d_backward_cuda + Checkpoint: checkpoint_thnn_conv2d_backward - func: thnn_conv_depthwise2d.out(Tensor self, Tensor weight, int[2] kernel_size, Tensor? bias=None, int[2] stride=1, int[2] padding=0, int[2] dilation=1, *, Tensor(a!) out) -> Tensor(a!) python_module: nn diff --git a/aten/src/ATen/templates/TensorBody.h b/aten/src/ATen/templates/TensorBody.h index 44f0c9561a4..98320dd6387 100644 --- a/aten/src/ATen/templates/TensorBody.h +++ b/aten/src/ATen/templates/TensorBody.h @@ -310,6 +310,9 @@ class CAFFE2_API Tensor { /// Returns a `Tensor`'s device. Device device() const; + /// Returns a `Tensor`'s device. + c10::optional optional_device() const; + /// Returns a `Tensor`'s device index. int64_t get_device() const; diff --git a/aten/src/ATen/templates/TensorMethods.cpp b/aten/src/ATen/templates/TensorMethods.cpp index 064f5911cb1..aed046b8b63 100644 --- a/aten/src/ATen/templates/TensorMethods.cpp +++ b/aten/src/ATen/templates/TensorMethods.cpp @@ -60,6 +60,10 @@ Device Tensor::device() const { return impl_->device(); } +c10::optional Tensor::optional_device() const { + return impl_->optional_device(); +} + int64_t Tensor::get_device() const { // NB: this is not a native function to avoid dispatching overhead. return impl_->get_device(); diff --git a/c10/core/Backend.h b/c10/core/Backend.h index ab0c45ad1fc..6afb9ce6a85 100644 --- a/c10/core/Backend.h +++ b/c10/core/Backend.h @@ -39,11 +39,48 @@ enum class Backend { Vulkan, QuantizedCPU, QuantizedCUDA, + Checkpoint, Undefined, MkldnnCPU, NumOptions }; +// TODO: This probably shouldn't actually be static inline +static inline const char* toString(Backend b) { + switch (b) { + case Backend::CPU: + return "CPU"; + case Backend::CUDA: + return "CUDA"; + case Backend::HIP: + return "HIP"; + case Backend::FPGA: + return "FPGA"; + case Backend::MSNPU: + return "MSNPU"; + case Backend::XLA: + return "XLA"; + case Backend::SparseCPU: + return "SparseCPU"; + case Backend::SparseCUDA: + return "SparseCUDA"; + case Backend::SparseHIP: + return "SparseHIP"; + case Backend::MkldnnCPU: + return "MkldnnCPU"; + case Backend::Vulkan: + return "Vulkan"; + case Backend::QuantizedCPU: + return "QuantizedCPU"; + case Backend::QuantizedCUDA: + return "QuantizedCUDA"; + case Backend::Checkpoint: + return "Checkpoint"; + default: + return "UNKNOWN_BACKEND"; + } +} + static inline Backend toSparse(Backend b) { switch (b) { case Backend::CPU: @@ -59,7 +96,7 @@ static inline Backend toSparse(Backend b) { case Backend::SparseHIP: return Backend::SparseHIP; default: - throw std::runtime_error("Unknown backend"); + throw std::runtime_error(std::string("Unknown backend: ") + toString(b)); } } @@ -88,7 +125,7 @@ static inline Backend toDense(Backend b) { case Backend::QuantizedCUDA: return Backend::QuantizedCUDA; default: - throw std::runtime_error("Unknown backend"); + throw std::runtime_error(std::string("Unknown backend: ") + toString(b)); } } @@ -121,6 +158,8 @@ static inline Backend dispatchKeyToBackend(DispatchKey t) { return Backend::QuantizedCUDA; } else if (t == DispatchKey::Undefined) { return Backend::Undefined; + } else if (t == DispatchKey::Checkpoint) { + return Backend::Checkpoint; } else { AT_ERROR("Unrecognized tensor type ID: ", t); } @@ -154,10 +193,12 @@ static inline DispatchKey backendToDispatchKey(Backend b) { return DispatchKey::QuantizedCPU; case Backend::QuantizedCUDA: return DispatchKey::QuantizedCUDA; + case Backend::Checkpoint: + return DispatchKey::Checkpoint; case Backend::Undefined: return DispatchKey::Undefined; default: - throw std::runtime_error("Unknown backend"); + throw std::runtime_error(std::string("Unknown backend: ") + toString(b)); } } @@ -191,7 +232,7 @@ static inline DeviceType backendToDeviceType(Backend b) { case Backend::Undefined: AT_ERROR("Undefined backend is not a valid device type"); default: - AT_ERROR("Unknown backend"); + AT_ERROR(std::string("Unknown backend: ") + toString(b)); } } @@ -223,7 +264,7 @@ static inline Backend backendToCPU(Backend b) { case Backend::Undefined: return Backend::Undefined; default: - AT_ERROR("Unknown backend"); + AT_ERROR(std::string("Unknown backend: ") + toString(b)); } } @@ -243,7 +284,7 @@ static inline Backend backendToCUDA(Backend b) { case Backend::Undefined: return Backend::Undefined; default: - AT_ERROR("Unknown backend"); + AT_ERROR(std::string("Unknown backend: ") + toString(b)); } } @@ -267,40 +308,6 @@ static inline Backend backendToHIP(Backend b) { } } -// TODO: This probably shouldn't actually be static inline -static inline const char* toString(Backend b) { - switch (b) { - case Backend::CPU: - return "CPU"; - case Backend::CUDA: - return "CUDA"; - case Backend::HIP: - return "HIP"; - case Backend::FPGA: - return "FPGA"; - case Backend::MSNPU: - return "MSNPU"; - case Backend::XLA: - return "XLA"; - case Backend::SparseCPU: - return "SparseCPU"; - case Backend::SparseCUDA: - return "SparseCUDA"; - case Backend::SparseHIP: - return "SparseHIP"; - case Backend::MkldnnCPU: - return "MkldnnCPU"; - case Backend::Vulkan: - return "Vulkan"; - case Backend::QuantizedCPU: - return "QuantizedCPU"; - case Backend::QuantizedCUDA: - return "QuantizedCUDA"; - default: - return "UNKNOWN_BACKEND"; - } -} - static inline bool isSparse(Backend b) { switch (b) { case Backend::SparseCPU: diff --git a/c10/core/DispatchKey.cpp b/c10/core/DispatchKey.cpp index be8efbbe09e..ab2a19f75c0 100644 --- a/c10/core/DispatchKey.cpp +++ b/c10/core/DispatchKey.cpp @@ -79,6 +79,8 @@ const char* toString(DispatchKey t) { return "AutogradPrivateUse3"; case DispatchKey::AutogradOther: return "AutogradOther"; + case DispatchKey::Checkpoint: + return "Checkpoint"; case DispatchKey::BackendSelect: return "BackendSelect"; case DispatchKey::Named: diff --git a/c10/core/DispatchKey.h b/c10/core/DispatchKey.h index b0c75a261c0..8c2600be6cf 100644 --- a/c10/core/DispatchKey.h +++ b/c10/core/DispatchKey.h @@ -222,6 +222,9 @@ enum class DispatchKey : uint8_t { AutogradPrivateUse2, AutogradPrivateUse3, + // Checkpoint must go after Autograd. This way, Autograd will hook ad outside of CheckpointTensor. + Checkpoint, + Tracer, // Autocasting precedes VariableTypeId, to ensure casts are autograd-exposed diff --git a/c10/core/TensorImpl.h b/c10/core/TensorImpl.h index 80917e5745a..c792dc8a192 100644 --- a/c10/core/TensorImpl.h +++ b/c10/core/TensorImpl.h @@ -493,6 +493,10 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { return *device_opt_; } + c10::optional optional_device() const { + return device_opt_; + } + Layout layout() const { // NB: This method is not virtual and avoid dispatches for perf. if (is_sparse()) { @@ -910,6 +914,12 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { * compatible with SparseCUDA. */ inline bool has_compatible_shallow_copy_type(DispatchKeySet from) { + if (key_set_ == from) { + return true; + } + if (key_set_.has(DispatchKey::Checkpoint) || from.has(DispatchKey::Checkpoint)) { + return false; + } auto is_dense = [](DispatchKeySet ts) { return ts.has(DispatchKey::CPU) || ts.has(DispatchKey::CUDA) || @@ -920,7 +930,7 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { ts.has(DispatchKey::SparseCUDA) || ts.has(DispatchKey::SparseHIP); }; - return (key_set_ == from) || (is_dense(key_set_) && is_dense(from)) || (is_sparse(key_set_) && is_sparse(from)); + return (is_dense(key_set_) && is_dense(from)) || (is_sparse(key_set_) && is_sparse(from)); } /** diff --git a/test.py b/test.py new file mode 100644 index 00000000000..9a6ee104db0 --- /dev/null +++ b/test.py @@ -0,0 +1,4 @@ +import torch +x = torch.Tensor([1]).checkpoint() +y = x +z = y diff --git a/third_party/json b/third_party/json new file mode 160000 index 00000000000..19843b038ca --- /dev/null +++ b/third_party/json @@ -0,0 +1 @@ +Subproject commit 19843b038caa463164d6f89ea1b2765fae7552e9 diff --git a/tools/autograd/templates/Functions.cpp b/tools/autograd/templates/Functions.cpp index 4109d1e7998..1cda83ce2ff 100644 --- a/tools/autograd/templates/Functions.cpp +++ b/tools/autograd/templates/Functions.cpp @@ -769,18 +769,6 @@ Tensor index_select_backward(Tensor grad, int64_t dim, Tensor indices, IntArrayR return at::zeros(sizes, grad.options()).scatter_(dim, indices, grad); } -Tensor slice_backward(Tensor grad, IntArrayRef input_sizes, int64_t dim, int64_t start, int64_t end, int64_t step) { - auto grad_input = at::zeros(input_sizes, grad.options()); - grad_input.slice(dim, start, end, step).copy_(grad); - return grad_input; -} - -Tensor select_backward(Tensor grad, IntArrayRef input_sizes, int64_t dim, int64_t index) { - auto grad_input = at::zeros(input_sizes, grad.options()); - grad_input.select(dim, index).copy_(grad); - return grad_input; -} - Tensor trace_backward(const Tensor & grad, IntArrayRef sizes) { if (sizes.size() != 2) { throw std::runtime_error("expected matrix input"); @@ -901,38 +889,6 @@ Tensor cholesky_inverse_backward(Tensor grad, Tensor L, bool upper, Tensor inver return grad_L; } -Tensor split_with_sizes_backward(const std::vector &grads, - IntArrayRef split_sizes, int64_t dim, IntArrayRef sizes, const at::TensorOptions &options) { - dim = at::maybe_wrap_dim(dim, sizes.size()); - - // it's possible some of the grads are not defined (represents tensors of all 0s). - // Since at::cat can't handle those, let's define them - std::vector grads_all_defined(grads.size()); - for (size_t j = 0; j < grads.size(); ++j) { - if (grads[j].defined()) { - grads_all_defined[j] = grads[j]; - } else { - auto length = split_sizes[j]; - auto grad_size = sizes.vec(); - grad_size[dim] = length; - grads_all_defined[j] = at::zeros(grad_size, options); - } - } - - auto ret = at::cat(grads_all_defined, dim); - return ret; -} - -Tensor split_backward(const std::vector &grads, - int64_t split_size, int64_t dim, IntArrayRef sizes, const at::TensorOptions &options) { - dim = at::maybe_wrap_dim(dim, sizes.size()); - int64_t dim_size = sizes[dim]; - int64_t num_splits = grads.size(); - std::vector split_sizes(num_splits, split_size); - split_sizes[num_splits - 1] = split_size - (split_size * num_splits - dim_size); - return split_with_sizes_backward(grads, split_sizes, dim, sizes, options); -} - Tensor max_pool_double_backward(const Tensor & grad, const Tensor & indices, int dim) { AT_ASSERT(indices.dim() >= dim); auto size = indices.sizes().slice(0, indices.dim() - dim).vec(); diff --git a/torch/csrc/utils/tensor_numpy.cpp b/torch/csrc/utils/tensor_numpy.cpp index 879618f652b..7b96cd41044 100644 --- a/torch/csrc/utils/tensor_numpy.cpp +++ b/torch/csrc/utils/tensor_numpy.cpp @@ -73,7 +73,8 @@ static std::vector seq_to_aten_shape(PyObject *py_seq) { return result; } -PyObject* tensor_to_numpy(const at::Tensor& tensor) { +PyObject* tensor_to_numpy(const at::Tensor& tensor_) { + Tensor tensor = tensor_.decheckpoint(); if (tensor.device().type() != DeviceType::CPU) { throw TypeError( "can't convert %s device type tensor to numpy. Use Tensor.cpu() to " diff --git a/torch/nn/functional.py b/torch/nn/functional.py index 916fab5e6ee..f7e8385e863 100644 --- a/torch/nn/functional.py +++ b/torch/nn/functional.py @@ -4082,9 +4082,10 @@ def multi_head_attention_forward(query, # type: Tensor q = q.contiguous().view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1) if k is not None: - k = k.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1) + # extremely bizarre, but got errors of dimension mistmatch here and below unless I changed view -> reshape + k = k.contiguous().reshape(-1, bsz * num_heads, head_dim).transpose(0, 1) if v is not None: - v = v.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1) + v = v.contiguous().reshape(-1, bsz * num_heads, head_dim).transpose(0, 1) if static_k is not None: assert static_k.size(0) == bsz * num_heads @@ -4135,7 +4136,8 @@ def multi_head_attention_forward(query, # type: Tensor attn_output = torch.bmm(attn_output_weights, v) assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim] - attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) + # also had to change view to reshape + attn_output = attn_output.transpose(0, 1).contiguous().reshape(tgt_len, bsz, embed_dim) attn_output = linear(attn_output, out_proj_weight, out_proj_bias) if need_weights: