diff --git a/.gitmodules b/.gitmodules index 3ae80c83792..6be1789ef76 100644 --- a/.gitmodules +++ b/.gitmodules @@ -122,3 +122,6 @@ ignore = dirty path = third_party/XNNPACK url = https://github.com/google/XNNPACK.git +[submodule "third_party/json"] + path = third_party/json + url = git@github.com:nlohmann/json.git diff --git a/aten/src/ATen/CheckpointTensorImpl.cpp b/aten/src/ATen/CheckpointTensorImpl.cpp new file mode 100644 index 00000000000..72ad9e63fda --- /dev/null +++ b/aten/src/ATen/CheckpointTensorImpl.cpp @@ -0,0 +1,658 @@ +#include +#include +#include + +#include +#include +#include +#include + +namespace at { + +using Clock = std::chrono::high_resolution_clock; +using Time = Clock::time_point; +using Duration = Clock::duration; +using FinalTime = std::chrono::nanoseconds; + +struct PerfStats; + +struct Timer { + std::string name; + Time start; + Timer(std::string name, Time start) : name(name), start(start) {} + Timer() {} + ~Timer(); +}; + +constexpr bool stats = true; + +struct PerfStats { + using TimerStats = std::tuple; + Time start; + std::unordered_map calls; + std::vector timers; + + PerfStats() : start(Clock::now()), calls(0), timers() {} + + /*Timer track(std::string name) { + if (stats) { + auto it = this->calls.find(name); + if (it != this->calls.end()) { + it->second += 1; + } else { + this->calls.insert({name, 0}); + } + + return Timer(name, Clock::now()); + } + return Timer(); + }*/ + void track(const char*) { } + + ~PerfStats() { + if (!stats) { return; } + auto start = std::get<1>(this->timers[0]); + auto now = Clock::now(); + std::cout << "All done. Here are some perf stats fresh off the preses." << std::endl; + std::unordered_map durations; + + Duration total = now - this->start; + + // For now simple strategy, count up all the time taken + // by each "tagged call site". + for (auto timer : timers) { + auto name = std::get<0>(timer); + Duration duration = std::get<3>(timer); + auto it = durations.find(name); + if (it != durations.end()) { + it->second += duration; + } else { + durations.insert({name, duration}); + } + } + + std::vector> data; + + // Convert the durations + for (auto d : durations) { + // auto duration = std::chrono::duration_cast(d.second); + data.push_back(d); + } + + std::sort(data.begin(), data.end(), + [](const std::pair & a, const std::pair & b) -> bool { + return a.second > b.second; + }); + + for (auto d : data) { + auto duration = std::chrono::duration_cast(d.second); + auto total_duration = std::chrono::duration_cast(total); + double percentage = ((double)duration.count())/((double)total_duration.count()) * 100; + auto call_count = this->calls.find(d.first); + TORCH_CHECK(call_count != this->calls.end()); + std::cout << "CallSite: " << d.first << " CallCount: " << call_count->second << " Cost: " << duration.count() << "ns" << " (%" << percentage << ")" << std::endl; + } + } +}; + +static PerfStats STATS = PerfStats(); + +size_t memory_sum = 0; +size_t memory_max = 0; +size_t memory_count = 0; + +void reset_memory_stat() { + memory_sum = 0; + memory_max = 0; + memory_count = 0; +} + +inline size_t memory(const Tensor& t) { + if (! t.has_storage()) { + return 0; + } + auto& storage = t.storage(); + size_t res = storage.numel() * storage.itemsize(); + memory_sum += res; + memory_max = std::max(memory_max, res); + memory_count += 1; + return res; +} + +Timer::~Timer() { + Time now = Clock::now(); + Duration elapsed = now - start; + PerfStats::TimerStats stats = { name , start, now, elapsed }; + STATS.timers.push_back(stats); +} + +CheckpointPool pool; +void CheckpointPool::add(const intrusive_ptr& p) { + if (p->memory > 0 && (memory_count == 0 || p->memory >= 0.01 * double(memory_sum/memory_count))) { + aps.push_back(weak_intrusive_ptr(p)); + } +} + +long current_memory() { + STATS.track("current_memory"); + auto device_stat = c10::cuda::CUDACachingAllocator::getDeviceStats(0); + return device_stat.allocated_bytes[0].current; +} + +void CheckpointPool::auto_evict() { + STATS.track("CheckpointPool::auto_evict"); + if (has_memory_budget) { + while (current_memory() > memory_budget) { + evict(); + } + } +} + +void CheckpointPool::evict() { + STATS.track("CheckpointPool::evict"); + TORCH_CHECK(aps.size() > 0); + bool shrinked = false; + int evict_idx = -1; + double evict_cost = INFINITY; + time_t current_time = std::chrono::system_clock::now(); + auto remove_from_aps = [&](size_t i) { + aps[i] = aps[aps.size() - 1]; + aps.pop_back(); + }; + std::uniform_int_distribution<> distrib(1, 1 * std::max(1, static_cast(std::sqrt(aps.size())))); + // sampling a random independent subset of all evictable tensors to find the cheapest tensor to evict. + for (size_t i = 0; i < aps.size();) { + auto cannot_evict = [&]() { + shrinked = true; + remove_from_aps(i); + }; + auto ap_strong = aps[i].lock(); + if (!ap_strong.defined()) { + cannot_evict(); + } + else if (ap_strong->ecn) { + cannot_evict(); + } + else { + if (ap_strong->evictable()) { + double cost = ap_strong->cost(current_time); + if (cost < evict_cost) { + evict_cost = cost; + evict_idx = i; + } + } + i += distrib(gen); + } + } + if (evict_idx == -1) { + TORCH_CHECK(shrinked); + } else { + auto evict_from_idx = [&](size_t idx) { + auto ap_strong = aps[idx].lock(); + TORCH_CHECK(ap_strong.defined()); + ap_strong->evict(); + remove_from_aps(evict_idx); + }; + evict_from_idx(evict_idx); + } +} + +CheckpointPool::CheckpointPool() { } + +bool use_log = false; +long compute_time_ = 0; + +namespace native { + +Tensor checkpoint(const Tensor& t) { + STATS.track("checkpoint"); + auto cpti = intrusive_ptr::make(t.detach()); + if (use_log) { + DTRLogConstant(cpti->counter_name()); + DTRLogMemory(cpti->counter_name(), cpti->ref->value->value->memory()); + } + return Tensor(cpti); +} + +Tensor uncheckpoint(const Tensor& t) { + STATS.track("uncheckpoint"); + auto* cpti = dynamic_cast(t.unsafeGetTensorImpl()); + TORCH_CHECK(cpti != nullptr); + return cpti->ref->value->value->get(); +} + +void pin(const Tensor& t) { + STATS.track("pin"); + auto* cpti = dynamic_cast(t.unsafeGetTensorImpl()); + TORCH_CHECK(cpti != nullptr); + cpti->ref->value->value->pin(); +} + +Tensor decheckpoint(const Tensor& t) { + STATS.track("decheckpoint"); + auto* cpti = dynamic_cast(t.unsafeGetTensorImpl()); + return cpti ? cpti->ref->value->value->get() : t; +} + +bool is_checkpoint(const Tensor& t) { + STATS.track("is_checkpoint"); + auto* cpti = dynamic_cast(t.unsafeGetTensorImpl()); + return cpti != nullptr; +} + +Tensor try_checkpoint(const Tensor& t) { + STATS.track("try_checkpiont"); + return is_checkpoint(t) ? t : checkpoint(t); +} + +void new_log(std::string str) { + DTRLogger::logger().out = std::ofstream(DTRLogger::logger().get_filename(str)); +} + +void annotate_log(std::string str) { + if (!use_log) { return; } + if (log_json) { + json j; + j[INSTRUCTION] = "ANNOTATE"; + j[ANNOTATION] = str; + DTRLogger::logger().log(j.dump()); + } else { + DTRLogger::logger().log("# " + str); + } +} + +void toggle_log(bool b) { + use_log = b; +} + +void clear_checkpointpool() { + while (likely(!pool.exts.empty())) { + if (auto e = pool.exts.back().lock()) { + e->value->pin(); + } + pool.exts.pop_back(); + } +} + +void unset_memory_budget() { + pool.has_memory_budget = false; +} + +void set_memory_budget(long budget) { + pool.memory_budget = budget; + pool.has_memory_budget = true; +} + +void reset_compute_time() { + compute_time_ = 0; +} + +long compute_time() { + return compute_time_; +} + +} + +[[inline]] +Tensor uncheckpoint(const strong& input) { + return input->get(); +} + +Tensors uncheckpoint(const strongs& inputs) { + STATS.track("uncheckpoint"); + Tensors ret; + ret.reserve(inputs.size()); + for (const strong& input : inputs) { + // inlined manually + ret.push_back(input->get()); + } + return ret; +}; + +Tensors try_checkpoint(const Tensors& inputs) { + STATS.track("try_checkpoint"); + Tensors ret; + ret.reserve(inputs.size()); + for (const Tensor& input : inputs) { + ret.push_back(at::native::try_checkpoint(input)); + } + return ret; +} + +CheckpointInfo merge_cpi(CheckpointInfo l, CheckpointInfo r) { + STATS.track("merge_cpi"); + return CheckpointInfo(l.compute_cost + r.compute_cost); +} + +void AliasPool::evict() { + STATS.track("AliasPool::evict"); + TORCH_CHECK(!ecn); + ecn = head_remat->get_ecn(); + auto ecns = neighbor_ecn(); + for (const auto& necn : ecns) { + merge(merge_cpi, ecn, necn); + } + auto b4 = current_memory(); + TORCH_CHECK(memory > 0); + TORCH_CHECK(lock_count == 0); + TORCH_CHECK(!is_evicted); + is_evicted = true; + for (const weak& w : tensors) { + if (auto cell = w.lock()) { + cell->evict(); + } + } + TORCH_CHECK(current_memory() < b4); + // somehow it is still evicting unevictable stuff. +} + +double AliasPool::cost(time_t current_time) { + auto cpi = head_remat->get_cpi(); + auto ecns = neighbor_ecn(); + for (const auto& necn : ecns) { + cpi = merge_cpi(cpi, get_t(necn)); + } + return cpi.cost(memory, (current_time - last_used_time).count()); +} + +void External::release_resources() { + value->pool->release_external(); + value.reset(); +} + +void Rematerializer::remat() { + // TODO: refactor using RAII for exception safety. + for (const strong& s : inputs) { + s->pool->lock(); + } + Tensors ts = uncheckpoint(inputs); + time_t pre = std::chrono::system_clock::now(); + auto ret = func(ts); + time_t post = std::chrono::system_clock::now(); + pool.auto_evict(); + compute_time_ += (post - pre).count(); + TORCH_CHECK(ret.size() == outputs.size()); + for (size_t i = 0; i < outputs.size(); ++i) { + if (auto output_cell = outputs[i].lock()) { + output_cell->fill(ret[i]); + } + } + ecn.reset(); + for (const strong& s : inputs) { + s->pool->unlock(); + } +} + +ecn_ptr Rematerializer::get_ecn() { + if (!ecn) { + ecn = ecn_ptr::make(CheckpointInfo(compute_cost)); + } + return ecn; +} + +CheckpointInfo Rematerializer::get_cpi() { + return CheckpointInfo(ecn ? duration_t(0) : compute_cost); +} + +std::set AliasPool::neighbor_ecn() { + STATS.track("AliasPool::neighbor_ecn"); + std::set ptr_set; + int size = neighbors.size(); + for (size_t i = 0; i < size;) { + if (auto cptc = neighbors[i].lock()) { + if (cptc->pool->ecn) { + ptr_set.insert(cptc->pool->ecn); + } + ++i; + } else { + neighbors[i] = neighbors[size - 1]; + size = size - 1; + } + } + if (size < neighbors.size()) { + neighbors.erase(neighbors.begin() + size); + } + return ptr_set; +} + +void AliasPool::set_not_evicted(const intrusive_ptr& self) { + if (unlikely(is_evicted)) { + STATS.track("AliasPool::set_not_evicted(inside)"); + is_evicted = false; + if (ecn) { + TORCH_CHECK(head_remat); + auto cpi = get_t(ecn); + update_t(ecn, CheckpointInfo(cpi.compute_cost - head_remat->compute_cost)); + ecn.reset(); + } + pool.add(self); + } +} + +void CheckpointTensorCell::fill(const Tensor& t) { + STATS.track("CheckpointTensorCell::fill"); + if (!(this->t)) { + this->t = std::make_unique(t.detach()); + pool->set_not_evicted(pool); + if (!defined) { + defined = true; + is_undefined_tensor = !t.defined(); + key_set_ = t.key_set(); + dtype_ = t.dtype(); + optional_device_ = t.optional_device(); + } + } +} + +intrusive_ptr CheckpointTensorImpl::shallow_copy_and_detach(const VariableVersion& version_counter, + bool allow_tensor_metadata_change) const { + auto ret = intrusive_ptr::make(ref); + if (use_log) { + DTRLogCopy(ret->counter_name(), counter_name()); + } + return ret; +} + +void CheckpointTensorImpl::shallow_copy_from(const c10::intrusive_ptr& impl) { + STATS.track("CheckpointTensorCell::shallow_copy_from"); + TORCH_CHECK(impl->key_set().has(DispatchKey::CheckpointTensorId)); + auto* cpti = dynamic_cast(impl.get()); + TORCH_CHECK(cpti != nullptr); + ref->value = cpti->ref->value; + if (use_log) { + DTRLogCopyFrom(counter_name(), cpti->counter_name()); + } +} + +int CheckpointTensorImpl::counter = 0; + +bool is_alias(const Tensor& l, const Tensor& r) { + return l.defined() && r.defined() && l.is_alias_of(r); +} + +// return an index for alias. +// we dont care which one because they all lead to the same alias pool. +// return -1 for no alias. +// may god forgive my sin. +int get_alias(const Tensors& ts, const Tensor& t) { + if (t.defined()) { + for (size_t i = 0; i < ts.size(); ++i) { + if (ts[i].defined() && t.is_alias_of(ts[i])) { + return i; + } + } + } + return -1; +} + +struct MakeRawResult { + std::vector> outputs; + std::vector aliases; + duration_t time; + intrusive_ptr rematerializer; +}; + +void add_neighbor(const strong& l, const strong& r) { + l->pool->neighbors.push_back(weak(r)); + r->pool->neighbors.push_back(weak(l)); +} + +// remat take a single vector of tensors, +// while there are two vector, one storing nonconstants and one storing constants. +// the constants are small and they will not be considered for eviction. +// however, we have to stitch the two vectors together to pass it in remat. +// the size_t in constants decide the location to stitch them in, while input_values fill in the rest. +MakeRawResult make_raw(const rematerialize_function_t& remat_f, + const strongs& inputs) { + STATS.track("make_raw"); + for (const strong& s : inputs) { + s->pool->lock(); + } + Tensors raw_inputs = uncheckpoint(inputs); + time_t pre = std::chrono::system_clock::now(); + auto raw_outputs = remat_f(raw_inputs); + time_t post = std::chrono::system_clock::now(); + pool.auto_evict(); + compute_time_ += (post - pre).count(); + std::vector> outputs; + std::vector aliases; + weaks weak_outputs; + auto remat = intrusive_ptr::make(Unsafe(), remat_f, inputs, post - pre); + + for (const Tensor& t : raw_outputs) { + intrusive_ptr alias_pool; + int alias = get_alias(raw_inputs, t); + if (alias == -1) { + auto m = memory(t); + alias_pool = intrusive_ptr::make(Unsafe(), remat, m); + pool.add(alias_pool); + } + else { + alias_pool = inputs[alias]->pool; + if (alias_pool->head_remat) { + alias_pool->head_remat->compute_cost += (post - pre); + } + } + auto e = intrusive_ptr::make(t, alias_pool, remat); + pool.exts.push_back(weak_intrusive_ptr(e)); + alias_pool->tensors.push_back(weak(e->value)); + outputs.push_back(e); + aliases.push_back(alias); + weak_outputs.push_back(weak(outputs.back()->value)); + } + remat->outputs = weak_outputs; + for (size_t i = 0; i < inputs.size(); ++i) { + for (size_t j = 0; j < outputs.size(); ++j) { + if (!is_alias(raw_inputs[i], raw_outputs[j])) { + add_neighbor(inputs[i], outputs[j]->value); + } + } + } + for (const strong& s : inputs) { + s->pool->unlock(); + } + return {outputs, aliases, post - pre, remat}; +} + +std::string from_time(duration_t t) { + return std::to_string(std::chrono::nanoseconds(t).count()); +} + +Tensors CheckpointTensorImpl::make(const std::string& name, + const rematerialize_function_t& remat, + const Tensors& inputs) { + STATS.track("CheckPointTensorImpl::make"); + Tensors checkpointed_inputs = try_checkpoint(inputs); + auto input_size = checkpointed_inputs.size(); + + strongs input_values; + input_values.reserve(input_size); + + std::vector args; + args.reserve(input_size); + + for (const Tensor& t: checkpointed_inputs) { + auto* cpti = dynamic_cast(t.unsafeGetTensorImpl()); + TORCH_CHECK(cpti); + input_values.push_back(cpti->ref->value->value); + if (use_log) { + args.push_back(cpti->counter_name()); + } + } + + auto ret = make_raw(remat, input_values); + + Tensors tensors; + tensors.reserve(ret.outputs.size()); + + for (const auto& t: ret.outputs) { + auto cp = Tensor(intrusive_ptr::make(t)); + tensors.push_back(cp); + } + + if (use_log) { + std::vector res; + res.reserve(ret.outputs.size()); + + for (const auto& tensor : tensors) { + res.push_back(get_cpti(tensor)->counter_name()); + } + + DTRLogCall(res, name, args, from_time(ret.time)); + for (size_t i = 0; i < tensors.size(); ++i) { + Tensor t = tensors[i]; + auto cpti = get_cpti(t); + DTRLogMemory(cpti->counter_name(), cpti->ref->value->value->memory()); + DTRLogAlias(cpti->counter_name(), ret.aliases[i]); + } + } + + return tensors; +} + +// TODO: check that mutated value does not have alias. +void CheckpointTensorImpl::mutate(const std::string& name, + const mutate_function_t& mutate, + const Tensors& inputs, + const std::vector& mutate_idx) { + auto remat = [=](const Tensors& t) -> Tensors { + Tensors new_input_values = t; + for (size_t idx: mutate_idx) { + new_input_values[idx] = t[idx].clone(); + } + mutate(new_input_values); + return new_input_values; + }; + Tensors checkpointed_inputs = try_checkpoint(inputs); + strongs input_values; + std::vector args; + for (const Tensor& t: checkpointed_inputs) { + auto* cpti = dynamic_cast(t.unsafeGetTensorImpl()); + TORCH_CHECK(cpti); + input_values.push_back(cpti->ref->value->value); + if (use_log) { + args.push_back(cpti->counter_name()); + } + } + auto ret = make_raw(remat, input_values); + const auto& modified = ret.outputs; + for (size_t idx: mutate_idx) { + cell_from_tensor(inputs[idx])->value = modified[idx]; + } + if (use_log) { + DTRLogMutate(name, args, mutate_idx, from_time(ret.time)); + } +} + +void CheckpointTensorImpl::release_resources() { + if (use_log) { + DTRLogRelease(counter_name()); + } + ref.reset(); +} + +CheckpointTensorImpl::CheckpointTensorImpl(const Tensor& t) : CheckpointTensorImpl(intrusive_ptr::make(t)) { + pool.exts.push_back(weak_intrusive_ptr(ref->value)); +} + +} diff --git a/aten/src/ATen/CheckpointTensorImpl.h b/aten/src/ATen/CheckpointTensorImpl.h new file mode 100644 index 00000000000..3e0f3e4d203 --- /dev/null +++ b/aten/src/ATen/CheckpointTensorImpl.h @@ -0,0 +1,456 @@ +#pragma once + +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +#define likely(x) __builtin_expect(!!(x), 1) +#define unlikely(x) __builtin_expect(!!(x), 0) +#define TORCH_CHECK(a, ...) // profile mode + +// System Description: +// Every Tensor is managed by a CheckpointTensor, +// that describe how it is computed, (the function and the inputs) +// And might optionally hold the tensor value. +// The tensor value might be dropped, and when requested later, recomputed and cached again. + +// Corner Cases: +// An CheckpointedTensor might be constant. +// In this case it is unevictable. +// An operator might return multiple output. +// In this case the computation info (rematerializer) is shared between all of them, +// And when the function get computed again all value get cached. +// An operator might not return value, but only mutate input value. +// To combat this, we COW the operator, and wrap CheckpopintTensor with a Ref. +// By doing this the inner CheckpointTensor is kept purely functional. +// An operator might try to mutate uncheckpointed tensor. +// We do not support this and will error. +// An operator might create aliases. +// We track alias in AliasPool. +// Each AliasPool hold a set of tensor that is alias to eachother. +// An operator might try to create Alias to an unevictable tensor. +// In such a case the output tensor is unevictable. +// An operator might try to mutate Tensor with Alias. +// We do not support this case an will error if a Tensor has any alive Alias. +// However it could be done without a major redesign of the system - +// Each AliasPool will hold weak pointers to the External Reference. +// When alias mutation occur, +// we make a rematerialize_function that take in the base tensor (other tensor alias from) +// and output all the new value of the aliases, then update the Ref. +// Of course, the cleaner way is to not support this. +// Shame on those who use this feature. + +// Memory Safety: +// The objects here will have lots of backedges. +// In order to collect memory when computation is completed, +// We require that all strong pointer is of the form of value -> input. +// This ensure that everything will be released if there is no external ref whatsoever. + +// Optimization: +// We treat tensor that has no external reference differently - +// They will never be externally used again so we assume their next use time is infinite +// so, if it doesnt has any evicted neighbor it will get evicted immediately. + +// Note: to code fast I do not use RAII and just assume the code will not try to recover from exception. +// It should be easy to fix though. + +namespace at { + +// TODO: using a pool allocator might make more sense - no need to allocate and delete each pointer individually. +template +struct EquivalentClassNode : intrusive_ptr_target { + explicit EquivalentClassNode(const T& t) : t_unsafe(t) { } + mutable intrusive_ptr parent; + bool is_root() { + return !parent; + } + void release_resources() override { + parent.reset(); + } + T t_unsafe; +}; + +template +T& get_t(const intrusive_ptr>& n) { + return find_root(n)->t_unsafe; +} + +template +static void update_t(const intrusive_ptr>& n, const T& t) { + find_root(n)->t_unsafe = t; +} + +template +intrusive_ptr> find_root(const intrusive_ptr>& n) { + if (n->is_root()) { + return n; + } else { + n->parent = find_root(n->parent); + return n->parent; + } +} + +template +intrusive_ptr> merge(const std::function& merge_t, + const intrusive_ptr>& lhs, + const intrusive_ptr>& rhs) { + auto l = find_root(lhs); + auto r = find_root(rhs); + if (l == r) { + return l; + } + l->parent = r; + r->t_unsafe = merge_t(l->t_unsafe, r->t_unsafe); + return r; +} + +size_t memory(const Tensor& t); + +template +struct RefCell final : intrusive_ptr_target { + mutable T value; + void release_resources() final { + static_release_resources(value); + } + RefCell(const T& t) : value(t) { } +}; + +template +using Ref = intrusive_ptr>; + +template +void static_release_resources(intrusive_ptr& ptr) { + ptr.reset(); +} + +class CheckpointTensorCell; +using strong = intrusive_ptr; +using strongs = std::vector; +using weak = weak_intrusive_ptr; +using weaks = std::vector; +using Tensors = std::vector; +using rematerialize_function_t = std::function; +using mutate_function_t = std::function; + +using time_t = std::chrono::time_point; +using duration_t = std::chrono::system_clock::duration; +struct CheckpointInfo { + duration_t compute_cost; + // @ZACH: Floating Point instability? + double cost(size_t memory, size_t staleness) const { + TORCH_CHECK(memory > 0); + TORCH_CHECK(staleness > 0); + return compute_cost.count() / static_cast(memory * staleness); + } + CheckpointInfo(duration_t compute_cost) : + compute_cost(compute_cost) { + } +}; + +// ecn represent a evicted tensor group. +// it is a set of tensor that are evicted, and if two evicted tensor are input -> output to each other, +// they must be in an ecn. +// note: we try to support removal from ecn by subtracting compute_cost and memory. +// this will create suprious connection but that should be fine empircally. +// below is an example of a suprious connection: +// a -> b, a -> c +// a, b, c got evicted so belong to a single ecn. +// a got rematerialized. +// b, c still belong to a single ecn although there is no connection. +using ecn_ptr = intrusive_ptr>; + +struct Unsafe { }; + +// The rematerializer could be called to reinvoke an operator. +// Tensor point to remat which point to Tensor. +// To build the cycle remat support a default constructor, +// And allow you to fill in the member later. +struct Rematerializer : intrusive_ptr_target { + rematerialize_function_t func; + strongs inputs; + weaks outputs; + duration_t compute_cost; + // when some output in here get evicted, they should belong to this ecn. + // a rematerializer have to track this, + // because when multiple output of a rematerializer get evicted, + // we only want to count the compute cost once. + ecn_ptr ecn; + Rematerializer(const Unsafe&, + const rematerialize_function_t& func, + const strongs& inputs, + duration_t compute_cost) : + func(func), + inputs(inputs), + compute_cost(compute_cost) { + } + void release_resources() final { + func = rematerialize_function_t(); + inputs.clear(); + outputs.clear(); + } + void remat(); + ecn_ptr get_ecn(); + CheckpointInfo get_cpi(); +}; + +// Track all Tensor that share the same Storage. +// This is the atomic level of eviction - when evicting, everything here will get evicted. +// When an AliasPool is evicted, the Storage of the underlying tensor must be freed. +// Additionally, the AliasPool contain weak pointer to all children of tensors, +// in order to compute the score of evicting a Storage. +struct AliasPool : intrusive_ptr_target { + weaks tensors; + weaks neighbors; + std::set neighbor_ecn(); + // get() might hold some raw Tensor, rendering them unevictable. + // it is likely that get() will run out of memory, and when it does so, it will try to evict. + // so, it is crucial that we dont try to evict those tensors - doing so will not evict anything. + // lock_count count how many time a tensor is referenced by get. + size_t lock_count = 0; + size_t external_count = 0; + void lock() { + ++lock_count; + } + void unlock() { + --lock_count; + } + intrusive_ptr head_remat; + bool evictable() const { + return lock_count == 0 && head_remat; + } + // if it is not evictable it must not be evicted. + bool is_evicted = false; + size_t memory; + time_t last_used_time; + // An aliaspool cant register itself to the checkpointpool - you have to do it yourself. + AliasPool(const Unsafe&, intrusive_ptr head_remat, size_t memory) : + head_remat(head_remat), + memory(memory), + last_used_time(std::chrono::system_clock::now()) { + } + // if it is evicted, then hold the evicted tensor group. + ecn_ptr ecn; + double cost(time_t current_time); + void evict(); + void register_external() { + ++external_count; + } + void release_external() { + --external_count; + if (external_count == 0) { + if (lock_count > 0) {return;} + TORCH_CHECK(lock_count == 0); + if (memory > 0 && (!ecn) && head_remat) { + evict(); + } + } + } + // if it was evicted, refresh it. otherwise do nothing. + // have to check so, because when we rematerialize a single tensor in an aliaspool, + // we will set it to non-evicted, and when we rematerialize it's tensor they will also reset this. + void set_not_evicted(const intrusive_ptr& self); + void release_resources() final { + tensors.clear(); + neighbors.clear(); + head_remat.reset(); + } +}; + +struct CheckpointTensorCell : intrusive_ptr_target { + std::unique_ptr t; + bool defined = false; + bool is_undefined_tensor; + DispatchKeySet key_set_; + DispatchKeySet key_set() const { + TORCH_CHECK(defined); + return key_set_; + } + caffe2::TypeMeta dtype_; + caffe2::TypeMeta dtype() const { + TORCH_CHECK(defined); + return dtype_; + } + c10::optional optional_device_; + c10::optional optional_device() const { + TORCH_CHECK(defined); + return optional_device_; + } + // A Tensor is evictable iff it's AliasPool is evictable. + // A evictable tensor must have Rematerializer. + intrusive_ptr pool; + intrusive_ptr remat; + void evict() { + TORCH_CHECK(remat); + t.reset(); + } + void fill(const Tensor& t); + explicit CheckpointTensorCell(const Tensor& t, const intrusive_ptr& pool) : pool(pool) { + fill(t); + } + explicit CheckpointTensorCell(const Tensor& t, + const intrusive_ptr& pool, + const intrusive_ptr& remat) : + pool(pool), remat(remat) { + fill(t); + } + size_t memory() { + TORCH_CHECK(defined); + return pool->memory; + } + Tensor get() { + if (! t) { + TORCH_CHECK(remat); + remat->remat(); + } + TORCH_CHECK(t); + TORCH_CHECK(! t->key_set().has(DispatchKey::CheckpointTensorId)); + pool->last_used_time = std::chrono::system_clock::now(); + return *t; + } + void pin() { + get(); + pool->head_remat.reset(); + remat.reset(); + } + void release_resources() final { + t.reset(); + pool.reset(); + remat.reset(); + } +}; + +// An external reference. +// Each strong will have at most one external reference. +// By keeping such an invariant, whenever an external reference die, +// We know that the underlying strong is only used internally. +// Thus, when it die we can apply optimization like banishing/infinite staleness. +// We keep this invariant by only allowing CheckpointTensorImpl to make new External, +// When new CheckpointTensorImpl is constructed. +struct External : intrusive_ptr_target { + External(const strong& value) : value(value) { + value->pool->register_external(); + } + External(const Tensor& value) : + External(strong::make(value, + intrusive_ptr::make(Unsafe(), + intrusive_ptr(), + memory(value)))) { } + External(const Tensor& value, + const intrusive_ptr& pool, + const intrusive_ptr& remat) : + External(strong::make(value, pool, remat)) { } + strong value; + void release_resources() override; +}; + +inline DispatchKeySet convert_key_set(const DispatchKeySet& t) { + CHECK(!t.has(DispatchKey::CheckpointTensorId)); + auto ret = t.add(DispatchKey::CheckpointTensorId); + CHECK(!ret.has(DispatchKey::VariableTensorId)); + return ret; +} + +struct CheckpointTensorImpl : TensorImpl { + int id = gen_counter(); + static int counter; + static int gen_counter() { + return counter++; + } + std::string counter_name() const { + return std::string("x") + std::to_string(id); + } + + Ref> ref; + + void release_resources() final; + + explicit CheckpointTensorImpl(const Ref>& ref) : + TensorImpl(convert_key_set(ref->value->value->key_set()), + ref->value->value->dtype(), + ref->value->value->optional_device()), + ref(ref) { } + + explicit CheckpointTensorImpl(const intrusive_ptr& e) : + CheckpointTensorImpl(Ref>::make(e)) { } + + explicit CheckpointTensorImpl(const Tensor& t); + + static Tensors make(const std::string& name, + const rematerialize_function_t& remat, + const Tensors& inputs); + + // mutate_idx indicate which of the inputs will get mutated. + static void mutate(const std::string& name, + const mutate_function_t& mutate, + const Tensors& inputs, + const std::vector& mutate_idx); + intrusive_ptr shallow_copy_and_detach(const VariableVersion& version_counter, + bool allow_tensor_metadata_change) const override; + void shallow_copy_from(const c10::intrusive_ptr& impl) override; + int64_t dim() const override { + return ref->value->value->get().dim(); + } + int64_t numel() const override { + return ref->value->value->get().numel(); + } + IntArrayRef sizes() const override { + return ref->value->value->get().sizes(); + } + int64_t size(int64_t d) const override { + return ref->value->value->get().size(d); + } + IntArrayRef strides() const override { + return ref->value->value->get().strides(); + } + int64_t stride(int64_t d) const override { + return ref->value->value->get().stride(d); + } + bool has_storage() const override { + return false; + } +}; + +// CheckpointPool keep a list of AliasPool, and search over them to choose the best one to evict. +struct CheckpointPool { + std::vector> aps; + std::vector> exts; + std::random_device rd; + std::mt19937 gen = std::mt19937(rd()); + bool has_memory_budget = false; + long memory_budget; + void evict(); + void auto_evict(); + void clear_checkpointpool(); + void add(const intrusive_ptr&); + CheckpointPool(); +}; + +inline CheckpointTensorImpl* get_cpti(const Tensor& t) { + auto* cpti = dynamic_cast(t.unsafeGetTensorImpl()); + TORCH_CHECK(cpti != nullptr); + return cpti; +} + +inline Ref> cell_from_tensor(const Tensor& t) { + return get_cpti(t)->ref; +} + +} diff --git a/aten/src/ATen/Logger.h b/aten/src/ATen/Logger.h new file mode 100644 index 00000000000..7e3c01a9af9 --- /dev/null +++ b/aten/src/ATen/Logger.h @@ -0,0 +1,197 @@ +#pragma once + +#include +#include +#include <../../../third_party/json/single_include/nlohmann/json.hpp> + +namespace at { + +struct DTRLogger { + std::string time_prefix; + std::ofstream out; + static std::string get_time_prefix() { + std::time_t t = std::time(nullptr); + std::tm* tm = std::localtime(&t); + return + std::to_string(1900+tm->tm_year) + "-" + + std::to_string(1+tm->tm_mon) + "-" + + std::to_string(tm->tm_mday) + "-" + + std::to_string(tm->tm_hour) + "-" + + std::to_string(tm->tm_min) + "-" + + std::to_string(tm->tm_sec); + } + std::string get_filename(const std::string& name) { + return time_prefix + "-" + name + ".log"; + } + DTRLogger() : time_prefix(get_time_prefix()), out(get_filename("default")) { } + void log(const std::string& str) { + out << str << std::endl; + } + static DTRLogger& logger() { + static DTRLogger ret; + return ret; + } + +}; + +using json = nlohmann::json; +constexpr bool log_json = true; +const std::string INSTRUCTION = "INSTRUCTION"; +const std::string ANNOTATION = "ANNOTATION"; +const std::string RELEASE = "RELEASE"; +const std::string PIN = "PIN"; +const std::string TIME = "TIME"; +const std::string ARGS = "ARGS"; +const std::string MEMORY = "MEMORY"; +const std::string ALIAS = "ALIAS"; +const std::string NAME = "NAME"; +const std::string CONSTANT = "CONSTANT"; + +void DTRLogConstant(const std::string& name) { + if (log_json) { + json j; + j[INSTRUCTION] = CONSTANT; + j[NAME] = name; + DTRLogger::logger().log(j.dump()); + } else { + DTRLogger::logger().log(CONSTANT + " " + name); + } +} + +void DTRLogMemory(const std::string& name, size_t memory) { + if (log_json) { + json j; + j[INSTRUCTION] = MEMORY; + j[NAME] = name; + j[MEMORY] = std::to_string(memory); + DTRLogger::logger().log(j.dump()); + } else { + DTRLogger::logger().log(name + " " + MEMORY + ": " + std::to_string(memory)); + } +} + +void DTRLogAlias(const std::string& name, int index) { + if (log_json) { + json j; + j[INSTRUCTION] = ALIAS; + j[NAME] = name; + j[ALIAS] = std::to_string(index); + DTRLogger::logger().log(j.dump()); + } else { + DTRLogger::logger().log(name + " " + ALIAS + ": " + std::to_string(index)); + } +} + +void DTRLogCopyFrom(const std::string& to, const std::string& from) { + if (log_json) { + json j; + j[INSTRUCTION] = "COPY_FROM"; + j["DST"] = to; + j["SRC"] = from; + DTRLogger::logger().log(j.dump()); + } else { + DTRLogger::logger().log(to + " <- " + from); + } +} + +void DTRLogCopy(const std::string& new_name, const std::string& old_name) { + if (log_json) { + json j; + j[INSTRUCTION] = "COPY"; + j["DST"] = new_name; + j["SRC"] = old_name; + DTRLogger::logger().log(j.dump()); + } else { + DTRLogger::logger().log(new_name + " = " + old_name); + } +} + +void DTRLogMutate(const std::string& name, + const std::vector& args, + const std::vector& mutate, + const std::string& time) { + if (log_json) { + json j; + j[INSTRUCTION] = "MUTATE"; + j[NAME] = name; + j[ARGS] = args; + j["MUTATE"] = mutate; + j[TIME] = time; + DTRLogger::logger().log(j.dump()); + } else { + std::string log = name; + log += "("; + for (const auto& s : args) { + log += s; + log += ", "; + } + log += ") "; + log += " MUTATING: "; + log += "("; + for (const size_t i : mutate) { + log += std::to_string(i); + log += ", "; + } + log += ") "; + log += TIME; + log += ": "; + log += time; + DTRLogger::logger().log(log); + } +} + +void DTRLogRelease(const std::string& name) { + if (log_json) { + json j; + j[INSTRUCTION] = RELEASE; + j[NAME] = name; + DTRLogger::logger().log(j.dump()); + } else { + DTRLogger::logger().log(RELEASE + ": " + name); + } +} + +void DTRLogPin(const std::string& name) { + if (log_json) { + json j; + j[INSTRUCTION] = PIN; + j[NAME] = name; + DTRLogger::logger().log(j.dump()); + } else { + DTRLogger::logger().log(RELEASE + ": " + name); + } +} + +void DTRLogCall(const std::vector& res, + const std::string& name, + const std::vector& args, + const std::string& time) { + if (log_json) { + json j; + j[INSTRUCTION] = "CALL"; + j[NAME] = name; + j["RESULT"] = res; + j[ARGS] = args; + j[TIME] = time; + DTRLogger::logger().log(j.dump()); + } else { + std::string arg = name + "("; + for (const auto& s : args) { + arg += s; + arg += ", "; + } + arg += ")"; + std::string log = "("; + for (const auto& s: res) { + log += s; + log += ", "; + } + log += ") = "; + log += arg; + log += " TIME: "; + log += time; + DTRLogger::logger().log(log); + } +} + +} diff --git a/aten/src/ATen/gen.py b/aten/src/ATen/gen.py index e64bbd891a8..5a1707c6406 100644 --- a/aten/src/ATen/gen.py +++ b/aten/src/ATen/gen.py @@ -353,6 +353,8 @@ def generate_storage_type_and_tensor(backend, density, declarations, per_op_regi if env['DeviceType'] == 'CPU': top_env['cpu_type_headers'].append( '#include "ATen/{}.h"'.format(env['Type'])) + elif env['DeviceType'] == 'Checkpoint': + pass else: assert env['DeviceType'] == 'CUDA' top_env['cuda_type_headers'].append( @@ -411,6 +413,8 @@ def declare_outputs(): fname = gen_per_op_registration_filename(whitelisted_op) file_manager.will_write(fname) + file_manager.will_write("CheckpointType.h") + file_manager.will_write("CheckpointType.cpp") def filter_by_extension(files, *extensions): filtered_files = [] @@ -478,6 +482,8 @@ def generate_outputs(): generate_storage_type_and_tensor( backend, density, declarations, per_op_registrations) + generate_storage_type_and_tensor('Checkpoint', 'Dense', declarations, per_op_registrations) + core_files = { 'TensorBody.h': TENSOR_H, 'TensorMethods.h': TENSOR_METHODS_H, diff --git a/aten/src/ATen/native/Activation.cpp b/aten/src/ATen/native/Activation.cpp index c97aa5a941b..b31c46ec8fd 100644 --- a/aten/src/ATen/native/Activation.cpp +++ b/aten/src/ATen/native/Activation.cpp @@ -712,4 +712,94 @@ Tensor& log_sigmoid_backward_out_cpu( DEFINE_DISPATCH(GeluKernel); DEFINE_DISPATCH(GeluBackwardKernel); +Tensor slice_backward(const Tensor& grad, IntArrayRef input_sizes, int64_t dim, int64_t start, int64_t end, int64_t step) { + auto grad_input = at::zeros(input_sizes, grad.options()); + grad_input.slice(dim, start, end, step).copy_(grad); + return grad_input; +} + +Tensor select_backward(const Tensor& grad, IntArrayRef input_sizes, int64_t dim, int64_t index) { + auto grad_input = at::zeros(input_sizes, grad.options()); + grad_input.select(dim, index).copy_(grad); + return grad_input; +} + +std::vector split(const Tensor& self, int64_t split_size, int64_t dim) { + TORCH_CHECK(self.dim() != 0, "split expects at least a 1-dimensional tensor"); + TORCH_CHECK(split_size >= 0, "split expects split_size be non-negative, but got split_size=", split_size); + int64_t dim_size = self.size(dim); + TORCH_CHECK(split_size > 0 || self.size(dim) == 0, + "split_size can only be 0 if dimension size is 0, " + "but got dimension size of ", dim_size); + // if split_size is 0 and dimension size is 0, there is 1 split. + int64_t num_splits = 1; + if (split_size != 0) { + // ensuring num_splits is at least 1 makes consistent the case where split_size > dim_size + // (returns a single split). We might want to error here, but keep it for BC. + num_splits = std::max((dim_size + split_size - 1) / split_size, 1); + } + std::vector splits(num_splits); + int64_t last_split_size = split_size - (split_size * num_splits - dim_size); + + for (int64_t i = 0; i < num_splits; ++i) { + auto length = i < num_splits - 1 ? split_size : last_split_size; + splits[i] = self.narrow(dim, i * split_size, length); + } + return splits; +} + +std::vector split_with_sizes(const Tensor& self, IntArrayRef split_sizes, int64_t dim) { + TORCH_CHECK(self.dim() != 0, "split expects at least a 1-dimensional tensor"); + int64_t dim_size = self.size(dim); + int64_t num_splits = split_sizes.size(); + std::vector splits(num_splits); + int64_t start_idx = 0; + int64_t i; + + for (i = 0; i < num_splits; ++i) { + auto length = split_sizes[i]; + TORCH_CHECK(length >= 0, + "split_with_sizes expects split_sizes have only non-negative ", + "entries, but got split_sizes=", split_sizes); + splits[i] = self.narrow(dim, start_idx, length); + start_idx += length; + } + TORCH_CHECK(start_idx == dim_size, + "split_with_sizes expects split_sizes to sum exactly to ", dim_size, + " (input tensor's size at dimension ", dim, "), ", "but got split_sizes=", split_sizes); + return splits; +} + +Tensor split_backward(c10::ArrayRef grads, + int64_t split_size, int64_t dim, IntArrayRef sizes, const at::TensorOptions &options) { + dim = at::maybe_wrap_dim(dim, sizes.size()); + int64_t dim_size = sizes[dim]; + int64_t num_splits = grads.size(); + std::vector split_sizes(num_splits, split_size); + split_sizes[num_splits - 1] = split_size - (split_size * num_splits - dim_size); + return at::native::split_with_sizes_backward(grads, split_sizes, dim, sizes, options); +} + +Tensor split_with_sizes_backward(c10::ArrayRef grads, + IntArrayRef split_sizes, int64_t dim, IntArrayRef sizes, const at::TensorOptions &options) { + dim = at::maybe_wrap_dim(dim, sizes.size()); + + // it's possible some of the grads are not defined (represents tensors of all 0s). + // Since at::cat can't handle those, let's define them + std::vector grads_all_defined(grads.size()); + for (size_t j = 0; j < grads.size(); ++j) { + if (grads[j].defined()) { + grads_all_defined[j] = grads[j]; + } else { + auto length = split_sizes[j]; + auto grad_size = sizes.vec(); + grad_size[dim] = length; + grads_all_defined[j] = at::zeros(grad_size, options); + } + } + + auto ret = at::cat(grads_all_defined, dim); + return ret; +} + }} // namespace at::native diff --git a/aten/src/ATen/native/Checkpoint.cpp b/aten/src/ATen/native/Checkpoint.cpp new file mode 100644 index 00000000000..61a21569534 --- /dev/null +++ b/aten/src/ATen/native/Checkpoint.cpp @@ -0,0 +1,1620 @@ +#include +#include + +namespace at { namespace native { + +Tensor checkpoint_add(const Tensor& a, const Tensor& b, c10::Scalar c) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::add(vec.at(0), vec.at(1), c)}; + }; + return CheckpointTensorImpl::make("add", rt, {a, b})[0]; +} + +Tensor checkpoint_t(at::Tensor const& a) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::t(vec.at(0))}; + }; + return CheckpointTensorImpl::make("t", rt, {a})[0]; +} + +Tensor checkpoint_add(at::Tensor const& a, c10::Scalar b, c10::Scalar c) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::add(vec.at(0), b, c)}; + }; + return CheckpointTensorImpl::make("add", rt, {a})[0]; +} + +Tensor& checkpoint_add_(Tensor& a, const Tensor& b, Scalar c) { + mutate_function_t mt = + [=](const Tensors& vec) { + vec.at(0).add_(vec.at(1), c); + }; + CheckpointTensorImpl::mutate("add_", mt, {a, b}, {0}); + return a; +} + +Tensor checkpoint_mul(at::Tensor const& a, at::Tensor const& b) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::mul(vec.at(0), vec.at(1))}; + }; + return CheckpointTensorImpl::make("mul", rt, {a, b})[0]; +} + +Tensor& checkpoint_mul_(at::Tensor& a, at::Tensor const& b) { + mutate_function_t mt = + [=](const Tensors& vec) { + vec.at(0).mul_(vec.at(1)); + }; + CheckpointTensorImpl::mutate("mul_", mt, {a, b}, {0}); + return a; +} + +Tensor& checkpoint_mul_(at::Tensor& a, c10::Scalar b) { + mutate_function_t mt = + [=](const Tensors& vec) { + vec.at(0).mul_(b); + }; + CheckpointTensorImpl::mutate("mul_", mt, {a}, {0}); + return a; +} + +Tensor checkpoint_zeros_like(at::Tensor const& a, c10::TensorOptions const& b, c10::optional c) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::zeros_like(vec.at(0), b, c)}; + }; + return CheckpointTensorImpl::make("zeros_like", rt, {a})[0]; +} + +Tensor checkpoint_ones_like(at::Tensor const& a, c10::TensorOptions const& b, c10::optional c) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::ones_like(vec.at(0), b, c)}; + }; + return CheckpointTensorImpl::make("ones_like", rt, {a})[0]; +} + +Tensor checkpoint_addcmul(at::Tensor const& a, at::Tensor const& b, at::Tensor const& c, c10::Scalar d) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::addcmul(vec.at(0), vec.at(1), vec.at(2), d)}; + }; + return CheckpointTensorImpl::make("addcmul", rt, {a, b, c})[0]; +} + +Tensor& checkpoint_addcmul_(at::Tensor& a, at::Tensor const& b, at::Tensor const& c, c10::Scalar d) { + mutate_function_t mt = + [=](const Tensors& vec) { + vec.at(0).addcmul_(vec.at(1), vec.at(2), d); + }; + CheckpointTensorImpl::mutate("addcmul_", mt, {a, b, c}, {0}); + return a; +} + +Tensor checkpoint_abs(const Tensor& a) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::abs(vec.at(0))}; + }; + return CheckpointTensorImpl::make("abs", rt, {a})[0]; +} + +Tensor checkpoint_sqrt(const Tensor& a) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::sqrt(vec.at(0))}; + }; + return CheckpointTensorImpl::make("sqrt", rt, {a})[0]; +} + +Tensor& checkpoint_addcdiv_(at::Tensor& a, at::Tensor const& b, at::Tensor const& c, c10::Scalar d) { + mutate_function_t mt = + [=](const Tensors& vec) { + vec.at(0).addcdiv_(vec.at(1), vec.at(2), d); + }; + CheckpointTensorImpl::mutate("addcdiv_", mt, {a, b, c}, {0}); + return a; +} + +Tensor checkpoint_addcdiv(at::Tensor const& a, at::Tensor const& b, at::Tensor const& c, c10::Scalar d) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::addcdiv(vec.at(0), vec.at(1), vec.at(2), d)}; + }; + return CheckpointTensorImpl::make("addcdiv", rt, {a, b, c})[0]; +} + +Tensor checkpoint_to(at::Tensor const& a, c10::TensorOptions const& b, bool c, bool d, c10::optional e) { + c10::TensorOptions b_ = b; + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {vec.at(0).to(b_, c, d, e)}; + }; + return CheckpointTensorImpl::make("to", rt, {a})[0]; +} + +Tensor checkpoint_div(const Tensor& a, const Tensor& b) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::div(vec.at(0), vec.at(1))}; + }; + return CheckpointTensorImpl::make("div", rt, {a, b})[0]; +} + +Tensor& checkpoint_div_(Tensor& a, const Tensor& b) { + mutate_function_t mt = + [=](const Tensors& vec) { + vec.at(0).div_(vec.at(1)); + }; + CheckpointTensorImpl::mutate("div_", mt, {a, b}, {0}); + return a; +} + +Tensor checkpoint_clone(at::Tensor const& a, c10::optional b) { + if (b) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::clone(vec.at(0), b)}; + }; + return CheckpointTensorImpl::make("clone", rt, {a})[0]; + } + else { + return a; + } +} + +Tensor checkpoint_where(at::Tensor const& a, at::Tensor const& b, at::Tensor const& c) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::where(vec.at(0), vec.at(1), vec.at(2))}; + }; + return CheckpointTensorImpl::make("where", rt, {a, b, c})[0]; +} + +Tensor checkpoint_constant_pad_nd(Tensor const& a, c10::ArrayRef b, c10::Scalar c) { + std::vector b_ = b.vec(); + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::constant_pad_nd(vec.at(0), b_, c)}; + }; + return CheckpointTensorImpl::make("constant_pad_nd", rt, {a})[0]; +} + +Tensor checkpoint_binary_cross_entropy(const Tensor& a, const Tensor& b, const Tensor& c, long d) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::binary_cross_entropy(vec.at(0), vec.at(1), vec.at(2), d)}; + }; + return CheckpointTensorImpl::make("binary_cross_entropy", rt, {a, b, c})[0]; +} + +Tensor& checkpoint_binary_cross_entropy_out(Tensor& a, const Tensor& b, const Tensor& c, const Tensor& d, long e) { + mutate_function_t mt = + [=](const Tensors& vec) { + Tensor self = vec.at(0); + at::binary_cross_entropy_out(self, vec.at(1), vec.at(2), vec.at(3), e); + }; + CheckpointTensorImpl::mutate("binary_cross_entropy_out", mt, {a, b, c, d}, {0}); + return a; +} + +Tensor checkpoint_binary_cross_entropy_backward(const Tensor& a, const Tensor& b, const Tensor& c, const Tensor& d, long e) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::binary_cross_entropy_backward(vec.at(0), vec.at(1), vec.at(2), vec.at(3), e)}; + }; + return CheckpointTensorImpl::make("binary_cross_entropy_backward", rt, {a, b, c, d})[0]; +} + +Tensor& checkpoint_binary_cross_entropy_backward_out(Tensor& a, const Tensor& b, const Tensor& c, const Tensor& d, const Tensor& e, long f) { + mutate_function_t mt = + [=](const Tensors& vec) { + Tensor self = vec.at(0); + at::binary_cross_entropy_backward_out(self, vec.at(1), vec.at(2), vec.at(3), vec.at(4), f); + }; + CheckpointTensorImpl::mutate("binary_cross_entropy_backward_out", mt, {a, b, c, d, e}, {0}); + return a; +} + +Tensor checkpoint_embedding(const Tensor& a, const Tensor& b, long c, bool d, bool e) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::embedding(vec.at(0), vec.at(1), c, d, e)}; + }; + return CheckpointTensorImpl::make("embedding", rt, {a, b})[0]; +} + +Tensor checkpoint_embedding_backward(const Tensor& a, const Tensor& b, long c, long d, bool e, bool f) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::embedding_backward(vec.at(0), vec.at(1), c, d, e, f)}; + }; + return CheckpointTensorImpl::make("embedding", rt, {a, b})[0]; +} + +std::tuple +checkpoint_cudnn_batch_norm(const Tensor& a, const Tensor& b, const Tensor& c, const Tensor& d, const Tensor& e, bool f, double g, double h) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + auto ret = at::cudnn_batch_norm(vec.at(0), vec.at(1), vec.at(2), vec.at(3), vec.at(4), f, g, h); + return {std::get<0>(ret), std::get<1>(ret), std::get<2>(ret), std::get<3>(ret)}; + }; + auto ret = CheckpointTensorImpl::make("cudnn_batch_norm", rt, {a, b, c, d, e}); + return {ret[0], ret[1], ret[2], ret[3]}; +} + +std::tuple checkpoint_cudnn_batch_norm_backward(at::Tensor const& a, at::Tensor const& b, at::Tensor const& c, at::Tensor const& d, at::Tensor const& e, at::Tensor const& f, at::Tensor const& g, double h, at::Tensor const& i) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + auto ret = at::cudnn_batch_norm_backward(vec.at(0), vec.at(1), vec.at(2), vec.at(3), vec.at(4), vec.at(5), vec.at(6), h, vec.at(7)); + return {std::get<0>(ret), std::get<1>(ret), std::get<2>(ret)}; + }; + auto ret = CheckpointTensorImpl::make("cudnn_batch_norm_backward", rt, {a, b, c, d, e, f, g, i}); + return {ret[0], ret[1], ret[2]}; +} + +Tensor checkpoint_as_strided(const Tensor& a, c10::ArrayRef b, c10::ArrayRef c, c10::optional d) { + std::vector b_ = b.vec(), c_ = c.vec(); + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::as_strided(vec.at(0), b_, c_, d)}; + }; + return CheckpointTensorImpl::make("as_strided", rt, {a})[0]; +} + +Tensor checkpoint__masked_scale(const Tensor& a, const Tensor& b, double c) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::_masked_scale(vec.at(0), vec.at(1), c)}; + }; + return CheckpointTensorImpl::make("_masked_scale", rt, {a, b})[0]; +} + +Tensor checkpoint_cudnn_convolution(const Tensor& a, const Tensor& b, c10::ArrayRef c, c10::ArrayRef d, c10::ArrayRef e, long f, bool g, bool h) { + std::vector c_ = c.vec(), d_ = d.vec(), e_ = e.vec(); + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::cudnn_convolution(vec.at(0), vec.at(1), c_, d_, e_, f, g, h)}; + }; + return CheckpointTensorImpl::make("cudnn_convolution", rt, {a, b})[0]; +} + +Tensor checkpoint_cudnn_convolution_transpose(const Tensor& a, const Tensor& b, c10::ArrayRef c, c10::ArrayRef d, c10::ArrayRef e, c10::ArrayRef f, long g, bool h, bool i) { + std::vector c_ = c.vec(), d_ = d.vec(), e_ = e.vec(), f_ = f.vec(); + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::cudnn_convolution_transpose(vec.at(0), vec.at(1), c_, d_, e_, f_, g, h, i)}; + }; + return CheckpointTensorImpl::make("cudnn_convolution_transpose", rt, {a, b})[0]; +} + +std::tuple checkpoint_cudnn_convolution_backward(const Tensor& a, const Tensor& b, const Tensor& c, c10::ArrayRef d, c10::ArrayRef e, c10::ArrayRef f, long g, bool h, bool i, std::array j) { + std::vector d_ = d.vec(), e_ = e.vec(), f_ = f.vec(); + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + auto ret = at::cudnn_convolution_backward(vec.at(0), vec.at(1), vec.at(2), d_, e_, f_, g, h, i, j); + return {std::get<0>(ret), std::get<1>(ret)}; + }; + auto ret = CheckpointTensorImpl::make("cudnn_convolution_backward", rt, {a, b, c}); + return {ret[0], ret[1]}; +} + +std::tuple checkpoint_cudnn_convolution_transpose_backward(const Tensor& a, const Tensor& b, const Tensor& c, c10::ArrayRef d, c10::ArrayRef e, c10::ArrayRef f, c10::ArrayRef g, long h, bool i, bool j, std::array k) { + std::vector d_ = d.vec(), e_ = e.vec(), f_ = f.vec(), g_ = g.vec(); + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + auto ret = at::cudnn_convolution_transpose_backward(vec.at(0), vec.at(1), vec.at(2), d_, e_, f_, g_, h, i, j, k); + return {std::get<0>(ret), std::get<1>(ret)}; + }; + auto ret = CheckpointTensorImpl::make("cudnn_convolution_transpose_backward", rt, {a, b, c}); + return {ret[0], ret[1]}; +} + +Tensor checkpoint_cudnn_convolution_backward_input(c10::ArrayRef a, const Tensor& b, const Tensor& c, c10::ArrayRef d, c10::ArrayRef e, c10::ArrayRef f, long g, bool h, bool i) { + std::vector a_ = a.vec(), d_ = d.vec(), e_ = e.vec(), f_ = f.vec(); + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::cudnn_convolution_backward_input(a_, vec.at(0), vec.at(1), d_, e_, f_, g, h, i)}; + }; + return CheckpointTensorImpl::make("cudnn_convolution_backward_input", rt, {b, c})[0]; +} + +Tensor checkpoint_cudnn_convolution_transpose_backward_input(const Tensor& a, const Tensor& b, c10::ArrayRef c, c10::ArrayRef d, c10::ArrayRef e, long f, bool g, bool h) { + std::vector c_ = c.vec(), d_ = d.vec(), e_ = e.vec(); + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::cudnn_convolution_transpose_backward_input(vec.at(0), vec.at(1), c_, d_, e_, f, g, h)}; + }; + return CheckpointTensorImpl::make("cudnn_convolution_transpose_backward_input", rt, {a, b})[0]; +} + +Tensor checkpoint_cudnn_convolution_backward_weight(c10::ArrayRef a, const Tensor& b, const Tensor& c, c10::ArrayRef d, c10::ArrayRef e, c10::ArrayRef f, long g, bool h, bool i) { + std::vector a_ = a.vec(), d_ = d.vec(), e_ = e.vec(), f_ = f.vec(); + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::cudnn_convolution_backward_weight(a_, vec.at(0), vec.at(1), d_, e_, f_, g, h, i)}; + }; + return CheckpointTensorImpl::make("cudnn_convolution_backward_weight", rt, {b, c})[0]; +} + +Tensor checkpoint_cudnn_convolution_transpose_backward_weight(c10::ArrayRef a, const Tensor& b, const Tensor& c, c10::ArrayRef d, c10::ArrayRef e, c10::ArrayRef f, long g, bool h, bool i) { + std::vector a_ = a.vec(), d_ = d.vec(), e_ = e.vec(), f_ = f.vec(); + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::cudnn_convolution_transpose_backward_weight(a_, vec.at(0), vec.at(1), d_, e_, f_, g, h, i)}; + }; + return CheckpointTensorImpl::make("cudnn_convolution_transpose_backward_weight", rt, {b, c})[0]; +} + +Tensor checkpoint_relu(const Tensor& a) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::relu(vec.at(0))}; + }; + return CheckpointTensorImpl::make("relu", rt, {a})[0]; +} + +Tensor& checkpoint_relu_(Tensor& a) { + mutate_function_t mt = + [=](const Tensors& vec) { + vec.at(0).relu_(); + }; + CheckpointTensorImpl::mutate("relu_", mt, {a}, {0}); + return a; +} + +Tensor checkpoint_log(at::Tensor const& a) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::log(vec.at(0))}; + }; + return CheckpointTensorImpl::make("log", rt, {a})[0]; +} + +Tensor& checkpoint_log_out(at::Tensor& a, at::Tensor const& b) { + mutate_function_t mt = + [=](const Tensors& vec) { + Tensor a_ = vec.at(0); + at::log_out(a_, vec.at(1)); + }; + CheckpointTensorImpl::mutate("log_out", mt, {a, b}, {0}); + return a; +} + +Tensor checkpoint_rsub(at::Tensor const& a, at::Tensor const& b, c10::Scalar c) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::rsub(vec.at(0), vec.at(1), c)}; + }; + return CheckpointTensorImpl::make("rsub", rt, {a, b})[0]; +} + +Tensor checkpoint_rsub(at::Tensor const& a, c10::Scalar b, c10::Scalar c) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::rsub(vec.at(0), b, c)}; + }; + return CheckpointTensorImpl::make("rsub", rt, {a})[0]; +} + +Tensor checkpoint_mul(at::Tensor const& a, c10::Scalar b) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::mul(vec.at(0), b)}; + }; + return CheckpointTensorImpl::make("mul", rt, {a})[0]; +} + +std::tuple checkpoint_max_pool2d_with_indices_out(Tensor& a, Tensor& b, const Tensor& c, c10::ArrayRef d, c10::ArrayRef e, c10::ArrayRef f, c10::ArrayRef g, bool h) { + std::vector d_ = d.vec(), e_ = e.vec(), f_ = f.vec(), g_ = g.vec(); + mutate_function_t mt = + [=](const Tensors& vec) { + Tensor a_ = vec.at(0), b_ = vec.at(1); + at::max_pool2d_with_indices_out(a_, b_, vec.at(2), d_, e_, f_, g_, h); + }; + CheckpointTensorImpl::mutate("max_pool2d_with_indices_out", mt, {a, b, c}, {0, 1}); + return {a, b}; +} + +Tensor checkpoint_avg_pool2d(const Tensor& a, c10::ArrayRef b, c10::ArrayRef c, c10::ArrayRef d, bool e, bool f, c10::optional g) { + std::vector b_ = b.vec(), c_ = c.vec(), d_ = d.vec(); + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::avg_pool2d(vec.at(0), b_, c_, d_, e, f, g)}; + }; + return CheckpointTensorImpl::make("avg_pool2d", rt, {a})[0]; +} + +Tensor checkpoint_avg_pool2d_backward(const Tensor& a, const Tensor& b, c10::ArrayRef c, c10::ArrayRef d, c10::ArrayRef e, bool f, bool g, c10::optional h) { + std::vector c_ = c.vec(), d_ = d.vec(), e_ = e.vec(); + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::avg_pool2d_backward(vec.at(0), vec.at(1), c_, d_, e_, f, g, h)}; + }; + return CheckpointTensorImpl::make("avg_pool2d_backward", rt, {a, b})[0]; +} + +Tensor& checkpoint_avg_pool2d_out(Tensor& a, const Tensor& b, c10::ArrayRef c, c10::ArrayRef d, c10::ArrayRef e, bool f, bool g, c10::optional h) { + std::vector c_ = c.vec(), d_ = d.vec(), e_ = e.vec(); + mutate_function_t mt = + [=](const Tensors& vec) { + Tensor a_ = vec.at(0); + at::avg_pool2d_out(a_, vec.at(1), c_, d_, e_, f, g, h); + }; + CheckpointTensorImpl::mutate("avg_pool2d_out", mt, {a, b}, {0}); + return a; +} + +Tensor& checkpoint_avg_pool2d_backward_grad_input(Tensor& a, const Tensor& b, const Tensor& c, c10::ArrayRef d, c10::ArrayRef e, c10::ArrayRef f, bool g, bool h, c10::optional i) { + std::vector d_ = d.vec(), e_ = e.vec(), f_ = f.vec(); + mutate_function_t mt = + [=](const Tensors& vec) { + Tensor a_ = vec.at(0); + at::avg_pool2d_backward_out(a_, vec.at(1), vec.at(2), d_, e_, f_, g, h, i); + }; + CheckpointTensorImpl::mutate("avg_pool2d_backward_grad_input", mt, {a, b, c}, {0}); + return a; +} + +std::tuple checkpoint_max_pool2d_with_indices(const Tensor& a, c10::ArrayRef b, c10::ArrayRef c, c10::ArrayRef d, c10::ArrayRef e, bool f) { + std::vector b_ = b.vec(), c_ = c.vec(), d_ = d.vec(), e_ = e.vec(); + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + auto ret = at::max_pool2d_with_indices(vec.at(0), b_, c_, d_, e_, f); + return {std::get<0>(ret), std::get<1>(ret)}; + }; + auto ret = CheckpointTensorImpl::make("max_pool2d_backward", rt, {a}); + return {ret[0], ret[1]}; +} + +Tensor& checkpoint_max_pool2d_with_indices_backward_grad_input(Tensor& a, const Tensor& b, const Tensor& c, c10::ArrayRef d, c10::ArrayRef e, c10::ArrayRef f, c10::ArrayRef g, bool h, const Tensor& i) { + std::vector d_ = d.vec(), e_ = e.vec(), f_ = f.vec(), g_ = g.vec(); + mutate_function_t mt = + [=](const Tensors& vec) { + Tensor a_ = vec.at(0); + at::max_pool2d_with_indices_backward_out(a_, vec.at(1), vec.at(2), d_, e_, f_, g, h, vec.at(3)); + }; + CheckpointTensorImpl::mutate("max_pool2d_with_indices_backward_grad_input", mt, {a, b, c, i}, {0}); + return a; +} + +Tensor checkpoint_max_pool2d_with_indices_backward(const Tensor& a, const Tensor& b, c10::ArrayRef c, c10::ArrayRef d, c10::ArrayRef e, c10::ArrayRef f, bool g, const Tensor& h) { + std::vector c_ = c.vec(), d_ = d.vec(), e_ = e.vec(), f_ = f.vec(); + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::max_pool2d_with_indices_backward(vec.at(0), vec.at(1), c_, d_, e_, f_, g, vec.at(2))}; + }; + return CheckpointTensorImpl::make("max_pool2d_with_indices_backward", rt, {a, b, h})[0]; +} + +Tensor checkpoint_view(const Tensor& a, c10::ArrayRef b) { + std::vector b_ = b.vec(); + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {vec.at(0).view(b_)}; + }; + return CheckpointTensorImpl::make("view", rt, {a})[0]; +} + +Tensor checkpoint_ne_Scalar(const Tensor& a, c10::Scalar b) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::ne(vec.at(0), b)}; + }; + return CheckpointTensorImpl::make("ne_Scalar", rt, {a})[0]; +} + +Tensor& checkpoint_ne_Scalar_out(Tensor& a, const Tensor& b, c10::Scalar c) { + mutate_function_t mt = + [=](const Tensors& vec) { + Tensor a_ = vec.at(0); + at::ne_out(a_, vec.at(1), c); + }; + CheckpointTensorImpl::mutate("ne_Scalar_out", mt, {a, b}, {0}); + return a; +} + +Tensor checkpoint_ne_Tensor(const Tensor& a, const Tensor& b) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::ne(vec.at(0), vec.at(1))}; + }; + return CheckpointTensorImpl::make("ne_Tensor", rt, {a, b})[0]; +} + +Tensor& checkpoint_ne_Tensor_out(Tensor& a, const Tensor& b, const Tensor& c) { + mutate_function_t mt = + [=](const Tensors& vec) { + Tensor a_ = vec.at(0); + at::ne_out(a_, vec.at(1), vec.at(2)); + }; + CheckpointTensorImpl::mutate("ne_Tensor_out", mt, {a, b, c}, {0}); + return a; +} + +Tensor checkpoint_eq_Scalar(const Tensor& a, c10::Scalar b) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::eq(vec.at(0), b)}; + }; + return CheckpointTensorImpl::make("eq_Scalar", rt, {a})[0]; +} + +Tensor& checkpoint_eq_Scalar_out(Tensor& a, const Tensor& b, c10::Scalar c) { + mutate_function_t mt = + [=](const Tensors& vec) { + Tensor a_ = vec.at(0); + at::eq_out(a_, vec.at(1), c); + }; + CheckpointTensorImpl::mutate("eq_Scalar_out", mt, {a, b}, {0}); + return a; +} + +Tensor checkpoint_eq_Tensor(const Tensor& a, const Tensor& b) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::eq(vec.at(0), vec.at(1))}; + }; + return CheckpointTensorImpl::make("eq_Tensor", rt, {a, b})[0]; +} + +Tensor& checkpoint_eq_Tensor_out(Tensor& a, const Tensor& b, const Tensor& c) { + mutate_function_t mt = + [=](const Tensors& vec) { + Tensor a_ = vec.at(0); + at::eq_out(a_, vec.at(1), vec.at(2)); + }; + CheckpointTensorImpl::mutate("eq_Tensor_out", mt, {a, b, c}, {0}); + return a; +} + +Tensor checkpoint_addmm(const Tensor& a, const Tensor& b, const Tensor& c, c10::Scalar d, c10::Scalar e) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::addmm(vec.at(0), vec.at(1), vec.at(2), d, e)}; + }; + return CheckpointTensorImpl::make("addmm", rt, {a, b, c})[0]; +} + +Tensor& checkpoint_addmm_out(Tensor& a, const Tensor& b, const Tensor& c, const Tensor& d, c10::Scalar e, c10::Scalar f) { + mutate_function_t mt = + [=](const Tensors& vec) { + Tensor a_ = vec.at(0); + at::addmm_out(a_, vec.at(1), vec.at(2), d, e, f); + }; + CheckpointTensorImpl::mutate("addmm_out", mt, {a, b, c}, {0}); + return a; +} + +Tensor& checkpoint_addmm_(Tensor& a, const Tensor& b, const Tensor& c, c10::Scalar d, c10::Scalar e) { + mutate_function_t mt = + [=](const Tensors& vec) { + Tensor a_ = vec.at(0); + a.addmm_(vec.at(1), vec.at(2), d, e); + }; + CheckpointTensorImpl::mutate("addmm_", mt, {a, b, c}, {0}); + return a; +} + +Tensor checkpoint_sigmoid(const Tensor& a) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::sigmoid(vec.at(0))}; + }; + return CheckpointTensorImpl::make("sigmoid", rt, {a})[0]; +} + +Tensor& checkpoint_sigmoid_(Tensor& a) { + mutate_function_t mt = + [=](const Tensors& vec) { + Tensor a_ = vec.at(0); + a.sigmoid_(); + }; + CheckpointTensorImpl::mutate("sigmoid_", mt, {a}, {0}); + return a; +} + +Tensor checkpoint__log_softmax(const Tensor& a, long b, bool c) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::_log_softmax(vec.at(0), b, c)}; + }; + return CheckpointTensorImpl::make("_log_softmax", rt, {a})[0]; +} + +Tensor checkpoint__log_softmax_backward_data(const Tensor& a, const Tensor& b, long c, const Tensor& d) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::_log_softmax_backward_data(vec.at(0), vec.at(1), c, vec.at(2))}; + }; + return CheckpointTensorImpl::make("_log_softmax_backward_data", rt, {a, b, d})[0]; +} + +std::tuple checkpoint_nll_loss_forward(const Tensor& a, const Tensor& b, const Tensor& c, long d, long e) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + auto ret = at::nll_loss_forward(vec.at(0), vec.at(1), vec.at(2), d, e); + return {std::get<0>(ret), std::get<1>(ret)}; + }; + auto ret = CheckpointTensorImpl::make("nll_loss_forward", rt, {a, b, c}); + return {ret[0], ret[1]}; +} + +std::tuple checkpoint_nll_loss_forward_out(Tensor& a, Tensor& b, const Tensor& c, const Tensor& d, const Tensor& e, long f, long g) { + mutate_function_t mt = + [=](const Tensors& vec) { + Tensor a_ = vec.at(0); + Tensor b_ = vec.at(1); + at::nll_loss_forward_out(a_, b_, vec.at(2), vec.at(3), vec.at(4), f, g); + }; + CheckpointTensorImpl::mutate("nll_loss_forward_out", mt, {a, b, c, d, e}, {0, 1}); + return {a, b}; +} + +Tensor checkpoint_nll_loss_backward(const Tensor& a, const Tensor& b, const Tensor& c, const Tensor& d, long e, long f, const Tensor& g) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::nll_loss_backward(vec.at(0), vec.at(1), vec.at(2), vec.at(3), e, f, vec.at(4))}; + }; + return CheckpointTensorImpl::make("nll_loss_backward", rt, {a, b, c, d, g})[0]; +} + +Tensor& checkpoint_nll_loss_backward_grad_input(Tensor& a, const Tensor& b, const Tensor& c, const Tensor& d, const Tensor& e, long f, long g, const Tensor& h) { + mutate_function_t mt = + [=](const Tensors& vec) { + Tensor a_ = vec.at(0); + at::nll_loss_backward_out(a_, vec.at(1), vec.at(2), vec.at(3), vec.at(4), f, g, vec.at(5)); + }; + CheckpointTensorImpl::mutate("nll_loss_backward_grad_input", mt, {a, b, c, d, e, h}, {0}); + return a; +} + +Tensor checkpoint_mm(const Tensor& a, const Tensor& b) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::mm(vec.at(0), vec.at(1))}; + }; + return CheckpointTensorImpl::make("mm", rt, {a, b})[0]; +} + +Tensor& checkpoint_mm_out(Tensor& a, const Tensor& b, const Tensor& c) { + mutate_function_t mt = + [=](const Tensors& vec) { + Tensor a_ = vec.at(0); + at::mm_out(a_, vec.at(1), vec.at(2)); + }; + CheckpointTensorImpl::mutate("mm_out", mt, {a, b, c}, {0}); + return a; +} + +Tensor checkpoint_sum(const Tensor& a, c10::optional b) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::sum(vec.at(0), b)}; + }; + return CheckpointTensorImpl::make("sum", rt, {a})[0]; +} + +Tensor checkpoint_sum_dim_IntList(const Tensor& a, c10::ArrayRef b, bool c, c10::optional d) { + std::vector b_ = b.vec(); + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::sum(vec.at(0), b_, c, d)}; + }; + return CheckpointTensorImpl::make("sum_dim_IntList", rt, {a})[0]; +} + +Tensor checkpoint_threshold(const Tensor& a, c10::Scalar b, c10::Scalar c) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::threshold(vec.at(0), b, c)}; + }; + return CheckpointTensorImpl::make("threshold", rt, {a})[0]; +} + +Tensor& checkpoint_threshold_(Tensor& a, c10::Scalar b, c10::Scalar c) { + mutate_function_t mt = + [=](const Tensors& vec) { + Tensor a_ = vec.at(0); + at::threshold_(a_, b, c); + }; + CheckpointTensorImpl::mutate("threshold_", mt, {a}, {0}); + return a; +} + +Tensor& checkpoint_threshold_out(Tensor& a, const Tensor& b, c10::Scalar c, c10::Scalar d) { + mutate_function_t mt = + [=](const Tensors& vec) { + Tensor a_ = vec.at(0); + at::threshold_out(a_, b, c, d); + }; + CheckpointTensorImpl::mutate("threshold_out", mt, {a}, {0}); + return a; +} + +Tensor checkpoint_threshold_backward(const Tensor& a, const Tensor& b, c10::Scalar c) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::threshold_backward(vec.at(0), vec.at(1), c)}; + }; + return CheckpointTensorImpl::make("threshold_backward", rt, {a, b})[0]; +} + +Tensor checkpoint_select(const Tensor& a, long b, long c) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::select(vec.at(0), b, c)}; + }; + return CheckpointTensorImpl::make("select", rt, {a})[0]; +} + +Tensor checkpoint_select_backward(const Tensor& a, c10::ArrayRef b, long c, long d) { + std::vector b_ = b.vec(); + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::select_backward(vec.at(0), b_, c, d)}; + }; + return CheckpointTensorImpl::make("select_backward", rt, {a})[0]; +} + +Tensor checkpoint_slice(const Tensor& a, long b, long c, long d, long e) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::slice(vec.at(0), b, c, d, e)}; + }; + return CheckpointTensorImpl::make("slice", rt, {a})[0]; +} + +Tensor checkpoint_slice_backward(const Tensor& a, c10::ArrayRef b, long c, long d, long e, long f) { + std::vector b_ = b.vec(); + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::slice_backward(vec.at(0), b_, c, d, e, f)}; + }; + return CheckpointTensorImpl::make("slice_backward", rt, {a})[0]; +} + +Tensor& checkpoint_zero_(Tensor& a) { + mutate_function_t mt = + [=](const Tensors& vec) { + vec.at(0).zero_(); + }; + CheckpointTensorImpl::mutate("zero_", mt, {a}, {0}); + return a; +} + +Tensor& checkpoint_squeeze_(at::Tensor& a, at::Dimname b) { + mutate_function_t mt = + [=](const Tensors& vec) { + vec.at(0).squeeze_(b); + }; + CheckpointTensorImpl::mutate("squeeze_", mt, {a}, {0}); + return a; +} + +Tensor& checkpoint_squeeze_(at::Tensor& a) { + mutate_function_t mt = + [=](const Tensors& vec) { + vec.at(0).squeeze_(); + }; + CheckpointTensorImpl::mutate("squeeze_", mt, {a}, {0}); + return a; +} + +Tensor& checkpoint_squeeze_(at::Tensor& a, long b) { + mutate_function_t mt = + [=](const Tensors& vec) { + vec.at(0).squeeze_(b); + }; + CheckpointTensorImpl::mutate("squeeze_", mt, {a}, {0}); + return a; +} + +Tensor checkpoint_sigmoid_backward(at::Tensor const& a, at::Tensor const& b) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::sigmoid_backward(vec.at(0), vec.at(1))}; + }; + return CheckpointTensorImpl::make("sigmoid_backward", rt, {a, b})[0]; +} + +Tensor& checkpoint_sigmoid_backward_out(at::Tensor& a, at::Tensor const& b, at::Tensor const& c) { + mutate_function_t mt = + [=](const Tensors& vec) { + Tensor a_ = vec.at(0); + at::sigmoid_backward_out(a_, vec.at(1), vec.at(2)); + }; + CheckpointTensorImpl::mutate("sigmoid_backward_out", mt, {a, b, c}, {0}); + return a; +} + +Tensor& checkpoint_sign_out(at::Tensor& a, at::Tensor const& b) { + mutate_function_t mt = + [=](const Tensors& vec) { + Tensor a_ = vec.at(0); + at::sign_out(a_, vec.at(1)); + }; + CheckpointTensorImpl::mutate("sign_out", mt, {a, b}, {0}); + return a; +} + +Tensor checkpoint_sign(const Tensor& a) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::sign(vec.at(0))}; + }; + return CheckpointTensorImpl::make("sign", rt, {a})[0]; +} + +Tensor checkpoint_tanh(const Tensor& a) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::tanh(vec.at(0))}; + }; + return CheckpointTensorImpl::make("tanh", rt, {a})[0]; +} + +Tensor checkpoint_tanh_backward(at::Tensor const& a, at::Tensor const& b) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::tanh_backward(vec.at(0), vec.at(1))}; + }; + return CheckpointTensorImpl::make("tanh_backward", rt, {a, b})[0]; +} + +Tensor& checkpoint_tanh_backward_out(at::Tensor& a, at::Tensor const& b, at::Tensor const& c) { + mutate_function_t mt = + [=](const Tensors& vec) { + Tensor a_ = vec.at(0); + at::tanh_backward_out(a_, vec.at(1), vec.at(2)); + }; + CheckpointTensorImpl::mutate("tanh_backward_out", mt, {a, b, c}, {0}); + return a; +} + +Tensor checkpoint_neg(at::Tensor const& a) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::neg(vec.at(0))}; + }; + return CheckpointTensorImpl::make("neg", rt, {a})[0]; +} + +Tensor checkpoint_sub(at::Tensor const& a, at::Tensor const& b, c10::Scalar c) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::sub(vec.at(0), vec.at(1), c)}; + }; + return CheckpointTensorImpl::make("sub", rt, {a, b})[0]; +} + +Tensor& checkpoint_sub_(at::Tensor& a, at::Tensor const& b, c10::Scalar c) { + mutate_function_t mt = + [=](const Tensors& vec) { + Tensor self = vec.at(0); + self.sub_(vec.at(1), c); + }; + CheckpointTensorImpl::mutate("sub_", mt, {a, b}, {0}); + return a; +} + +Tensor checkpoint_repeat(const at::Tensor& a, c10::ArrayRef b) { + std::vector b_ = b.vec(); + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {vec.at(0).repeat(b_)}; + }; + return CheckpointTensorImpl::make("repeat", rt, {a})[0]; +} + +Tensor checkpoint_mean(const Tensor& self, c10::optional dtype) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::native::mean_cpu_gpu(vec[0], dtype)}; + }; + return CheckpointTensorImpl::make("mean", rt, {self})[0]; +} + +Tensor checkpoint_mean(const Tensor& self, IntArrayRef dim, bool keepdim, c10::optional dtype) { + std::vector dim_ = dim.vec(); + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::native::mean_cpu_gpu(vec[0], dim_, keepdim, dtype)}; + }; + return CheckpointTensorImpl::make("mean.dim", rt, {self})[0]; +} + +Tensor checkpoint__cat(c10::ArrayRef a, long b) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::cat(vec, b)}; + }; + std::vector s; + for (const Tensor& t : a) { + s.push_back(t); + } + return CheckpointTensorImpl::make("_cat", rt, s)[0]; +} + +Tensor& checkpoint__cat_out(Tensor& a, c10::ArrayRef b, long c) { + std::vector args; + args.push_back(a); + for (const Tensor& t : b) { + args.push_back(t); + } + mutate_function_t mt = + [=](const Tensors& vec) { + Tensor t = vec[0]; + at::cat_out(t, ArrayRef(vec.data() + 1, vec.size() - 1), c); + }; + CheckpointTensorImpl::mutate("_cat_out", mt, args, {0}); + return a; +} + +Tensor checkpoint_kl_div(at::Tensor const& a, at::Tensor const& b, long c) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::kl_div(vec.at(0), vec.at(1), c)}; + }; + return CheckpointTensorImpl::make("kl_div", rt, {a, b})[0]; +} + +Tensor checkpoint_kl_div_backward(at::Tensor const& a, at::Tensor const& b, at::Tensor const& c, long d) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::kl_div_backward(vec.at(0), vec.at(1), vec.at(2), d)}; + }; + return CheckpointTensorImpl::make("kl_div_backward", rt, {a, b, c})[0]; +} + +Tensor checkpoint_upsample_bilinear2d(at::Tensor const& self, c10::ArrayRef output_size, bool align_corners, c10::optional scales_h, c10::optional scales_w) { + std::vector output_size_ = output_size.vec(); + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::upsample_bilinear2d(vec.at(0), output_size_, align_corners, scales_h, scales_w)}; + }; + return CheckpointTensorImpl::make("upsample_bilinear2d", rt, {self})[0]; +} + +Tensor& checkpoint_upsample_bilinear2d_out(at::Tensor& out, const at::Tensor& self, c10::ArrayRef output_size, bool align_corners, c10::optional scales_h, c10::optional scales_w) { + std::vector output_size_ = output_size.vec(); + mutate_function_t mt = + [=](const Tensors& vec) { + Tensor out = vec.at(0); + at::upsample_bilinear2d_out(out, vec.at(1), output_size_, align_corners, scales_h, scales_w); + }; + CheckpointTensorImpl::mutate("binary_cross_entropy_out", mt, {out, self}, {0}); + return out; +} + +Tensor& checkpoint_upsample_bilinear2d_backward_out(at::Tensor& grad_input, const at::Tensor& grad_output, c10::ArrayRef output_size, c10::ArrayRef input_size, bool align_corners, c10::optional scales_h, c10::optional scales_w) { + std::vector output_size_ = output_size.vec(); + std::vector input_size_ = input_size.vec(); + mutate_function_t mt = + [=](const Tensors& vec) { + Tensor grad_input = vec.at(0); + at::upsample_bilinear2d_backward_out(grad_input, vec.at(1), output_size_, input_size_, align_corners, scales_h, scales_w); + }; + CheckpointTensorImpl::mutate("upsample_bilinear2d_backward_out", mt, {grad_input, grad_output}, {0}); + return grad_input; +} + +Tensor checkpoint_upsample_bilinear2d_backward(at::Tensor const& grad_output, c10::ArrayRef output_size, c10::ArrayRef input_size, bool align_corners, c10::optional scales_h, c10::optional scales_w) { + std::vector output_size_ = output_size.vec(); + std::vector input_size_ = input_size.vec(); + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::upsample_bilinear2d_backward(vec.at(0), output_size_, input_size_, align_corners, scales_h, scales_w)}; + }; + return CheckpointTensorImpl::make("upsample_bilinear2d_backward", rt, {grad_output})[0]; +} + +Tensor& checkpoint_clamp_min_(Tensor& a, Scalar min) { + mutate_function_t mt = + [=](const Tensors& vec) { + Tensor self = vec.at(0); + at::clamp_min_(self, min); + }; + CheckpointTensorImpl::mutate("clamp_min_", mt, {a}, {0}); + return a; +} + +Tensor& checkpoint_clamp_min__out(Tensor& out, const Tensor& self, Scalar min) { + mutate_function_t mt = + [=](const Tensors& vec) { + Tensor out = vec.at(0); + at::clamp_min_out(out, vec.at(1), min); + }; + CheckpointTensorImpl::mutate("clamp_min__out", mt, {out, self}, {0}); + return out; +} + +Tensor checkpoint_binary_cross_entropy_with_logits(const Tensor& input, const Tensor& target, const Tensor& weight, const Tensor& pos_weight, int64_t reduction) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::binary_cross_entropy_with_logits(vec.at(0), vec.at(1), vec.at(2), vec.at(3), reduction)}; + }; + return CheckpointTensorImpl::make("binary_cross_entropy_with_logits", rt, {input, target, weight, pos_weight})[0]; +} + +Tensor checkpoint_binary_cross_entropy_with_logits_backward(const Tensor& grad, const Tensor& input, const Tensor& target, const Tensor& weight, const Tensor& pos_weight, int64_t reduction) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::binary_cross_entropy_with_logits_backward(vec.at(0), vec.at(1), vec.at(2), vec.at(3), vec.at(4), reduction)}; + }; + return CheckpointTensorImpl::make("binary_cross_entropy_with_logits_backward", rt, {grad, input, target, weight, pos_weight})[0]; +} + +std::tuple checkpoint__fused_dropout(const Tensor & self, double p, Generator* g) { + // TODO: Figure out how to properly duplicate the generator; + // note that the commented-out code below results in a segfault! + // Ref> gen; + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + // Generator* cur = gen.t ? gen.t.get() : g; + // auto newG = cur->clone(); + // auto res = at::_fused_dropout(vec.at(0), p, cur); + // gen.t = newG; + auto res = at::_fused_dropout(vec.at(0), p); + return {std::get<0>(res), std::get<1>(res)}; + }; + auto res = CheckpointTensorImpl::make("_fused_droupout_", rt, {self}); + return {res[0], res[1]}; +} + +std::tuple checkpoint__thnn_fused_lstm_cell(const Tensor& input_gates, const Tensor& hidden_gates, const Tensor& cx, const Tensor& input_bias, const Tensor& hidden_bias) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + auto res = at::_thnn_fused_lstm_cell(vec.at(0), vec.at(1), vec.at(2), + vec.at(3), vec.at(4)); + return {std::get<0>(res), std::get<1>(res), std::get<2>(res)}; + }; + auto res = CheckpointTensorImpl::make("_thnn_fused_lstm_cell", rt, + {input_gates, hidden_gates, cx, input_bias, hidden_bias}); + return {res[0], res[1], res[2]}; +} + +std::tuple checkpoint__thnn_fused_lstm_cell_backward(const Tensor& grad_hy, const Tensor& grad_cy, const Tensor& cx, const Tensor& cy, const Tensor& workspace, bool has_bias) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + auto res = at::_thnn_fused_lstm_cell_backward(vec.at(0), vec.at(1), vec.at(2), vec.at(3), vec.at(4), has_bias); + return {std::get<0>(res), std::get<1>(res), + std::get<2>(res), std::get<3>(res), std::get<4>(res)}; + }; + auto res = CheckpointTensorImpl::make("_thnn_fused_lstm_cell_backward", rt, + {grad_hy, grad_cy, cx, cy, workspace}); + return {res[0], res[1], res[2], res[3], res[4]}; +} + +std::tuple checkpoint__thnn_fused_gru_cell(const Tensor& input_gates, const Tensor& hidden_gates, const Tensor& hx, const Tensor& input_bias, const Tensor& hidden_bias) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + auto res = at::_thnn_fused_gru_cell(vec.at(0), vec.at(1), vec.at(2), vec.at(3), vec.at(4)); + return {std::get<0>(res), std::get<1>(res)}; + }; + auto res = CheckpointTensorImpl::make("_thnn_fused_gru_cell", rt, + {input_gates, hidden_gates, hx, input_bias, hidden_bias}); + return {res[0], res[1]}; +} + +std::tuple checkpoint__thnn_fused_gru_cell_backward(const Tensor& grad_hy, const Tensor& workspace, bool has_bias) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + auto res = at::_thnn_fused_gru_cell_backward(vec.at(0), vec.at(1), has_bias); + return {std::get<0>(res), std::get<1>(res), + std::get<2>(res), std::get<3>(res), std::get<4>(res)}; + }; + auto res = CheckpointTensorImpl::make("_thnn_fused_gru_cell_backward", rt, + {grad_hy, workspace}); + return {res[0], res[1], res[2], res[3], res[4]}; +} + +Tensor& checkpoint_bitwise_and_out(Tensor& self, const Tensor& other, const Tensor& out) { + mutate_function_t mt = + [=](const Tensors& vec) { + Tensor self = vec.at(0); + at::bitwise_and_out(self, vec.at(1), vec.at(2)); + }; + CheckpointTensorImpl::mutate("bitwise_and_out", mt, {self, other, out}, {0}); + return self; +} + +Tensor& checkpoint_bitwise_and_out(Tensor& self, const Tensor& out, Scalar other) { + mutate_function_t mt = + [=](const Tensors& vec) { + Tensor self = vec.at(0); + at::bitwise_and_out(self, vec.at(1), other); + }; + CheckpointTensorImpl::mutate("bitwise_and_out", mt, {self, out}, {0}); + return self; +} + +Tensor& checkpoint_fill_(Tensor& self, const Tensor& value) { + mutate_function_t mt = + [=](const Tensors& vec) { + Tensor self = vec.at(0); + at::fill_(self, vec.at(1)); + }; + CheckpointTensorImpl::mutate("fill_", mt, {self, value}, {0}); + return self; +} + +Tensor& checkpoint_fill_(Tensor& self, Scalar value) { + mutate_function_t mt = + [=](const Tensors& vec) { + Tensor self = vec.at(0); + at::fill_(self, value); + }; + CheckpointTensorImpl::mutate("fill_", mt, {self}, {0}); + return self; +} + +Tensor& checkpoint_masked_select_out(Tensor& self, const Tensor& mask, const Tensor& out) { + mutate_function_t mt = + [=](const Tensors& vec) { + Tensor self = vec.at(0); + at::masked_select_out(self, vec.at(1), vec.at(2)); + }; + CheckpointTensorImpl::mutate("masked_select_out", mt, {self, mask, out}, {0}); + return self; +} + +Tensor checkpoint_masked_select(const Tensor& self, const Tensor& mask) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::masked_select(vec.at(0), vec.at(1))}; + }; + return CheckpointTensorImpl::make("masked_select", rt, {self, mask})[0]; +} + +Tensor checkpoint_index(const Tensor& self, ArrayRef indices) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + auto self = vec.at(0); + auto indices = std::vector(vec.begin() + 1, vec.end()); + return {at::index(self, indices)}; + }; + + std::vector s = {self}; + for (const Tensor& t: indices) { + s.push_back(t); + } + return CheckpointTensorImpl::make("index", rt, s)[0]; +} + +Tensor& checkpoint_index_put_(Tensor& self, ArrayRef indices, const Tensor& values, const bool accumulate) { + mutate_function_t mt = + [=](const Tensors& vec) { + Tensor self = vec.at(0); + auto values = vec.at(1); + auto indices = std::vector(vec.begin() + 2, vec.end()); + at::index_put_(self, indices, values, accumulate); + }; + std::vector s = {self, values}; + for (const Tensor& t: indices) { + s.push_back(t); + } + CheckpointTensorImpl::mutate("index_put_", mt, s, {0}); + return self; +} + +Tensor checkpoint_bmm(const Tensor& self, const Tensor& mat2) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::bmm(vec.at(0), vec.at(1))}; + }; + return CheckpointTensorImpl::make("bmm", rt, {self, mat2})[0]; +} + +Tensor checkpoint__softmax(const Tensor& self, long dim, bool half_to_float) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::_softmax(vec.at(0), dim, half_to_float)}; + }; + return CheckpointTensorImpl::make("_softmax", rt, {self})[0]; +} + +Tensor checkpoint__softmax_backward_data(const Tensor& grad_output, const Tensor& output, long dim, const Tensor& self) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::_softmax_backward_data(vec.at(0), vec.at(1), dim, vec.at(2))}; + }; + return CheckpointTensorImpl::make("_softmax_backward_data", rt, {grad_output, output, self})[0]; +} + +std::tuple +checkpoint_layer_norm(const Tensor& input, const Tensor& weight, const Tensor& bias, long M, long N, double eps) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + auto ret = at::native_layer_norm(vec.at(0), vec.at(1), vec.at(2), M, N, eps); + return {std::get<0>(ret), std::get<1>(ret), std::get<2>(ret)}; + }; + auto ret = CheckpointTensorImpl::make("native_layer_norm", rt, {input, weight, bias}); + return {ret[0], ret[1], ret[2]}; +} + +std::tuple +checkpoint_layer_norm_backward(const Tensor& grad_out, const Tensor& input, const Tensor& mean, const Tensor& rstd, const Tensor& weight, long M, long N, std::array output_mask) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + auto ret = at::native_layer_norm_backward(vec.at(0), vec.at(1), vec.at(2), vec.at(3), vec.at(4), M, N, output_mask); + return {std::get<0>(ret), std::get<1>(ret), std::get<2>(ret)}; + }; + auto ret = CheckpointTensorImpl::make("native_layer_norm_backward", rt, {grad_out, input, mean, rstd, weight}); + return {ret[0], ret[1], ret[2]}; +} + +std::tuple +checkpoint_topk(const Tensor& self, long k, long dim, bool largest, bool sorted) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + auto ret = at::topk(vec.at(0), k, dim, largest, sorted); + return {std::get<0>(ret), std::get<1>(ret)}; + }; + auto ret = CheckpointTensorImpl::make("topk", rt, {self}); + return {ret[0], ret[1]}; +} + +std::tuple +checkpoint_topk_values(Tensor& values, Tensor& indices, const Tensor& self, long k, long dim, bool largest, bool sorted) { + mutate_function_t mt = + [=](const Tensors& vec) { + Tensor values_ = vec.at(0); + Tensor indices_ = vec.at(1); + at::topk_out(values_, indices_, vec.at(2), k, dim, largest, sorted); + }; + CheckpointTensorImpl::mutate("topk_values", mt, {values, indices, self}, {0, 1}); + return {values, indices}; +} + +Tensor& checkpoint_masked_fill_(Tensor& self, const Tensor& mask, Scalar value) { + mutate_function_t mt = + [=](const Tensors& vec) { + Tensor self_ = vec.at(0); + self_.masked_fill_(vec.at(1), value); + }; + CheckpointTensorImpl::mutate("masked_fill_Scalar", mt, {self, mask}, {0}); + return {self}; +} + +Tensor& checkpoint_masked_fill_(Tensor& self, const Tensor& mask, const Tensor& value) { + mutate_function_t mt = + [=](const Tensors& vec) { + Tensor self_ = vec.at(0); + self_.masked_fill_(vec.at(1), vec.at(2)); + }; + CheckpointTensorImpl::mutate("masked_fill_Tensor", mt, {self, mask, value}, {0}); + return {self}; +} + +Tensor checkpoint_clamp(const Tensor& self, c10::optional min, c10::optional max) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::clamp(vec.at(0), min, max)}; + }; + return CheckpointTensorImpl::make("clamp", rt, {self})[0]; +} + +Tensor& checkpoint_clamp_(Tensor& self, c10::optional min, c10::optional max) { + mutate_function_t mt = + [=](const Tensors& vec) { + Tensor self_ = vec.at(0); + at::clamp_(self_, min, max); + }; + CheckpointTensorImpl::mutate("clamp_", mt, {self}, {0}); + return {self}; +} + +Tensor& checkpoint_clamp_out(Tensor& out, const Tensor& self, c10::optional min, c10::optional max) { + mutate_function_t mt = + [=](const Tensors& vec) { + Tensor out_ = vec.at(0); + at::clamp_out(out_, vec.at(1), min, max); + }; + CheckpointTensorImpl::mutate("clamp_out", mt, {out, self}, {0}); + return {out}; +} + +std::tuple checkpoint_thnn_conv2d_forward_out(Tensor& output, Tensor& finput, Tensor& fgrad_input, const Tensor& self, const Tensor& weight, c10::ArrayRef kernel_size, const Tensor& bias, c10::ArrayRef stride, c10::ArrayRef padding) { + auto kernel_size_ = kernel_size.vec(); + auto stride_ = stride.vec(); + auto padding_ = padding.vec(); + mutate_function_t mt = + [=](const Tensors& vec) { + Tensor output_ = vec.at(0); + Tensor finput_ = vec.at(1); + Tensor fgrad_input_ = vec.at(2); + at::thnn_conv2d_forward_out(output_, finput_, fgrad_input_, vec.at(3), vec.at(4), kernel_size_, vec.at(5), stride_, padding_); + }; + CheckpointTensorImpl::mutate("thnn_conv2d_forward_out", mt, {output, finput, fgrad_input, self, weight, bias}, {0, 1, 2}); + return {output, finput, fgrad_input}; +} + +std::tuple checkpoint_thnn_conv2d_forward(const Tensor& self, const Tensor& weight, c10::ArrayRef kernel_size, const Tensor& bias, c10::ArrayRef stride, c10::ArrayRef padding) { + auto kernel_size_ = kernel_size.vec(); + auto stride_ = stride.vec(); + auto padding_ = padding.vec(); + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + auto ret = at::thnn_conv2d_forward(vec.at(0), vec.at(1), kernel_size_, vec.at(2), stride_, padding_); + return {std::get<0>(ret), std::get<1>(ret), std::get<2>(ret)}; + }; + auto ret = CheckpointTensorImpl::make("thnn_conv2d_forward", rt, {self, weight, bias}); + return {ret[0], ret[1], ret[2]}; +} + +std::tuple checkpoint_thnn_conv2d_backward_out(Tensor& grad_input, Tensor& grad_weight, Tensor& grad_bias, const Tensor& grad_output, const Tensor& self, const Tensor& weight, c10::ArrayRef kernel_size, c10::ArrayRef stride, c10::ArrayRef padding, const Tensor& finput, const Tensor& fgrad_input) { + auto kernel_size_ = kernel_size.vec(); + auto stride_ = stride.vec(); + auto padding_ = padding.vec(); + mutate_function_t mt = + [=](const Tensors& vec) { + Tensor grad_input_ = vec.at(0); + Tensor grad_weight_ = vec.at(1); + Tensor grad_bias_ = vec.at(2); + at::thnn_conv2d_backward_out(grad_input_, grad_weight_, grad_bias_, vec.at(3), vec.at(4), vec.at(5), kernel_size_, stride_, padding_, vec.at(6), vec.at(7)); + }; + CheckpointTensorImpl::mutate("thnn_conv2d_backward_out", mt, {grad_input, grad_weight, grad_bias, grad_output, self, weight, finput, fgrad_input}, {0, 1, 2}); + return {grad_input, grad_weight, grad_bias}; +} + +std::tuple checkpoint_thnn_conv2d_backward(const Tensor& grad_output, const Tensor& self, const Tensor& weight, c10::ArrayRef kernel_size, c10::ArrayRef stride, c10::ArrayRef padding, const Tensor& finput, const Tensor& fgrad_input, std::array output_mask) { + auto kernel_size_ = kernel_size.vec(); + auto stride_ = stride.vec(); + auto padding_ = padding.vec(); + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + auto ret = at::thnn_conv2d_backward(vec.at(0), vec.at(1), vec.at(2), kernel_size_, stride_, padding_, vec.at(3), vec.at(4), output_mask); + return {std::get<0>(ret), std::get<1>(ret), std::get<2>(ret)}; + }; + auto ret = CheckpointTensorImpl::make("thnn_conv2d_backward", rt, {grad_output, self, weight, finput, fgrad_input}); + return {ret[0], ret[1], ret[2]}; +} + +std::tuple checkpoint_native_batch_norm(const Tensor& input, const Tensor& weight, const Tensor& bias, const Tensor& running_mean, const Tensor& running_var, bool training, double momentum, double eps) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + auto ret = at::native_batch_norm(vec.at(0), vec.at(1), vec.at(2), vec.at(3), vec.at(4), training, momentum, eps); + return {std::get<0>(ret), std::get<1>(ret), std::get<2>(ret)}; + }; + auto ret = CheckpointTensorImpl::make("native_batch_norm", rt, {input, weight, bias, running_mean, running_var}); + return {ret[0], ret[1], ret[2]}; +} + +std::tuple checkpoint_native_batch_norm_out(Tensor& out, Tensor& save_mean, Tensor& save_invstd, const Tensor& input, const Tensor& weight, const Tensor& bias, const Tensor& running_mean, const Tensor& running_var, bool training, double momentum, double eps) { + mutate_function_t mt = + [=](const Tensors& vec) { + Tensor out_ = vec.at(0); + Tensor save_mean_ = vec.at(1); + Tensor save_invstd_ = vec.at(2); + at::native_batch_norm_out(out_, save_mean_, save_invstd_, vec.at(3), vec.at(4), vec.at(5), vec.at(6), vec.at(7), training, momentum, eps); + }; + CheckpointTensorImpl::mutate("native_batch_norm_out", mt, {out, save_mean, save_invstd, input, weight, bias, running_mean, running_var}, {0, 1, 2}); + return {out, save_mean, save_invstd}; +} + +std::tuple checkpoint_native_batch_norm_backward(const Tensor& grad_out, const Tensor& input, const Tensor& weight, const Tensor& running_mean, const Tensor& running_var, const Tensor& save_mean, const Tensor& save_invstd, bool train, double eps, std::array output_mask) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + auto ret = at::native_batch_norm_backward(vec.at(0), vec.at(1), vec.at(2), vec.at(3), vec.at(4), vec.at(5), vec.at(6), train, eps, output_mask); + return {std::get<0>(ret), std::get<1>(ret), std::get<2>(ret)}; + }; + auto ret = CheckpointTensorImpl::make("native_batch_norm_backward", rt, {grad_out, input, weight, running_mean, running_var, save_mean, save_invstd}); + return {ret[0], ret[1], ret[2]}; +} + +std::tuple checkpoint__cudnn_ctc_loss(const Tensor& log_probs, const Tensor& targets, ArrayRef input_lengths, ArrayRef target_lengths, long blank, bool deterministic, bool zero_infinity) { + auto input_lengths_ = input_lengths.vec(); + auto target_lengths_ = target_lengths.vec(); + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + auto ret = at::_cudnn_ctc_loss(vec.at(0), vec.at(1), input_lengths_, target_lengths_, blank, deterministic, zero_infinity); + return {std::get<0>(ret), std::get<1>(ret)}; + }; + auto ret = CheckpointTensorImpl::make("_cudnn_ctc_loss", rt, {log_probs, targets}); + return {ret[0], ret[1]}; +} + +std::tuple checkpoint__ctc_loss(const Tensor& log_probs, const Tensor& targets, ArrayRef input_lengths, ArrayRef target_lengths, long blank, bool zero_infinity) { + auto input_lengths_ = input_lengths.vec(); + auto target_lengths_ = target_lengths.vec(); + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + auto ret = at::_ctc_loss(vec.at(0), vec.at(1), input_lengths_, target_lengths_, blank, zero_infinity); + return {std::get<0>(ret), std::get<1>(ret)}; + }; + auto ret = CheckpointTensorImpl::make("_ctc_loss", rt, {log_probs, targets}); + return {ret[0], ret[1]}; +} + +Tensor checkpoint__ctc_loss_backward(const Tensor& grad, const Tensor& log_probs, const Tensor& targets, ArrayRef input_lengths, ArrayRef target_lengths, const Tensor& neg_log_likelihood, const Tensor& log_alpha, long blank, bool zero_infinity) { + auto input_lengths_ = input_lengths.vec(); + auto target_lengths_ = target_lengths.vec(); + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::_ctc_loss_backward(vec.at(0), vec.at(1), vec.at(2), input_lengths_, target_lengths_, vec.at(3), vec.at(4), blank, zero_infinity)}; + }; + return CheckpointTensorImpl::make("_ctc_loss_backward", rt, {grad, log_probs, targets, neg_log_likelihood, log_alpha})[0]; +} + +Tensor& checkpoint_hardtanh_backward_out(Tensor& grad_input, const Tensor& grad_output, const Tensor& self, Scalar min_val, Scalar max_val) { + mutate_function_t mt = + [=](const Tensors& vec) { + Tensor grad_input_ = vec.at(0); + at::hardtanh_backward_out(grad_input_, vec.at(1), vec.at(2), min_val, max_val); + }; + CheckpointTensorImpl::mutate("hardtanh_backward_out", mt, {grad_input, grad_output, self}, {0}); + return {grad_input}; +} + +Tensor checkpoint_hardtanh_backward(const Tensor& grad_output, const Tensor& self, Scalar min_val, Scalar max_val) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::hardtanh_backward(vec.at(0), vec.at(1), min_val, max_val)}; + }; + return CheckpointTensorImpl::make("hardtanh_backward", rt, {grad_output, self})[0]; +} + +Tensor checkpoint_nonzero(const Tensor& self) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::nonzero(vec.at(0))}; + }; + return CheckpointTensorImpl::make("nonzero", rt, {self})[0]; +} + +Tensor& checkpoint_nonzero_out(Tensor& out, const Tensor& self) { + mutate_function_t mt = + [=](const Tensors& vec) { + Tensor out_ = vec.at(0); + at::nonzero_out(out_, vec.at(1)); + }; + CheckpointTensorImpl::mutate("nonzero_out", mt, {out, self}, {0}); + return {out}; +} + +Tensor checkpoint_lt(const Tensor& self, Scalar other) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::lt(vec.at(0), other)}; + }; + return CheckpointTensorImpl::make("lt_Scalar", rt, {self})[0]; +} + +Tensor& checkpoint_lt_out(Tensor& out, const Tensor& self, Scalar other) { + mutate_function_t mt = + [=](const Tensors& vec) { + Tensor out_ = vec.at(0); + at::lt_out(out_, vec.at(1), other); + }; + CheckpointTensorImpl::mutate("lt_Scalar_out", mt, {out, self}, {0}); + return {out}; +} + +Tensor checkpoint_lt(const Tensor& self, const Tensor& other) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::lt(vec.at(0), vec.at(1))}; + }; + return CheckpointTensorImpl::make("lt_Tensor", rt, {self, other})[0]; +} + +Tensor& checkpoint_lt_out(Tensor& out, const Tensor& self, const Tensor& other) { + mutate_function_t mt = + [=](const Tensors& vec) { + Tensor out_ = vec.at(0); + at::lt_out(out_, vec.at(1), vec.at(2)); + }; + CheckpointTensorImpl::mutate("lt_Tensor_out", mt, {out, self, other}, {0}); + return {out}; +} + +Tensor checkpoint_any(const Tensor& self) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::any(vec.at(0))}; + }; + return CheckpointTensorImpl::make("any", rt, {self})[0]; +} + +bool checkpoint__use_cudnn_ctc_loss(const Tensor& log_probs, const Tensor& targets, ArrayRef input_lengths, ArrayRef target_lengths, long blank) { + return at::_use_cudnn_ctc_loss(decheckpoint(log_probs), decheckpoint(targets), input_lengths, target_lengths, blank); +} + +bool checkpoint_equal(const Tensor& self, const Tensor& other) { + // there can't possibly be a reason to rematerialize + // a single bool so we'll just compute it now + return at::equal(decheckpoint(self), decheckpoint(other)); +} + +Scalar checkpoint__local_scalar_dense(at::Tensor const& a) { + return at::_local_scalar_dense(decheckpoint(a)); +} + +Tensor checkpoint_split_with_sizes_backward(c10::ArrayRef a, c10::ArrayRef b, long c, c10::ArrayRef d, c10::TensorOptions const& e) { + std::vector a_ = a.vec(); + std::vector d_ = d.vec(); + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::split_with_sizes_backward(vec, b, c, d_, e)}; + }; + return CheckpointTensorImpl::make("split_with_sizes_backward", rt, a_)[0]; +} + +std::vector checkpoint_split_with_sizes(at::Tensor const& a, c10::ArrayRef b, long c) { + std::vector b_ = b.vec(); + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return at::split_with_sizes(vec.at(0), b_, c); + }; + return CheckpointTensorImpl::make("split_with_sizes", rt, {a}); +} + +Tensor checkpoint_split_backward(c10::ArrayRef a, long b, long c, c10::ArrayRef d, const c10::TensorOptions& e) { + std::vector a_ = a.vec(); + std::vector d_ = d.vec(); + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::split_backward(vec, b, c, d_, e)}; + }; + return CheckpointTensorImpl::make("split_backward", rt, a_)[0]; +} + +std::vector checkpoint_split(const at::Tensor& a, long b, long c) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return at::split(vec.at(0), b, c); + }; + return CheckpointTensorImpl::make("split", rt, {a}); +} + +Tensor checkpoint_expand(at::Tensor const& a, c10::ArrayRef b, bool c) { + std::vector b_ = b.vec(); + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {vec.at(0).expand(b_, c)}; + }; + return CheckpointTensorImpl::make("expand", rt, {a})[0]; +} + +Tensor checkpoint_diag(at::Tensor const& self, long diagonal) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::diag(vec.at(0), diagonal)}; + }; + return CheckpointTensorImpl::make("diag", rt, {self})[0]; +} + +Tensor& checkpoint_diag_out(at::Tensor& out, const Tensor& self, long diagonal) { + mutate_function_t mt = + [=](const Tensors& vec) { + Tensor out_ = vec.at(0); + at::diag_out(out_, vec.at(1), diagonal); + }; + CheckpointTensorImpl::mutate("diag_out", mt, {out, self}, {0}); + return {out}; +} + +Tensor checkpoint_mv(at::Tensor const& self, at::Tensor const& vec) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::mv(vec.at(0), vec.at(1))}; + }; + return CheckpointTensorImpl::make("mv", rt, {self, vec})[0]; +} + +Tensor& checkpoint_mv_out(at::Tensor& out, const Tensor& self, const Tensor& vec) { + mutate_function_t mt = + [=](const Tensors& vec) { + Tensor out_ = vec.at(0); + at::mv_out(out_, vec.at(1), vec.at(2)); + }; + CheckpointTensorImpl::mutate("mv_out", mt, {out, self, vec}, {0}); + return {out}; +} + +}} \ No newline at end of file diff --git a/aten/src/ATen/native/TensorShape.cpp b/aten/src/ATen/native/TensorShape.cpp index 14e3a877634..49102d33233 100644 --- a/aten/src/ATen/native/TensorShape.cpp +++ b/aten/src/ATen/native/TensorShape.cpp @@ -888,52 +888,6 @@ Tensor slice(const Tensor& self, int64_t dim, int64_t start, int64_t end, int64_ return result; } -std::vector split(const Tensor& self, int64_t split_size, int64_t dim) { - TORCH_CHECK(self.dim() != 0, "split expects at least a 1-dimensional tensor"); - TORCH_CHECK(split_size >= 0, "split expects split_size be non-negative, but got split_size=", split_size); - int64_t dim_size = self.size(dim); - TORCH_CHECK(split_size > 0 || self.size(dim) == 0, - "split_size can only be 0 if dimension size is 0, " - "but got dimension size of ", dim_size); - // if split_size is 0 and dimension size is 0, there is 1 split. - int64_t num_splits = 1; - if (split_size != 0) { - // ensuring num_splits is at least 1 makes consistent the case where split_size > dim_size - // (returns a single split). We might want to error here, but keep it for BC. - num_splits = std::max((dim_size + split_size - 1) / split_size, 1); - } - std::vector splits(num_splits); - int64_t last_split_size = split_size - (split_size * num_splits - dim_size); - - for (int64_t i = 0; i < num_splits; ++i) { - auto length = i < num_splits - 1 ? split_size : last_split_size; - splits[i] = self.narrow(dim, i * split_size, length); - } - return splits; -} - -std::vector split_with_sizes(const Tensor& self, IntArrayRef split_sizes, int64_t dim) { - TORCH_CHECK(self.dim() != 0, "split expects at least a 1-dimensional tensor"); - int64_t dim_size = self.size(dim); - int64_t num_splits = split_sizes.size(); - std::vector splits(num_splits); - int64_t start_idx = 0; - int64_t i; - - for (i = 0; i < num_splits; ++i) { - auto length = split_sizes[i]; - TORCH_CHECK(length >= 0, - "split_with_sizes expects split_sizes have only non-negative ", - "entries, but got split_sizes=", split_sizes); - splits[i] = self.narrow(dim, start_idx, length); - start_idx += length; - } - TORCH_CHECK(start_idx == dim_size, - "split_with_sizes expects split_sizes to sum exactly to ", dim_size, - " (input tensor's size at dimension ", dim, "), ", "but got split_sizes=", split_sizes); - return splits; -} - // Precondition: tensors is non-empty static inline std::vector get_stack_inputs(TensorList tensors, int64_t dim) { std::vector inputs(tensors.size()); diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index a9e02ea8ec9..59459cafef5 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -1,6 +1,53 @@ # See README.md in this directory for more guidance +- func: checkpoint(Tensor self) -> Tensor + variants: method + +# convert a tensor to a checkpoint tensor. +# if the input is already a checkpoint tensor, +# checkpoint will fail while try_checkpoint will +# simply return the input. +- func: try_checkpoint(Tensor self) -> Tensor + variants: method + +- func: is_checkpoint(Tensor self) -> bool + variants: method + +# convert checkpointed tensor into normal tensor. +# uncheckpoint assume the input is checkpointed and will fail otherwise. +# decheckpoint return the input if it is not checkpointed. +- func: uncheckpoint(Tensor self) -> Tensor + variants: method + +- func: decheckpoint(Tensor self) -> Tensor + variants: method + +- func: pin(Tensor(a!) self) -> () + variants: method + +- func: new_log(str logname) -> () + variants: function + +- func: annotate_log(str logname) -> () + variants: function + +- func: toggle_log(bool use_log) -> () + variants: function + +- func: clear_checkpointpool() -> () + variants: function + +- func: set_memory_budget(int budget) -> () + variants: function + +- func: unset_memory_budget() -> () + variants: function +- func: reset_compute_time() -> () + variants: function + +- func: compute_time() -> int + variants: function # Temporary type cast operators. These are needed to trace type-casts now since # Type's are not supported in the IR. Instead, we call down to these # specialized operators for each datatype. @@ -135,10 +182,12 @@ - func: _use_cudnn_ctc_loss(Tensor log_probs, Tensor targets, int[] input_lengths, int[] target_lengths, int blank) -> bool dispatch: CUDA: _use_cudnn_ctc_loss + Checkpoint: checkpoint__use_cudnn_ctc_loss - func: _cudnn_ctc_loss(Tensor log_probs, Tensor targets, int[] input_lengths, int[] target_lengths, int blank, bool deterministic, bool zero_infinity) -> (Tensor, Tensor) dispatch: CUDA: _cudnn_ctc_loss + Checkpoint: checkpoint__cudnn_ctc_loss - func: _cudnn_rnn_flatten_weight(Tensor[] weight_arr, int weight_stride0, int input_size, int mode, int hidden_size, int num_layers, bool batch_first, bool bidirectional) -> Tensor dispatch: @@ -164,6 +213,7 @@ variants: function dispatch: CUDA: fused_dropout_cuda + Checkpoint: checkpoint__fused_dropout supports_named_tensor: True - func: _masked_scale(Tensor self, Tensor mask, float scale) -> Tensor @@ -171,6 +221,7 @@ variants: function dispatch: CUDA: masked_scale_cuda + Checkpoint: checkpoint__masked_scale - func: _sobol_engine_draw(Tensor quasi, int n, Tensor sobolstate, int dimension, int num_generated, ScalarType? dtype) -> (Tensor, Tensor) @@ -217,6 +268,10 @@ use_c10_dispatcher: full variants: function, method supports_named_tensor: True + dispatch: + CUDA: abs + CPU: abs + Checkpoint: checkpoint_abs - func: abs_(Tensor(a!) self) -> Tensor(a!) variants: function, method @@ -285,6 +340,7 @@ SparseCPU: add_sparse SparseCUDA: add_sparse MkldnnCPU: mkldnn_add + Checkpoint: checkpoint_add supports_named_tensor: True - func: add_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> Tensor(a!) @@ -295,6 +351,7 @@ SparseCPU: add_sparse_ SparseCUDA: add_sparse_ MkldnnCPU: mkldnn_add_ + Checkpoint: checkpoint_add_ supports_named_tensor: True - func: add.out(Tensor self, Tensor other, *, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!) @@ -311,6 +368,10 @@ use_c10_dispatcher: full variants: function, method supports_named_tensor: True + dispatch: + CPU: add + CUDA: add + Checkpoint: checkpoint_add - func: add_.Scalar(Tensor(a!) self, Scalar other, Scalar alpha=1) -> Tensor(a!) variants: method @@ -417,6 +478,7 @@ CPU: as_strided_tensorimpl CUDA: as_strided_tensorimpl QuantizedCPU: as_strided_qtensorimpl + Checkpoint: checkpoint_as_strided device_guard: False supports_named_tensor: True @@ -528,6 +590,7 @@ dispatch: CPU: binary_cross_entropy_cpu CUDA: binary_cross_entropy_cuda + Checkpoint: checkpoint_binary_cross_entropy - func: binary_cross_entropy.out(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean, *, Tensor(a!) out) -> Tensor(a!) python_module: nn @@ -535,6 +598,7 @@ dispatch: CPU: binary_cross_entropy_out_cpu CUDA: binary_cross_entropy_out_cuda + Checkpoint: checkpoint_binary_cross_entropy_out - func: binary_cross_entropy_backward(Tensor grad_output, Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean) -> Tensor python_module: nn @@ -542,6 +606,7 @@ dispatch: CPU: binary_cross_entropy_backward_cpu CUDA: binary_cross_entropy_backward_cuda + Checkpoint: checkpoint_binary_cross_entropy_backward - func: binary_cross_entropy_backward.grad_input(Tensor grad_output, Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean, *, Tensor(a!) grad_input) -> Tensor(a!) python_module: nn @@ -549,12 +614,21 @@ dispatch: CPU: binary_cross_entropy_backward_out_cpu CUDA: binary_cross_entropy_backward_out_cuda + Checkpoint: checkpoint_binary_cross_entropy_backward_out - func: binary_cross_entropy_with_logits(Tensor self, Tensor target, Tensor? weight=None, Tensor? pos_weight=None, int reduction=Mean) -> Tensor variants: function + dispatch: + CPU: binary_cross_entropy_with_logits + CUDA: binary_cross_entropy_with_logits + Checkpoint: checkpoint_binary_cross_entropy_with_logits - func: binary_cross_entropy_with_logits_backward(Tensor grad_output, Tensor self, Tensor target, Tensor? weight=None, Tensor? pos_weight=None, int reduction=Mean) -> Tensor variants: function + dispatch: + CPU: binary_cross_entropy_with_logits_backward + CUDA: binary_cross_entropy_with_logits_backward + Checkpoint: checkpoint_binary_cross_entropy_with_logits_backward - func: bincount(Tensor self, Tensor? weights=None, int minlength=0) -> Tensor variants: function, method @@ -643,6 +717,7 @@ dispatch: CPU: bmm_cpu CUDA: bmm_cuda + Checkpoint: checkpoint_bmm supports_named_tensor: True - func: bmm.out(Tensor self, Tensor mat2, *, Tensor(a!) out) -> Tensor(a!) @@ -698,6 +773,7 @@ CPU: clamp CUDA: clamp QuantizedCPU: quantized_clamp + Checkpoint: checkpoint_clamp - func: clamp_(Tensor(a!) self, Scalar? min=None, Scalar? max=None) -> Tensor(a!) supports_named_tensor: True @@ -705,12 +781,14 @@ dispatch: CPU: _clamp__cpu CUDA: _clamp__cuda + Checkpoint: checkpoint_clamp_ - func: clamp.out(Tensor self, Scalar? min=None, Scalar? max=None, *, Tensor(a!) out) -> Tensor(a!) supports_named_tensor: True dispatch: CPU: _clamp_out_cpu CUDA: _clamp_out_cuda + Checkpoint: checkpoint_clamp_out - func: clamp_max(Tensor self, Scalar max) -> Tensor use_c10_dispatcher: full @@ -741,12 +819,14 @@ dispatch: CPU: _clamp_min__cpu CUDA: _clamp_min__cuda + Checkpoint: checkpoint_clamp_min_ - func: clamp_min.out(Tensor self, Scalar min, *, Tensor(a!) out) -> Tensor(a!) supports_named_tensor: True dispatch: CPU: _clamp_min_out_cpu CUDA: _clamp_min_out_cuda + Checkpoint: checkpoint_clamp_min__out - func: cudnn_is_acceptable(Tensor self) -> bool use_c10_dispatcher: full @@ -754,6 +834,10 @@ - func: constant_pad_nd(Tensor self, int[] pad, Scalar value=0) -> Tensor variants: function + dispatch: + CPU: constant_pad_nd + CUDA: constant_pad_nd + Checkpoint: checkpoint_constant_pad_nd - func: contiguous(Tensor self, *, MemoryFormat memory_format=contiguous_format) -> Tensor variants: method @@ -856,11 +940,13 @@ - func: cudnn_batch_norm(Tensor input, Tensor weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float exponential_average_factor, float epsilon) -> (Tensor, Tensor, Tensor, Tensor) dispatch: CUDA: cudnn_batch_norm + Checkpoint: checkpoint_cudnn_batch_norm # NB: You can only use this if you used cudnn_batch_norm training=True - func: cudnn_batch_norm_backward(Tensor input, Tensor grad_output, Tensor weight, Tensor? running_mean, Tensor? running_var, Tensor? save_mean, Tensor? save_var, float epsilon, Tensor reserveSpace) -> (Tensor, Tensor, Tensor) dispatch: CUDA: cudnn_batch_norm_backward + Checkpoint: checkpoint_cudnn_batch_norm_backward - func: cudnn_convolution.deprecated(Tensor self, Tensor weight, Tensor? bias, int[] padding, int[] stride, int[] dilation, int groups, bool benchmark, bool deterministic) -> Tensor dispatch: @@ -869,18 +955,22 @@ - func: cudnn_convolution(Tensor self, Tensor weight, int[] padding, int[] stride, int[] dilation, int groups, bool benchmark, bool deterministic) -> Tensor dispatch: CUDA: cudnn_convolution + Checkpoint: checkpoint_cudnn_convolution - func: cudnn_convolution_backward_input(int[] self_size, Tensor grad_output, Tensor weight, int[] padding, int[] stride, int[] dilation, int groups, bool benchmark, bool deterministic) -> Tensor dispatch: CUDA: cudnn_convolution_backward_input + Checkpoint: checkpoint_cudnn_convolution_backward_input - func: cudnn_convolution_backward(Tensor self, Tensor grad_output, Tensor weight, int[] padding, int[] stride, int[] dilation, int groups, bool benchmark, bool deterministic, bool[2] output_mask) -> (Tensor, Tensor) dispatch: CUDA: cudnn_convolution_backward + Checkpoint: checkpoint_cudnn_convolution_backward - func: cudnn_convolution_backward_weight(int[] weight_size, Tensor grad_output, Tensor self, int[] padding, int[] stride, int[] dilation, int groups, bool benchmark, bool deterministic) -> Tensor dispatch: CUDA: cudnn_convolution_backward_weight + Checkpoint: checkpoint_cudnn_convolution_backward_weight - func: cudnn_convolution_transpose.deprecated(Tensor self, Tensor weight, Tensor? bias, int[] padding, int[] output_padding, int[] stride, int[] dilation, int groups, bool benchmark, bool deterministic) -> Tensor dispatch: @@ -889,20 +979,24 @@ - func: cudnn_convolution_transpose(Tensor self, Tensor weight, int[] padding, int[] output_padding, int[] stride, int[] dilation, int groups, bool benchmark, bool deterministic) -> Tensor dispatch: CUDA: cudnn_convolution_transpose + Checkpoint: checkpoint_cudnn_convolution_transpose # NB: output_padding not strictly needed here, but it's helpful for the float # backwards - func: cudnn_convolution_transpose_backward(Tensor self, Tensor grad_output, Tensor weight, int[] padding, int[] output_padding, int[] stride, int[] dilation, int groups, bool benchmark, bool deterministic, bool[2] output_mask) -> (Tensor, Tensor) dispatch: CUDA: cudnn_convolution_transpose_backward + Checkpoint: checkpoint_cudnn_convolution_transpose_backward - func: cudnn_convolution_transpose_backward_input(Tensor grad_output, Tensor weight, int[] padding, int[] stride, int[] dilation, int groups, bool benchmark, bool deterministic) -> Tensor dispatch: CUDA: cudnn_convolution_transpose_backward_input + Checkpoint: checkpoint_cudnn_convolution_transpose_backward_input - func: cudnn_convolution_transpose_backward_weight(int[] weight_size, Tensor grad_output, Tensor self, int[] padding, int[] stride, int[] dilation, int groups, bool benchmark, bool deterministic) -> Tensor dispatch: CUDA: cudnn_convolution_transpose_backward_weight + Checkpoint: checkpoint_cudnn_convolution_transpose_backward_weight # NB: input is special cased in a way I don't quite understand - func: cudnn_grid_sampler(Tensor self, Tensor grid) -> Tensor output @@ -992,11 +1086,13 @@ dispatch: CPU: ctc_loss_cpu CUDA: ctc_loss_gpu + Checkpoint: checkpoint__ctc_loss - func: _ctc_loss_backward(Tensor grad, Tensor log_probs, Tensor targets, int[] input_lengths, int[] target_lengths, Tensor neg_log_likelihood, Tensor log_alpha, int blank, bool zero_infinity=False) -> Tensor dispatch: CPU: ctc_loss_backward_cpu CUDA: ctc_loss_backward_gpu + Checkpoint: checkpoint__ctc_loss_backward - func: det(Tensor self) -> Tensor use_c10_dispatcher: full @@ -1029,6 +1125,7 @@ CUDA: div SparseCPU: div_sparse SparseCUDA: div_sparse + Checkpoint: checkpoint_div supports_named_tensor: True - func: div_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) @@ -1038,6 +1135,7 @@ CUDA: div_ SparseCPU: div_sparse_ SparseCUDA: div_sparse_ + Checkpoint: checkpoint_div_ supports_named_tensor: True - func: div.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) @@ -1073,9 +1171,17 @@ - func: embedding(Tensor weight, Tensor indices, int padding_idx=-1, bool scale_grad_by_freq=False, bool sparse=False) -> Tensor use_c10_dispatcher: full + dispatch: + CPU: embedding + CUDA: embedding + Checkpoint: checkpoint_embedding - func: embedding_backward(Tensor grad, Tensor indices, int num_weights, int padding_idx, bool scale_grad_by_freq, bool sparse) -> Tensor use_c10_dispatcher: full + dispatch: + CPU: embedding_backward + CUDA: embedding_backward + Checkpoint: checkpoint_embedding_backward - func: embedding_dense_backward(Tensor grad_output, Tensor indices, int num_weights, int padding_idx, bool scale_grad_by_freq) -> Tensor use_c10_dispatcher: full @@ -1247,6 +1353,10 @@ variants: method # This is method-only to match the previous tensor API. In the future we could make this a function too. device_guard: False supports_named_tensor: True + dispatch: + CPU: expand + CUDA: expand + Checkpoint: checkpoint_expand - func: expand_as(Tensor self, Tensor other) -> Tensor use_c10_dispatcher: full @@ -1287,10 +1397,18 @@ - func: fill_.Scalar(Tensor(a!) self, Scalar value) -> Tensor(a!) supports_named_tensor: True variants: function, method + dispatch: + CPU: fill_ + CUDA: fill_ + Checkpoint: checkpoint_fill_ - func: fill_.Tensor(Tensor(a!) self, Tensor value) -> Tensor(a!) supports_named_tensor: True variants: function, method + dispatch: + CPU: fill_ + CUDA: fill_ + Checkpoint: checkpoint_fill_ - func: floor(Tensor self) -> Tensor use_c10_dispatcher: full @@ -1435,6 +1553,10 @@ - func: index.Tensor(Tensor self, Tensor?[] indices) -> Tensor variants: function, method + dispatch: + CPU: index + CUDA: index + Checkpoint: checkpoint_index # NB: This function is special-cased in tools/autograd/gen_variable_type.py # NB: The following functions are declared in aten/src/ATen/templates/TensorBody.h and defined in aten/src/ATen/TensorIndexing.cpp: # - Tensor Tensor::index(ArrayRef indices) @@ -1455,6 +1577,10 @@ - func: index_put_(Tensor(a!) self, Tensor?[] indices, Tensor values, bool accumulate=False) -> Tensor(a!) variants: function, method + dispatch: + CPU: index_put_ + CUDA: index_put_ + Checkpoint: checkpoint_index_put_ # NB: The following functions are declared in aten/src/ATen/templates/TensorBody.h and defined in aten/src/ATen/TensorIndexing.cpp: # - Tensor & Tensor::index_put_(ArrayRef indices, Tensor const & rhs) # - Tensor & Tensor::index_put_(ArrayRef indices, Scalar v) @@ -1535,12 +1661,17 @@ - func: kl_div(Tensor self, Tensor target, int reduction=Mean) -> Tensor use_c10_dispatcher: full + dispatch: + CPU: kl_div + CUDA: kl_div + Checkpoint: checkpoint_kl_div - func: kl_div_backward(Tensor grad_output, Tensor self, Tensor target, int reduction=Mean) -> Tensor use_c10_dispatcher: full dispatch: CPU: kl_div_backward_cpu CUDA: kl_div_backward_cuda + Checkpoint: checkpoint_kl_div_backward - func: kthvalue(Tensor self, int k, int dim=-1, bool keepdim=False) -> (Tensor values, Tensor indices) supports_named_tensor: True @@ -1565,11 +1696,13 @@ dispatch: CPU: layer_norm_cpu CUDA: layer_norm_cuda + Checkpoint: checkpoint_layer_norm - func: native_layer_norm_backward(Tensor grad_out, Tensor input, Tensor mean, Tensor rstd, Tensor? weight, int M, int N, bool[3] output_mask) -> (Tensor, Tensor, Tensor) dispatch: CPU: layer_norm_backward_cpu CUDA: layer_norm_backward_cuda + Checkpoint: checkpoint_layer_norm_backward - func: linear(Tensor input, Tensor weight, Tensor? bias=None) -> Tensor python_module: nn @@ -1615,6 +1748,10 @@ use_c10_dispatcher: full supports_named_tensor: True variants: function, method + dispatch: + CPU: log + CUDA: log + Checkpoint: checkpoint_log - func: log_(Tensor(a!) self) -> Tensor(a!) supports_named_tensor: True @@ -1625,6 +1762,7 @@ dispatch: CPU: log_out CUDA: log_out + Checkpoint: checkpoint_log_out - func: log10(Tensor self) -> Tensor use_c10_dispatcher: full @@ -1703,12 +1841,14 @@ dispatch: CPU: log_softmax_cpu CUDA: log_softmax_cuda + Checkpoint: checkpoint__log_softmax - func: _log_softmax_backward_data(Tensor grad_output, Tensor output, int dim, Tensor self) -> Tensor use_c10_dispatcher: full dispatch: CPU: log_softmax_backward_cpu CUDA: log_softmax_backward_cuda + Checkpoint: checkpoint__log_softmax_backward_data - func: logsumexp(Tensor self, int[1] dim, bool keepdim=False) -> Tensor supports_named_tensor: True @@ -1797,6 +1937,7 @@ CPU: mean_cpu_gpu CUDA: mean_cpu_gpu QuantizedCPU: quantized_mean_cpu + Checkpoint: checkpoint_mean - func: mean.dim(Tensor self, int[1] dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor variants: function, method @@ -1805,6 +1946,7 @@ CPU: mean_cpu_gpu CUDA: mean_cpu_gpu QuantizedCPU: quantized_mean_cpu + Checkpoint: checkpoint_mean - func: mean.out(Tensor self, int[1] dim, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!) supports_named_tensor: True @@ -1941,6 +2083,7 @@ CUDA: legacy::cuda::_th_mm SparseCPU: _sparse_mm SparseCUDA: _sparse_mm + Checkpoint: checkpoint_mm supports_named_tensor: True - func: mm.out(Tensor self, Tensor mat2, *, Tensor(a!) out) -> Tensor(a!) @@ -1949,6 +2092,7 @@ CUDA: legacy::cuda::_th_mm_out SparseCPU: _sparse_mm_out SparseCUDA: _sparse_mm_out + Checkpoint: checkpoint_mm_out supports_named_tensor: True - func: _sparse_mm(Tensor sparse, Tensor dense) -> Tensor @@ -1977,6 +2121,7 @@ SparseCPU: mul_sparse SparseCUDA: mul_sparse MkldnnCPU: mkldnn_mul + Checkpoint: checkpoint_mul supports_named_tensor: True - func: mul_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) @@ -1987,6 +2132,7 @@ SparseCPU: mul_sparse_ SparseCUDA: mul_sparse_ MkldnnCPU: mkldnn_mul_ + Checkpoint: checkpoint_mul_ supports_named_tensor: True - func: mul.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) @@ -2002,9 +2148,20 @@ - func: mul.Scalar(Tensor self, Scalar other) -> Tensor use_c10_dispatcher: full variants: function, method + dispatch: + CPU: mul + CUDA: mul + SparseCPU: mul + SparseCUDA: mul + MkldnnCPU: mul + Checkpoint: checkpoint_mul - func: mul_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) variants: method + dispatch: + CPU: mul_ + CUDA: mul_ + Checkpoint: checkpoint_mul_ - func: mv(Tensor self, Tensor vec) -> Tensor use_c10_dispatcher: full @@ -2012,12 +2169,14 @@ dispatch: CPU: mv_cpu CUDA: legacy::cuda::_th_mv + Checkpoint: checkpoint_mv supports_named_tensor: True - func: mv.out(Tensor self, Tensor vec, *, Tensor(a!) out) -> Tensor(a!) dispatch: CPU: mv_cpu_out CUDA: legacy::cuda::_th_mv_out + Checkpoint: checkpoint_mv_out supports_named_tensor: True - func: mvlgamma(Tensor self, int p) -> Tensor @@ -2046,10 +2205,12 @@ CPU: batch_norm_cpu CUDA: batch_norm_cuda MkldnnCPU: mkldnn_batch_norm + Checkpoint: checkpoint_native_batch_norm - func: native_batch_norm.out(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps, *, Tensor(a!) out, Tensor(b!) save_mean, Tensor(c!) save_invstd) -> (Tensor(a!), Tensor(b!), Tensor(c!)) dispatch: CUDA: batch_norm_cuda_out + Checkpoint: checkpoint_native_batch_norm_out - func: batch_norm_stats(Tensor input, float eps) -> (Tensor, Tensor) dispatch: @@ -2076,6 +2237,7 @@ dispatch: CPU: batch_norm_backward_cpu CUDA: batch_norm_backward_cuda + Checkpoint: checkpoint_native_batch_norm_backward - func: batch_norm_backward_reduce(Tensor grad_out, Tensor input, Tensor mean, Tensor invstd, Tensor? weight, bool input_g, bool weight_g, bool bias_g) -> (Tensor, Tensor, Tensor, Tensor) dispatch: @@ -2114,6 +2276,10 @@ - func: ones_like(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor supports_named_tensor: True + dispatch: + CPU: ones_like + CUDA: ones_like + Checkpoint: checkpoint_ones_like - func: pairwise_distance(Tensor x1, Tensor x2, float p=2, float eps=1e-06, bool keepdim=False) -> Tensor use_c10_dispatcher: full @@ -2266,6 +2432,10 @@ use_c10_dispatcher: full supports_named_tensor: True variants: function, method + dispatch: + CPU: neg + CUDA: neg + Checkpoint: checkpoint_neg - func: neg_(Tensor(a!) self) -> Tensor(a!) supports_named_tensor: True @@ -2279,6 +2449,10 @@ - func: repeat(Tensor self, int[] repeats) -> Tensor variants: method # This is method-only to match the previous tensor API. In the future we could make this a function too. + dispatch: + CPU: repeat + CUDA: repeat + Checkpoint: checkpoint_repeat - func: repeat_interleave.Tensor(Tensor repeats) -> Tensor use_c10_dispatcher: full @@ -2338,6 +2512,7 @@ CUDA: relu MkldnnCPU: mkldnn_relu QuantizedCPU: quantized_relu + Checkpoint: checkpoint_relu supports_named_tensor: True - func: relu_(Tensor(a!) self) -> Tensor(a!) @@ -2348,6 +2523,7 @@ CUDA: relu_ MkldnnCPU: mkldnn_relu_ QuantizedCPU: quantized_relu_ + Checkpoint: checkpoint_relu_ - func: prelu(Tensor self, Tensor weight) -> Tensor use_c10_dispatcher: full @@ -2408,6 +2584,18 @@ variants: function, method device_guard: False supports_named_tensor: True + dispatch: + CPU: select + CUDA: select + Checkpoint: checkpoint_select + +- func: select_backward(Tensor grad, int[] sizes, int dim, int index) -> Tensor + variants: function + device_guard: False + dispatch: + CPU: select_backward + CUDA: select_backward + Checkpoint: checkpoint_select_backward - func: selu(Tensor self) -> Tensor use_c10_dispatcher: full @@ -2429,6 +2617,7 @@ CUDA: sigmoid QuantizedCPU: quantized_sigmoid MkldnnCPU: mkldnn_sigmoid + Checkpoint: checkpoint_sigmoid - func: sigmoid_(Tensor(a!) self) -> Tensor(a!) supports_named_tensor: True @@ -2437,6 +2626,7 @@ CPU: sigmoid_ CUDA: sigmoid_ MkldnnCPU: mkldnn_sigmoid_ + Checkpoint: checkpoint_sigmoid_ - func: sigmoid.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) supports_named_tensor: True @@ -2508,6 +2698,18 @@ variants: function, method device_guard: False supports_named_tensor: True + dispatch: + CPU: slice + CUDA: slice + Checkpoint: checkpoint_slice + +- func: slice_backward(Tensor grad, int[] input_sizes, int dim, int start, int end, int step) -> Tensor + variants: function, method + device_guard: False + dispatch: + CPU: slice_backward + CUDA: slice_backward + Checkpoint: checkpoint_slice_backward - func: slogdet(Tensor self) -> (Tensor sign, Tensor logabsdet) variants: function, method @@ -2531,22 +2733,46 @@ CPU: softmax_cpu CUDA: softmax_cuda MkldnnCPU: mkldnn_softmax + Checkpoint: checkpoint__softmax - func: _softmax_backward_data(Tensor grad_output, Tensor output, int dim, Tensor self) -> Tensor use_c10_dispatcher: full dispatch: CPU: softmax_backward_cpu CUDA: softmax_backward_cuda + Checkpoint: checkpoint__softmax_backward_data - func: split.Tensor(Tensor(a) self, int split_size, int dim=0) -> Tensor(a)[] variants: function, method device_guard: False supports_named_tensor: True + dispatch: + CPU: split + CUDA: split + Checkpoint: checkpoint_split + +- func: split_backward(Tensor[] grads, int split_size, int dim, int[] sizes, ScalarType dtype, Layout layout, Device device, bool pin_memory=False) -> Tensor + variants: function + dispatch: + CPU: split_backward + CUDA: split_backward + Checkpoint: checkpoint_split_backward - func: split_with_sizes(Tensor self, int[] split_sizes, int dim=0) -> Tensor[] variants: function, method device_guard: False supports_named_tensor: True + dispatch: + CPU: split_with_sizes + CUDA: split_with_sizes + Checkpoint: checkpoint_split_with_sizes + +- func: split_with_sizes_backward(Tensor[] grads, int[] split_sizes, int dim, int[] sizes, ScalarType dtype, Layout layout, Device device, bool pin_memory=False) -> Tensor + variants: function + dispatch: + CPU: split_with_sizes_backward + CUDA: split_with_sizes_backward + Checkpoint: checkpoint_split_with_sizes_backward - func: squeeze(Tensor(a) self) -> Tensor(a) supports_named_tensor: True @@ -2566,14 +2792,26 @@ - func: squeeze_(Tensor(a!) self) -> Tensor(a!) variants: method device_guard: False + dispatch: + CPU: squeeze_ + CUDA: squeeze_ + Checkpoint: checkpoint_squeeze_ - func: squeeze_.dim(Tensor(a!) self, int dim) -> Tensor(a!) variants: method device_guard: False + dispatch: + CPU: squeeze_ + CUDA: squeeze_ + Checkpoint: checkpoint_squeeze_ - func: squeeze_.dimname(Tensor(a!) self, Dimname dim) -> Tensor(a!) variants: method device_guard: False + dispatch: + CPU: squeeze_ + CUDA: squeeze_ + Checkpoint: checkpoint_squeeze_ - func: sspaddmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor use_c10_dispatcher: full @@ -2611,10 +2849,18 @@ - func: sum(Tensor self, *, ScalarType? dtype=None) -> Tensor variants: function, method supports_named_tensor: True + dispatch: + CPU: sum + CUDA: sum + Checkpoint: checkpoint_sum - func: sum.dim_IntList(Tensor self, int[1] dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor variants: function, method supports_named_tensor: True + dispatch: + CPU: sum + CUDA: sum + Checkpoint: checkpoint_sum_dim_IntList - func: sum.dim_DimnameList(Tensor self, Dimname[1] dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor variants: function, method @@ -2634,6 +2880,10 @@ use_c10_dispatcher: full supports_named_tensor: True variants: function, method + dispatch: + CPU: sqrt + CUDA: sqrt + Checkpoint: checkpoint_sqrt - func: sqrt_(Tensor(a!) self) -> Tensor(a!) supports_named_tensor: True @@ -2705,6 +2955,10 @@ device_guard: False variants: function, method supports_named_tensor: True + dispatch: + CPU: t + CUDA: t + Checkpoint: checkpoint_t - func: t_(Tensor(a!) self) -> Tensor(a!) device_guard: False @@ -2736,6 +2990,7 @@ CPU: tanh CUDA: tanh QuantizedCPU: quantized_tanh + Checkpoint: checkpoint_tanh - func: tanh_(Tensor(a!) self) -> Tensor(a!) supports_named_tensor: True @@ -2761,6 +3016,7 @@ dispatch: CPU: threshold CUDA: threshold_cuda + Checkpoint: checkpoint_threshold - func: threshold_(Tensor(a!) self, Scalar threshold, Scalar value) -> Tensor(a!) variants: function @@ -2768,12 +3024,14 @@ dispatch: CPU: threshold_ CUDA: threshold__cuda + Checkpoint: checkpoint_threshold_ - func: threshold.out(Tensor self, Scalar threshold, Scalar value, *, Tensor(a!) out) -> Tensor(a!) supports_named_tensor: True dispatch: CPU: threshold_out CUDA: threshold_out_cuda + Checkpoint: checkpoint_threshold_out - func: threshold_backward(Tensor grad_output, Tensor self, Scalar threshold) -> Tensor use_c10_dispatcher: full @@ -2781,6 +3039,7 @@ dispatch: CPU: threshold_backward CUDA: threshold_backward_cuda + Checkpoint: checkpoint_threshold_backward - func: transpose.int(Tensor(a) self, int dim0, int dim1) -> Tensor(a) variants: function, method @@ -2951,6 +3210,10 @@ - func: where.self(Tensor condition, Tensor self, Tensor other) -> Tensor use_c10_dispatcher: full variants: function, method + dispatch: + CPU: where + CUDA: where + Checkpoint: checkpoint_where - func: where(Tensor condition) -> Tensor[] variants: function @@ -2990,6 +3253,10 @@ - func: zeros_like(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor supports_named_tensor: True + dispatch: + CPU: zeros_like + CUDA: zeros_like + Checkpoint: checkpoint_zeros_like - func: _standard_gamma_grad(Tensor self, Tensor output) -> Tensor use_c10_dispatcher: full @@ -3104,6 +3371,7 @@ SparseCUDA: clone_sparse MkldnnCPU: mkldnn_clone QuantizedCPU: quantized_clone + Checkpoint: checkpoint_clone supports_named_tensor: True - func: resize_as_(Tensor(a!) self, Tensor the_template, *, MemoryFormat? memory_format=None) -> Tensor(a!) @@ -3138,6 +3406,7 @@ SparseCPU: zero_sparse_ SparseCUDA: zero_sparse_ MkldnnCPU: mkldnn_zero_ + Checkpoint: checkpoint_zero_ - func: sub.out(Tensor self, Tensor other, *, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!) dispatch: @@ -3155,6 +3424,7 @@ CUDA: sub SparseCPU: sub_sparse SparseCUDA: sub_sparse + Checkpoint: checkpoint_sub supports_named_tensor: True - func: sub_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> Tensor(a!) @@ -3162,6 +3432,7 @@ dispatch: CPU: sub_ CUDA: sub_ + Checkpoint: checkpoint_sub_ SparseCPU: sub_sparse_ SparseCUDA: sub_sparse_ supports_named_tensor: True @@ -3180,12 +3451,24 @@ use_c10_dispatcher: full variants: function supports_named_tensor: True + dispatch: + CPU: rsub + CUDA: rsub + SparseCPU: rsub + SparseCUDA: rsub + Checkpoint: checkpoint_rsub # For C++ only, until we have conversion from C++ numbers to Tensor - func: rsub.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> Tensor use_c10_dispatcher: full variants: function supports_named_tensor: True + dispatch: + CPU: rsub + CUDA: rsub + SparseCPU: rsub + SparseCUDA: rsub + Checkpoint: checkpoint_rsub # Functionally the same as addmm, but we give it a different derivative formula # that doesn't propagate gradients to non-present entries on sparse. @@ -3199,6 +3482,7 @@ CUDA: legacy::cuda::_th_addmm_out SparseCPU: addmm_out_sparse_dense_cpu SparseCUDA: addmm_out_sparse_dense_cuda + Checkpoint: checkpoint_addmm_out supports_named_tensor: True - func: addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor @@ -3209,6 +3493,7 @@ CUDA: legacy::cuda::_th_addmm SparseCPU: addmm_sparse_dense_cpu SparseCUDA: addmm_sparse_dense_cuda + Checkpoint: checkpoint_addmm supports_named_tensor: True - func: addmm_(Tensor(a!) self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor(a!) @@ -3220,6 +3505,7 @@ # broadcasting SparseCPU: s_addmm_sparse_dense_cpu_ SparseCUDA: s_addmm_sparse_dense_cuda_ + Checkpoint: checkpoint_addmm_ supports_named_tensor: True @@ -3655,6 +3941,10 @@ variants: method device_guard: False supports_named_tensor: True + dispatch: + CPU: to + CUDA: to + Checkpoint: checkpoint_to - func: to.device(Tensor self, Device device, ScalarType dtype, bool non_blocking=False, bool copy=False, MemoryFormat? memory_format=None) -> Tensor variants: method @@ -3707,6 +3997,7 @@ dispatch: CPU: _local_scalar_dense_cpu CUDA: _local_scalar_dense_cuda + Checkpoint: checkpoint__local_scalar_dense variants: function supports_named_tensor: True @@ -3714,20 +4005,24 @@ - func: _thnn_fused_lstm_cell(Tensor input_gates, Tensor hidden_gates, Tensor cx, Tensor? input_bias=None, Tensor? hidden_bias=None) -> (Tensor, Tensor, Tensor) dispatch: CUDA: _thnn_fused_lstm_cell_cuda + Checkpoint: checkpoint__thnn_fused_lstm_cell - func: _thnn_fused_lstm_cell_backward(Tensor? grad_hy, Tensor? grad_cy, Tensor cx, Tensor cy, Tensor workspace, bool has_bias) -> (Tensor, Tensor, Tensor, Tensor, Tensor) dispatch: CUDA: _thnn_fused_lstm_cell_backward_cuda + Checkpoint: checkpoint__thnn_fused_lstm_cell_backward - func: _thnn_differentiable_lstm_cell_backward(Tensor? grad_hy, Tensor? grad_cy, Tensor input_gates, Tensor hidden_gates, Tensor? input_bias, Tensor? hidden_bias, Tensor cx, Tensor cy) -> (Tensor, Tensor, Tensor, Tensor, Tensor) - func: _thnn_fused_gru_cell(Tensor input_gates, Tensor hidden_gates, Tensor hx, Tensor? input_bias=None, Tensor? hidden_bias=None) -> (Tensor, Tensor) dispatch: CUDA: _thnn_fused_gru_cell_cuda + Checkpoint: checkpoint__thnn_fused_gru_cell - func: _thnn_fused_gru_cell_backward(Tensor grad_hy, Tensor workspace, bool has_bias) -> (Tensor, Tensor, Tensor, Tensor, Tensor) dispatch: CUDA: _thnn_fused_gru_cell_backward_cuda + Checkpoint: checkpoint__thnn_fused_gru_cell_backward - func: _thnn_differentiable_gru_cell_backward(Tensor grad_hy, Tensor input_gates, Tensor hidden_gates, Tensor hx, Tensor? input_bias, Tensor? hidden_bias) -> (Tensor, Tensor, Tensor, Tensor, Tensor) @@ -3834,6 +4129,7 @@ dispatch: CPU: masked_fill__cpu CUDA: masked_fill__cuda + Checkpoint: checkpoint_masked_fill_ supports_named_tensor: True - func: masked_fill.Scalar(Tensor self, Tensor mask, Scalar value) -> Tensor @@ -3846,6 +4142,7 @@ dispatch: CPU: masked_fill__cpu CUDA: masked_fill__cuda + Checkpoint: checkpoint_masked_fill_ supports_named_tensor: True - func: masked_fill.Tensor(Tensor self, Tensor mask, Tensor value) -> Tensor @@ -3871,6 +4168,7 @@ CUDA: view MkldnnCPU: mkldnn_view QuantizedCPU: view + Checkpoint: checkpoint_view - func: put_(Tensor(a!) self, Tensor index, Tensor source, bool accumulate=False) -> Tensor(a!) variants: method @@ -4011,12 +4309,14 @@ dispatch: CPU: bitwise_and_out CUDA: bitwise_and_out + Checkpoint: checkpoint_bitwise_and_out - func: bitwise_and.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) variants: function dispatch: CPU: bitwise_and_out CUDA: bitwise_and_out + Checkpoint: checkpoint_bitwise_and_out - func: bitwise_and.Scalar(Tensor self, Scalar other) -> Tensor variants: method, function @@ -4280,6 +4580,10 @@ - func: addcdiv_(Tensor(a!) self, Tensor tensor1, Tensor tensor2, *, Scalar value=1) -> Tensor(a!) variants: method supports_named_tensor: True + dispatch: + CPU: addcdiv_ + CUDA: addcdiv_ + Checkpoint: checkpoint_addcdiv_ - func: random_.from(Tensor(a!) self, int from, int? to, *, Generator? generator=None) -> Tensor(a!) variants: method @@ -4322,6 +4626,7 @@ dispatch: CPU: legacy::cpu::_th_diag_out CUDA: legacy::cuda::_th_diag_out + Checkpoint: checkpoint_diag_out - func: diag(Tensor self, int diagonal=0) -> Tensor use_c10_dispatcher: full @@ -4329,6 +4634,7 @@ dispatch: CPU: legacy::cpu::_th_diag CUDA: legacy::cuda::_th_diag + Checkpoint: checkpoint_diag - func: cross.out(Tensor self, Tensor other, int? dim=None, *, Tensor(a!) out) -> Tensor(a!) @@ -4377,6 +4683,7 @@ CPU: ne_out CUDA: ne_out QuantizedCPU: ne_out_quantized_cpu + Checkpoint: checkpoint_ne_Scalar_out - func: ne.Scalar(Tensor self, Scalar other) -> Tensor supports_named_tensor: True @@ -4386,6 +4693,7 @@ CPU: ne CUDA: ne QuantizedCPU: ne_quantized_cpu + Checkpoint: checkpoint_ne_Scalar - func: ne.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) supports_named_tensor: True @@ -4393,6 +4701,7 @@ CPU: ne_out CUDA: ne_out QuantizedCPU: ne_out_quantized_cpu + Checkpoint: checkpoint_ne_Tensor_out - func: ne.Tensor(Tensor self, Tensor other) -> Tensor supports_named_tensor: True @@ -4402,6 +4711,7 @@ CPU: ne CUDA: ne QuantizedCPU: ne_quantized_cpu + Checkpoint: checkpoint_ne_Tensor - func: eq.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) supports_named_tensor: True @@ -4409,6 +4719,7 @@ CPU: eq_out CUDA: eq_out QuantizedCPU: eq_out_quantized_cpu + Checkpoint: checkpoint_eq_Scalar_out - func: eq.Scalar(Tensor self, Scalar other) -> Tensor supports_named_tensor: True @@ -4418,6 +4729,7 @@ CPU: eq CUDA: eq QuantizedCPU: eq_quantized_cpu + Checkpoint: checkpoint_eq_Scalar - func: eq.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) supports_named_tensor: True @@ -4425,6 +4737,7 @@ CPU: eq_out CUDA: eq_out QuantizedCPU: eq_out_quantized_cpu + Checkpoint: checkpoint_eq_Tensor_out - func: eq.Tensor(Tensor self, Tensor other) -> Tensor supports_named_tensor: True @@ -4434,6 +4747,7 @@ CPU: eq CUDA: eq QuantizedCPU: eq_quantized_cpu + Checkpoint: checkpoint_eq_Tensor - func: ge.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) supports_named_tensor: True @@ -4537,6 +4851,7 @@ CPU: lt_out CUDA: lt_out QuantizedCPU: lt_out_quantized_cpu + Checkpoint: checkpoint_lt_out - func: lt.Scalar(Tensor self, Scalar other) -> Tensor supports_named_tensor: True @@ -4546,6 +4861,7 @@ CPU: lt CUDA: lt QuantizedCPU: lt_quantized_cpu + Checkpoint: checkpoint_lt - func: lt.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) supports_named_tensor: True @@ -4553,6 +4869,7 @@ CPU: lt_out CUDA: lt_out QuantizedCPU: lt_out_quantized_cpu + Checkpoint: checkpoint_lt_out - func: lt.Tensor(Tensor self, Tensor other) -> Tensor supports_named_tensor: True @@ -4562,6 +4879,7 @@ CPU: lt CUDA: lt QuantizedCPU: lt_quantized_cpu + Checkpoint: checkpoint_lt - func: take.out(Tensor self, Tensor index, *, Tensor(a!) out) -> Tensor(a!) dispatch: @@ -4598,6 +4916,7 @@ dispatch: CPU: masked_select_out_cpu CUDA: masked_select_out_cuda + Checkpoint: checkpoint_masked_select_out supports_named_tensor: True - func: masked_select(Tensor self, Tensor mask) -> Tensor @@ -4606,12 +4925,14 @@ dispatch: CPU: masked_select_cpu CUDA: masked_select_cuda + Checkpoint: checkpoint_masked_select supports_named_tensor: True - func: nonzero.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) dispatch: CPU: legacy::cpu::_th_nonzero_out CUDA: legacy::cuda::_th_nonzero_out + Checkpoint: checkpoint_nonzero_out - func: nonzero(Tensor self) -> Tensor use_c10_dispatcher: full @@ -4619,6 +4940,7 @@ dispatch: CPU: legacy::cpu::_th_nonzero CUDA: legacy::cuda::_th_nonzero + Checkpoint: checkpoint_nonzero - func: nonzero_numpy(Tensor self) -> Tensor[] variants: method, function @@ -4650,10 +4972,18 @@ use_c10_dispatcher: full variants: method, function supports_named_tensor: True + dispatch: + CPU: addcmul + CUDA: addcmul + Checkpoint: checkpoint_addcmul - func: addcmul_(Tensor(a!) self, Tensor tensor1, Tensor tensor2, *, Scalar value=1) -> Tensor(a!) variants: method supports_named_tensor: True + dispatch: + CPU: addcmul_ + CUDA: addcmul_ + Checkpoint: checkpoint_addcmul_ - func: addcdiv.out(Tensor self, Tensor tensor1, Tensor tensor2, *, Scalar value=1, Tensor(a!) out) -> Tensor(a!) supports_named_tensor: True @@ -4662,6 +4992,10 @@ use_c10_dispatcher: full variants: method, function supports_named_tensor: True + dispatch: + CPU: addcdiv + CUDA: addcdiv + Checkpoint: checkpoint_addcdiv - func: lstsq.X(Tensor self, Tensor A, *, Tensor(a!) X, Tensor(b!) qr) -> (Tensor(a!) solution, Tensor(b!) QR) dispatch: @@ -4906,6 +5240,10 @@ - func: sign(Tensor self) -> Tensor variants: function, method supports_named_tensor: True + dispatch: + CPU: sign + CUDA: sign + Checkpoint: checkpoint_sign - func: sign_(Tensor(a!) self) -> Tensor(a!) variants: method @@ -4916,6 +5254,7 @@ dispatch: CPU: sign_out CUDA: sign_out + Checkpoint: checkpoint_sign_out - func: dist(Tensor self, Tensor other, Scalar p=2) -> Tensor use_c10_dispatcher: full @@ -5079,6 +5418,7 @@ dispatch: CPU: topk_out_cpu CUDA: legacy::cuda::_th_topk_out + Checkpoint: checkpoint_topk_values - func: topk(Tensor self, int k, int dim=-1, bool largest=True, bool sorted=True) -> (Tensor values, Tensor indices) variants: method, function @@ -5086,6 +5426,7 @@ CPU: topk CUDA: topk QuantizedCPU: quantized_topk_cpu + Checkpoint: checkpoint_topk - func: all(Tensor self) -> Tensor use_c10_dispatcher: full @@ -5101,6 +5442,7 @@ CUDA: any SparseCPU: any_sparse SparseCUDA: any_sparse + Checkpoint: checkpoint_any - func: renorm.out(Tensor self, Scalar p, int dim, Scalar maxnorm, *, Tensor(a!) out) -> Tensor(a!) dispatch: @@ -5128,6 +5470,7 @@ CPU: legacy::cpu::_th_equal CUDA: legacy::cuda::_th_equal QuantizedCPU: quantized_equal + Checkpoint: checkpoint_equal supports_named_tensor: True - func: pow.Tensor_Tensor_out(Tensor self, Tensor exponent, *, Tensor(a!) out) -> Tensor(a!) @@ -5274,12 +5617,14 @@ CPU: _cat_cpu CUDA: cat_cuda QuantizedCPU: quantized_cat + Checkpoint: checkpoint__cat - func: _cat.out(Tensor[] tensors, int dim=0, *, Tensor(a!) out) -> Tensor(a!) dispatch: CPU: _cat_out_cpu CUDA: cat_out_cuda QuantizedCPU: quantized_cat_out + Checkpoint: checkpoint__cat_out - func: _mode(Tensor self, int dim=-1, bool keepdim=False) -> (Tensor, Tensor) dispatch: @@ -5417,24 +5762,28 @@ dispatch: CPU: nll_loss_forward_out_cpu CUDA: legacy::cuda::_thnn_nll_loss_forward_out + Checkpoint: checkpoint_nll_loss_forward_out - func: nll_loss_forward(Tensor self, Tensor target, Tensor? weight, int reduction, int ignore_index) -> (Tensor output, Tensor total_weight) python_module: nn dispatch: CPU: nll_loss_forward_cpu CUDA: legacy::cuda::_thnn_nll_loss_forward + Checkpoint: checkpoint_nll_loss_forward - func: nll_loss_backward.grad_input(Tensor grad_output, Tensor self, Tensor target, Tensor? weight, int reduction, int ignore_index, Tensor total_weight, *, Tensor(a!) grad_input) -> Tensor(a!) python_module: nn dispatch: CPU: nll_loss_backward_out_cpu CUDA: legacy::cuda::_thnn_nll_loss_backward_out + Checkpoint: checkpoint_nll_loss_backward_grad_input - func: nll_loss_backward(Tensor grad_output, Tensor self, Tensor target, Tensor? weight, int reduction, int ignore_index, Tensor total_weight) -> Tensor python_module: nn dispatch: CPU: nll_loss_backward_cpu CUDA: legacy::cuda::_thnn_nll_loss_backward + Checkpoint: checkpoint_nll_loss_backward - func: nll_loss2d.out(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean, int ignore_index=-100, *, Tensor(a!) out) -> Tensor(a!) python_module: nn @@ -5558,10 +5907,15 @@ dispatch: CPU: hardtanh_backward_out CUDA: hardtanh_backward_out + Checkpoint: checkpoint_hardtanh_backward_out - func: hardtanh_backward(Tensor grad_output, Tensor self, Scalar min_val, Scalar max_val) -> Tensor use_c10_dispatcher: full python_module: nn + dispatch: + CPU: hardtanh_backward + CUDA: hardtanh_backward + Checkpoint: checkpoint_hardtanh_backward - func: hardtanh_(Tensor(a!) self, Scalar min_val=-1, Scalar max_val=1) -> Tensor(a!) python_module: nn @@ -5793,6 +6147,7 @@ CPU: avg_pool2d_out_cpu CUDA: avg_pool2d_out_cuda MkldnnCPU: mkldnn_avg_pool2d_out + Checkpoint: checkpoint_avg_pool2d_out - func: avg_pool2d(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, bool ceil_mode=False, bool count_include_pad=True, int? divisor_override=None) -> Tensor python_module: nn @@ -5801,18 +6156,21 @@ CUDA: avg_pool2d_cuda MkldnnCPU: mkldnn_avg_pool2d QuantizedCPU: quantized_avg_pool2d + Checkpoint: checkpoint_avg_pool2d - func: avg_pool2d_backward.grad_input(Tensor grad_output, Tensor self, int[2] kernel_size, int[2] stride, int[2] padding, bool ceil_mode, bool count_include_pad, int? divisor_override, *, Tensor(a!) grad_input) -> Tensor(a!) python_module: nn dispatch: CPU: avg_pool2d_backward_out_cpu CUDA: avg_pool2d_backward_out_cuda + Checkpoint: checkpoint_avg_pool2d_backward_grad_input - func: avg_pool2d_backward(Tensor grad_output, Tensor self, int[2] kernel_size, int[2] stride, int[2] padding, bool ceil_mode, bool count_include_pad, int? divisor_override) -> Tensor python_module: nn dispatch: CPU: avg_pool2d_backward_cpu CUDA: avg_pool2d_backward_cuda + Checkpoint: checkpoint_avg_pool2d_backward - func: avg_pool3d.out(Tensor self, int[3] kernel_size, int[3] stride=[], int[3] padding=0, bool ceil_mode=False, bool count_include_pad=True, int? divisor_override=None, *, Tensor(a!) out) -> Tensor(a!) python_module: nn @@ -5896,6 +6254,7 @@ dispatch: CPU: max_pool2d_with_indices_out_cpu CUDA: max_pool2d_with_indices_out_cuda + Checkpoint: checkpoint_max_pool2d_with_indices_out # Return: (Tensor output, Tensor indices) - func: max_pool2d_with_indices(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> (Tensor, Tensor) @@ -5903,6 +6262,7 @@ dispatch: CPU: max_pool2d_with_indices_cpu CUDA: max_pool2d_with_indices_cuda + Checkpoint: checkpoint_max_pool2d_with_indices supports_named_tensor: True - func: max_pool2d_with_indices_backward.grad_input(Tensor grad_output, Tensor self, int[2] kernel_size, int[2] stride, int[2] padding, int[2] dilation, bool ceil_mode, Tensor indices, *, Tensor(a!) grad_input) -> Tensor(a!) @@ -5910,12 +6270,14 @@ dispatch: CPU: max_pool2d_with_indices_backward_out_cpu CUDA: max_pool2d_with_indices_backward_out_cuda + Checkpoint: checkpoint_max_pool2d_with_indices_backward_grad_input - func: max_pool2d_with_indices_backward(Tensor grad_output, Tensor self, int[2] kernel_size, int[2] stride, int[2] padding, int[2] dilation, bool ceil_mode, Tensor indices) -> Tensor python_module: nn dispatch: CPU: max_pool2d_with_indices_backward_cpu CUDA: max_pool2d_with_indices_backward_cuda + Checkpoint: checkpoint_max_pool2d_with_indices_backward # Return: (Tensor output, Tensor indices) - func: max_pool3d_with_indices.out(Tensor self, int[3] kernel_size, int[3] stride=[], int[3] padding=0, int[3] dilation=1, bool ceil_mode=False, *, Tensor(a!) out, Tensor(b!) indices) -> (Tensor(a!), Tensor(b!)) @@ -6141,6 +6503,7 @@ dispatch: CPU: upsample_bilinear2d_out_cpu CUDA: upsample_bilinear2d_out_cuda + Checkpoint: checkpoint_upsample_bilinear2d_out - func: upsample_bilinear2d(Tensor self, int[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor python_module: nn @@ -6148,18 +6511,21 @@ CPU: upsample_bilinear2d_cpu CUDA: upsample_bilinear2d_cuda QuantizedCPU: quantized_upsample_bilinear2d_cpu + Checkpoint: checkpoint_upsample_bilinear2d - func: upsample_bilinear2d_backward.grad_input(Tensor grad_output, int[2] output_size, int[4] input_size, bool align_corners, float? scales_h=None, float? scales_w=None, *, Tensor(a!) grad_input) -> Tensor(a!) python_module: nn dispatch: CPU: upsample_bilinear2d_backward_out_cpu CUDA: upsample_bilinear2d_backward_out_cuda + Checkpoint: checkpoint_upsample_bilinear2d_backward_out - func: upsample_bilinear2d_backward(Tensor grad_output, int[2] output_size, int[4] input_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor python_module: nn dispatch: CPU: upsample_bilinear2d_backward_cpu CUDA: upsample_bilinear2d_backward_cuda + Checkpoint: checkpoint_upsample_bilinear2d_backward - func: upsample_bicubic2d.out(Tensor self, int[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None, *, Tensor(a!) out) -> Tensor(a!) python_module: nn @@ -6287,20 +6653,30 @@ dispatch: CPU: sigmoid_backward_out CUDA: sigmoid_backward_out + Checkpoint: checkpoint_sigmoid_backward_out - func: sigmoid_backward(Tensor grad_output, Tensor output) -> Tensor use_c10_dispatcher: full python_module: nn + dispatch: + CPU: sigmoid_backward + CUDA: sigmoid_backward + Checkpoint: checkpoint_sigmoid_backward - func: tanh_backward.grad_input(Tensor grad_output, Tensor output, *, Tensor(a!) grad_input) -> Tensor(a!) python_module: nn dispatch: CPU: tanh_backward_out CUDA: tanh_backward_out + Checkpoint: checkpoint_tanh_backward_out - func: tanh_backward(Tensor grad_output, Tensor output) -> Tensor use_c10_dispatcher: full python_module: nn + dispatch: + CPU: tanh_backward + CUDA: tanh_backward + Checkpoint: checkpoint_tanh_backward # What's a thnn_conv_ versus a slow_conv_? # @@ -6379,24 +6755,28 @@ dispatch: CPU: slow_conv2d_forward_out_cpu CUDA: legacy::cuda::_thnn_conv2d_forward_out + Checkpoint: checkpoint_thnn_conv2d_forward_out - func: thnn_conv2d_forward(Tensor self, Tensor weight, int[2] kernel_size, Tensor? bias, int[2] stride, int[2] padding) -> (Tensor output, Tensor finput, Tensor fgrad_input) python_module: nn dispatch: CPU: slow_conv2d_forward_cpu CUDA: legacy::cuda::_thnn_conv2d_forward + Checkpoint: checkpoint_thnn_conv2d_forward - func: thnn_conv2d_backward.grad_input(Tensor grad_output, Tensor self, Tensor weight, int[2] kernel_size, int[2] stride, int[2] padding, Tensor finput, Tensor fgrad_input, *, Tensor(a!)? grad_input, Tensor(b!)? grad_weight, Tensor(c!)? grad_bias) -> (Tensor(a!), Tensor(b!), Tensor(c!)) python_module: nn dispatch: CPU: slow_conv2d_backward_out_cpu CUDA: legacy::cuda::_thnn_conv2d_backward_out + Checkpoint: checkpoint_thnn_conv2d_backward_out - func: thnn_conv2d_backward.output_mask(Tensor grad_output, Tensor self, Tensor weight, int[2] kernel_size, int[2] stride, int[2] padding, Tensor finput, Tensor fgrad_input, bool[3] output_mask) -> (Tensor grad_input, Tensor grad_weight, Tensor grad_bias) python_module: nn dispatch: CPU: slow_conv2d_backward_cpu CUDA: legacy::cuda::_thnn_conv2d_backward + Checkpoint: checkpoint_thnn_conv2d_backward - func: thnn_conv_depthwise2d.out(Tensor self, Tensor weight, int[2] kernel_size, Tensor? bias=None, int[2] stride=1, int[2] padding=0, int[2] dilation=1, *, Tensor(a!) out) -> Tensor(a!) python_module: nn diff --git a/aten/src/ATen/preprocess_declarations.py b/aten/src/ATen/preprocess_declarations.py index b7d56125f4e..35925c17342 100644 --- a/aten/src/ATen/preprocess_declarations.py +++ b/aten/src/ATen/preprocess_declarations.py @@ -28,7 +28,7 @@ all_types = type_map['floating_point'] + type_map['integral'] + type_map['quantized'] type_map['all'] = all_types -all_backends = ['CPU', 'CUDA', 'SparseCPU', 'SparseCUDA', 'MkldnnCPU', 'QuantizedCPU'] +all_backends = ['CPU', 'CUDA', 'SparseCPU', 'SparseCUDA', 'MkldnnCPU', 'QuantizedCPU', 'Checkpoint'] default_backends = ['CPU', 'CUDA'] diff --git a/aten/src/ATen/templates/TensorBody.h b/aten/src/ATen/templates/TensorBody.h index de230612e19..cd31abf9588 100644 --- a/aten/src/ATen/templates/TensorBody.h +++ b/aten/src/ATen/templates/TensorBody.h @@ -287,6 +287,9 @@ class CAFFE2_API Tensor { /// Returns a `Tensor`'s device. Device device() const; + /// Returns a `Tensor`'s device. + c10::optional optional_device() const; + /// Returns a `Tensor`'s device index. int64_t get_device() const; diff --git a/aten/src/ATen/templates/TensorMethods.h b/aten/src/ATen/templates/TensorMethods.h index 33983ec6175..80c8b4767a0 100644 --- a/aten/src/ATen/templates/TensorMethods.h +++ b/aten/src/ATen/templates/TensorMethods.h @@ -68,6 +68,10 @@ inline Device Tensor::device() const { return impl_->device(); } +inline c10::optional Tensor::optional_device() const { + return impl_->optional_device(); +} + inline int64_t Tensor::get_device() const { // NB: this is not a native function to avoid dispatching overhead. return impl_->get_device(); diff --git a/c10/core/Backend.h b/c10/core/Backend.h index 5f3d8c7733c..f32dace2d04 100644 --- a/c10/core/Backend.h +++ b/c10/core/Backend.h @@ -25,7 +25,51 @@ namespace c10 { * or "SparseCUDA"; backend in torch.backends is something like "MKL" or * "CUDNN". */ -enum class Backend { CPU, CUDA, HIP, SparseCPU, SparseCUDA, SparseHIP, MSNPU, XLA, QuantizedCPU, Undefined, MkldnnCPU, NumOptions }; +enum class Backend { + CPU, + CUDA, + HIP, + SparseCPU, + SparseCUDA, + SparseHIP, + MSNPU, + XLA, + QuantizedCPU, + Undefined, + MkldnnCPU, + Checkpoint, + NumOptions +}; + +// TODO: This probably shouldn't actually be static inline +static inline const char* toString(Backend b) { + switch (b) { + case Backend::CPU: + return "CPU"; + case Backend::CUDA: + return "CUDA"; + case Backend::HIP: + return "HIP"; + case Backend::MSNPU: + return "MSNPU"; + case Backend::XLA: + return "XLA"; + case Backend::SparseCPU: + return "SparseCPU"; + case Backend::SparseCUDA: + return "SparseCUDA"; + case Backend::SparseHIP: + return "SparseHIP"; + case Backend::MkldnnCPU: + return "MkldnnCPU"; + case Backend::QuantizedCPU: + return "QuantizedCPU"; + case Backend::Checkpoint: + return "Checkpoint"; + default: + return "UNKNOWN_BACKEND"; + } +} static inline Backend toSparse(Backend b) { switch (b) { @@ -42,7 +86,7 @@ static inline Backend toSparse(Backend b) { case Backend::SparseHIP: return Backend::SparseHIP; default: - throw std::runtime_error("Unknown backend"); + throw std::runtime_error(std::string("Unknown backend: ") + toString(b)); } } @@ -67,7 +111,7 @@ static inline Backend toDense(Backend b) { case Backend::QuantizedCPU: return Backend::QuantizedCPU; default: - throw std::runtime_error("Unknown backend"); + throw std::runtime_error(std::string("Unknown backend: ") + toString(b)); } } @@ -94,6 +138,8 @@ static inline Backend dispatchKeyToBackend(DispatchKey t) { return Backend::QuantizedCPU; } else if (t == DispatchKey::Undefined) { return Backend::Undefined; + } else if (t == DispatchKey::CheckpointTensorId) { + return Backend::Checkpoint; } else { AT_ERROR("Unrecognized tensor type ID: ", t); } @@ -121,10 +167,12 @@ static inline DispatchKey backendToDispatchKey(Backend b) { return DispatchKey::MkldnnCPUTensorId; case Backend::QuantizedCPU: return DispatchKey::QuantizedCPUTensorId; + case Backend::Checkpoint: + return DispatchKey::CheckpointTensorId; case Backend::Undefined: return DispatchKey::Undefined; default: - throw std::runtime_error("Unknown backend"); + throw std::runtime_error(std::string("Unknown backend: ") + toString(b)); } } @@ -152,7 +200,7 @@ static inline DeviceType backendToDeviceType(Backend b) { case Backend::Undefined: AT_ERROR("Undefined backend is not a valid device type"); default: - AT_ERROR("Unknown backend"); + AT_ERROR(std::string("Unknown backend: ") + toString(b)); } } @@ -180,7 +228,7 @@ static inline Backend backendToCPU(Backend b) { case Backend::Undefined: return Backend::Undefined; default: - AT_ERROR("Unknown backend"); + AT_ERROR(std::string("Unknown backend: ") + toString(b)); } } @@ -199,7 +247,7 @@ static inline Backend backendToCUDA(Backend b) { case Backend::Undefined: return Backend::Undefined; default: - AT_ERROR("Unknown backend"); + AT_ERROR(std::string("Unknown backend: ") + toString(b)); } } @@ -218,35 +266,7 @@ static inline Backend backendToHIP(Backend b) { case Backend::Undefined: return Backend::Undefined; default: - AT_ERROR("Unknown backend"); - } -} - -// TODO: This probably shouldn't actually be static inline -static inline const char* toString(Backend b) { - switch (b) { - case Backend::CPU: - return "CPU"; - case Backend::CUDA: - return "CUDA"; - case Backend::HIP: - return "HIP"; - case Backend::MSNPU: - return "MSNPU"; - case Backend::XLA: - return "XLA"; - case Backend::SparseCPU: - return "SparseCPU"; - case Backend::SparseCUDA: - return "SparseCUDA"; - case Backend::SparseHIP: - return "SparseHIP"; - case Backend::MkldnnCPU: - return "MkldnnCPU"; - case Backend::QuantizedCPU: - return "QuantizedCPU"; - default: - return "UNKNOWN_BACKEND"; + AT_ERROR(std::string("Unknown backend: ") + toString(b)); } } diff --git a/c10/core/DispatchKey.cpp b/c10/core/DispatchKey.cpp index cf20e515c25..d5696184422 100644 --- a/c10/core/DispatchKey.cpp +++ b/c10/core/DispatchKey.cpp @@ -34,6 +34,8 @@ const char* toString(DispatchKey t) { return "MkldnnCPUTensorId"; case DispatchKey::QuantizedCPUTensorId: return "QuantizedCPUTensorId"; + case DispatchKey::CheckpointTensorId: + return "CheckpointTensorId"; case DispatchKey::VariableTensorId: return "VariableTensorId"; case DispatchKey::BackendSelect: diff --git a/c10/core/DispatchKey.h b/c10/core/DispatchKey.h index da7c3c564e1..87d28769e3a 100644 --- a/c10/core/DispatchKey.h +++ b/c10/core/DispatchKey.h @@ -112,6 +112,13 @@ enum class DispatchKey : uint8_t { // constructed by the output, and otherwise defers to the backend to // actually do the numeric computation. VariableTensorId contains // the bulk of this logic. + + // WARNING! If you add more "wrapper" style tensor ids (tensor + // ids which don't get kernels directly defined in native_functions.yaml; + // examples are tracing or profiling) here, you need to also adjust + // legacyExtractDispatchKey in c10/core/DispatchKeySet.h to mask them out. + CheckpointTensorId, + VariableTensorId, // Pre-autograd dispatch keys allow backends to override the autograd behavior diff --git a/c10/core/DispatchKeySet.h b/c10/core/DispatchKeySet.h index 50e3f024adb..a289c94cde8 100644 --- a/c10/core/DispatchKeySet.h +++ b/c10/core/DispatchKeySet.h @@ -130,7 +130,7 @@ static inline DispatchKey legacyExtractDispatchKey(DispatchKeySet s) { // VariableTensorId is being excluded from a DispatchKeySet right after dispatching // (See variable_excluded_from_dispatch in TensorBody.h) // Now we are getting rid of BackendSelect. - return s.remove(DispatchKey::BackendSelect).highestPriorityTypeId(); + return s.remove(DispatchKey::BackendSelect).remove(DispatchKey::CheckpointTensorId).highestPriorityTypeId(); } } diff --git a/c10/core/TensorImpl.h b/c10/core/TensorImpl.h index 94ebaa3bfe7..58e6263d861 100644 --- a/c10/core/TensorImpl.h +++ b/c10/core/TensorImpl.h @@ -465,6 +465,10 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { return *device_opt_; } + c10::optional optional_device() const { + return device_opt_; + } + Layout layout() const { // NB: This method is not virtual and avoid dispatches for perf. if (is_sparse()) { @@ -858,11 +862,18 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { /** * One TensorImpl can be copied to another TensorImpl if they have the same - * DispatchKeySet. The only two special cases (for legacy reason) are: + * DispatchKeySet. + * Special cases (for legacy reason) are: * CPUTensorId is compatible with CUDATensorId and SparseCPUTensorId is * compatible with SparseCUDATensorId. */ inline bool has_compatible_shallow_copy_type(DispatchKeySet from) { + if (key_set_ == from) { + return true; + } + if (key_set_.has(DispatchKey::CheckpointTensorId) || from.has(DispatchKey::CheckpointTensorId)) { + return false; + } auto is_dense = [](DispatchKeySet ts) { return ts.has(DispatchKey::CPUTensorId) || ts.has(DispatchKey::CUDATensorId) || @@ -873,7 +884,7 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { ts.has(DispatchKey::SparseCUDATensorId) || ts.has(DispatchKey::SparseHIPTensorId); }; - return (key_set_ == from) || (is_dense(key_set_) && is_dense(from)) || (is_sparse(key_set_) && is_sparse(from)); + return (is_dense(key_set_) && is_dense(from)) || (is_sparse(key_set_) && is_sparse(from)); } /** diff --git a/c10/cuda/CUDACachingAllocator.cpp b/c10/cuda/CUDACachingAllocator.cpp index 8c272c2ce5f..d309288291a 100644 --- a/c10/cuda/CUDACachingAllocator.cpp +++ b/c10/cuda/CUDACachingAllocator.cpp @@ -204,7 +204,6 @@ class THCCachingAllocator { void malloc(void** devPtr, size_t size, cudaStream_t stream) { std::lock_guard lock(mutex); - int device; C10_CUDA_CHECK(cudaGetDevice(&device)); diff --git a/test.py b/test.py new file mode 100644 index 00000000000..9a6ee104db0 --- /dev/null +++ b/test.py @@ -0,0 +1,4 @@ +import torch +x = torch.Tensor([1]).checkpoint() +y = x +z = y diff --git a/third_party/json b/third_party/json new file mode 160000 index 00000000000..19843b038ca --- /dev/null +++ b/third_party/json @@ -0,0 +1 @@ +Subproject commit 19843b038caa463164d6f89ea1b2765fae7552e9 diff --git a/tools/autograd/templates/Functions.cpp b/tools/autograd/templates/Functions.cpp index 70278a08190..f8a3a8351b3 100644 --- a/tools/autograd/templates/Functions.cpp +++ b/tools/autograd/templates/Functions.cpp @@ -656,18 +656,6 @@ Tensor index_select_backward(Tensor grad, int64_t dim, Tensor indices, IntArrayR return at::zeros(sizes, grad.options()).scatter_(dim, indices, grad); } -Tensor slice_backward(Tensor grad, IntArrayRef input_sizes, int64_t dim, int64_t start, int64_t end, int64_t step) { - auto grad_input = at::zeros(input_sizes, grad.options()); - grad_input.slice(dim, start, end, step).copy_(grad); - return grad_input; -} - -Tensor select_backward(Tensor grad, IntArrayRef input_sizes, int64_t dim, int64_t index) { - auto grad_input = at::zeros(input_sizes, grad.options()); - grad_input.select(dim, index).copy_(grad); - return grad_input; -} - Tensor trace_backward(const Tensor & grad, IntArrayRef sizes) { if (sizes.size() != 2) { throw std::runtime_error("expected matrix input"); @@ -802,38 +790,6 @@ Tensor cholesky_inverse_backward(Tensor grad, Tensor L, bool upper, Tensor inver return grad_L; } -Tensor split_with_sizes_backward(const std::vector &grads, - IntArrayRef split_sizes, int64_t dim, IntArrayRef sizes, const at::TensorOptions &options) { - dim = at::maybe_wrap_dim(dim, sizes.size()); - - // it's possible some of the grads are not defined (represents tensors of all 0s). - // Since at::cat can't handle those, let's define them - std::vector grads_all_defined(grads.size()); - for (size_t j = 0; j < grads.size(); ++j) { - if (grads[j].defined()) { - grads_all_defined[j] = grads[j]; - } else { - auto length = split_sizes[j]; - auto grad_size = sizes.vec(); - grad_size[dim] = length; - grads_all_defined[j] = at::zeros(grad_size, options); - } - } - - auto ret = at::cat(grads_all_defined, dim); - return ret; -} - -Tensor split_backward(const std::vector &grads, - int64_t split_size, int64_t dim, IntArrayRef sizes, const at::TensorOptions &options) { - dim = at::maybe_wrap_dim(dim, sizes.size()); - int64_t dim_size = sizes[dim]; - int64_t num_splits = grads.size(); - std::vector split_sizes(num_splits, split_size); - split_sizes[num_splits - 1] = split_size - (split_size * num_splits - dim_size); - return split_with_sizes_backward(grads, split_sizes, dim, sizes, options); -} - Tensor max_pool_double_backward(const Tensor & grad, const Tensor & indices, int dim) { AT_ASSERT(indices.dim() >= dim); auto size = indices.sizes().slice(0, indices.dim() - dim).vec(); diff --git a/torch/csrc/utils/tensor_numpy.cpp b/torch/csrc/utils/tensor_numpy.cpp index 092d0a37da4..7e56c232786 100644 --- a/torch/csrc/utils/tensor_numpy.cpp +++ b/torch/csrc/utils/tensor_numpy.cpp @@ -73,7 +73,8 @@ static std::vector seq_to_aten_shape(PyObject *py_seq) { return result; } -PyObject* tensor_to_numpy(const at::Tensor& tensor) { +PyObject* tensor_to_numpy(const at::Tensor& tensor_) { + Tensor tensor = tensor_.decheckpoint(); if (tensor.device().type() != DeviceType::CPU) { throw TypeError( "can't convert %s device type tensor to numpy. Use Tensor.cpu() to " diff --git a/torch/nn/functional.py b/torch/nn/functional.py index 4777206b90b..13037b8cbfe 100644 --- a/torch/nn/functional.py +++ b/torch/nn/functional.py @@ -3776,9 +3776,10 @@ def multi_head_attention_forward(query, # type: Tensor q = q.contiguous().view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1) if k is not None: - k = k.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1) + # extremely bizarre, but got errors of dimension mistmatch here and below unless I changed view -> reshape + k = k.contiguous().reshape(-1, bsz * num_heads, head_dim).transpose(0, 1) if v is not None: - v = v.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1) + v = v.contiguous().reshape(-1, bsz * num_heads, head_dim).transpose(0, 1) if static_k is not None: assert static_k.size(0) == bsz * num_heads @@ -3825,7 +3826,8 @@ def multi_head_attention_forward(query, # type: Tensor attn_output = torch.bmm(attn_output_weights, v) assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim] - attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) + # also had to change view to reshape + attn_output = attn_output.transpose(0, 1).contiguous().reshape(tgt_len, bsz, embed_dim) attn_output = linear(attn_output, out_proj_weight, out_proj_bias) if need_weights: