From 23e38ff9788b52886b43e57da1be4c5ea327ce20 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9B=BE=E9=9B=A8=E9=AD=94=E7=90=86=E6=B2=99?= Date: Fri, 6 Mar 2020 11:57:11 -0800 Subject: [PATCH 01/42] Backend Cleanroom port (#13) Implements CheckpointTensor backend --- c10/core/Backend.h | 92 ++++++++++++++++++++++++--------------- c10/core/DispatchKey.cpp | 2 + c10/core/DispatchKey.h | 7 +++ c10/core/DispatchKeySet.h | 2 +- 4 files changed, 66 insertions(+), 37 deletions(-) diff --git a/c10/core/Backend.h b/c10/core/Backend.h index 5f3d8c7733c..c25192d4e0b 100644 --- a/c10/core/Backend.h +++ b/c10/core/Backend.h @@ -25,7 +25,51 @@ namespace c10 { * or "SparseCUDA"; backend in torch.backends is something like "MKL" or * "CUDNN". */ -enum class Backend { CPU, CUDA, HIP, SparseCPU, SparseCUDA, SparseHIP, MSNPU, XLA, QuantizedCPU, Undefined, MkldnnCPU, NumOptions }; +enum class Backend { + CPU, + CUDA, + HIP, + SparseCPU, + SparseCUDA, + SparseHIP, + MSNPU, + XLA, + QuantizedCPU, + Undefined, + MkldnnCPU, + CheckPoint, + NumOptions +}; + +// TODO: This probably shouldn't actually be static inline +static inline const char* toString(Backend b) { + switch (b) { + case Backend::CPU: + return "CPU"; + case Backend::CUDA: + return "CUDA"; + case Backend::HIP: + return "HIP"; + case Backend::MSNPU: + return "MSNPU"; + case Backend::XLA: + return "XLA"; + case Backend::SparseCPU: + return "SparseCPU"; + case Backend::SparseCUDA: + return "SparseCUDA"; + case Backend::SparseHIP: + return "SparseHIP"; + case Backend::MkldnnCPU: + return "MkldnnCPU"; + case Backend::QuantizedCPU: + return "QuantizedCPU"; + case Backend::CheckPoint: + return "CheckPoint"; + default: + return "UNKNOWN_BACKEND"; + } +} static inline Backend toSparse(Backend b) { switch (b) { @@ -42,7 +86,7 @@ static inline Backend toSparse(Backend b) { case Backend::SparseHIP: return Backend::SparseHIP; default: - throw std::runtime_error("Unknown backend"); + throw std::runtime_error(std::string("Unknown backend: ") + toString(b)); } } @@ -67,7 +111,7 @@ static inline Backend toDense(Backend b) { case Backend::QuantizedCPU: return Backend::QuantizedCPU; default: - throw std::runtime_error("Unknown backend"); + throw std::runtime_error(std::string("Unknown backend: ") + toString(b)); } } @@ -94,6 +138,8 @@ static inline Backend dispatchKeyToBackend(DispatchKey t) { return Backend::QuantizedCPU; } else if (t == DispatchKey::Undefined) { return Backend::Undefined; + } else if (t == DispatchKey::CheckPointTensorId) { + return Backend::CheckPoint; } else { AT_ERROR("Unrecognized tensor type ID: ", t); } @@ -121,10 +167,12 @@ static inline DispatchKey backendToDispatchKey(Backend b) { return DispatchKey::MkldnnCPUTensorId; case Backend::QuantizedCPU: return DispatchKey::QuantizedCPUTensorId; + case Backend::CheckPoint: + return DispatchKey::CheckPointTensorId; case Backend::Undefined: return DispatchKey::Undefined; default: - throw std::runtime_error("Unknown backend"); + throw std::runtime_error(std::string("Unknown backend: ") + toString(b)); } } @@ -152,7 +200,7 @@ static inline DeviceType backendToDeviceType(Backend b) { case Backend::Undefined: AT_ERROR("Undefined backend is not a valid device type"); default: - AT_ERROR("Unknown backend"); + AT_ERROR(std::string("Unknown backend: ") + toString(b)); } } @@ -180,7 +228,7 @@ static inline Backend backendToCPU(Backend b) { case Backend::Undefined: return Backend::Undefined; default: - AT_ERROR("Unknown backend"); + AT_ERROR(std::string("Unknown backend: ") + toString(b)); } } @@ -199,7 +247,7 @@ static inline Backend backendToCUDA(Backend b) { case Backend::Undefined: return Backend::Undefined; default: - AT_ERROR("Unknown backend"); + AT_ERROR(std::string("Unknown backend: ") + toString(b)); } } @@ -218,35 +266,7 @@ static inline Backend backendToHIP(Backend b) { case Backend::Undefined: return Backend::Undefined; default: - AT_ERROR("Unknown backend"); - } -} - -// TODO: This probably shouldn't actually be static inline -static inline const char* toString(Backend b) { - switch (b) { - case Backend::CPU: - return "CPU"; - case Backend::CUDA: - return "CUDA"; - case Backend::HIP: - return "HIP"; - case Backend::MSNPU: - return "MSNPU"; - case Backend::XLA: - return "XLA"; - case Backend::SparseCPU: - return "SparseCPU"; - case Backend::SparseCUDA: - return "SparseCUDA"; - case Backend::SparseHIP: - return "SparseHIP"; - case Backend::MkldnnCPU: - return "MkldnnCPU"; - case Backend::QuantizedCPU: - return "QuantizedCPU"; - default: - return "UNKNOWN_BACKEND"; + AT_ERROR(std::string("Unknown backend: ") + toString(b)); } } diff --git a/c10/core/DispatchKey.cpp b/c10/core/DispatchKey.cpp index cf20e515c25..cd380872059 100644 --- a/c10/core/DispatchKey.cpp +++ b/c10/core/DispatchKey.cpp @@ -34,6 +34,8 @@ const char* toString(DispatchKey t) { return "MkldnnCPUTensorId"; case DispatchKey::QuantizedCPUTensorId: return "QuantizedCPUTensorId"; + case DispatchKey::CheckPointTensorId: + return "CheckPointTensorId"; case DispatchKey::VariableTensorId: return "VariableTensorId"; case DispatchKey::BackendSelect: diff --git a/c10/core/DispatchKey.h b/c10/core/DispatchKey.h index da7c3c564e1..700b67e44d0 100644 --- a/c10/core/DispatchKey.h +++ b/c10/core/DispatchKey.h @@ -112,6 +112,13 @@ enum class DispatchKey : uint8_t { // constructed by the output, and otherwise defers to the backend to // actually do the numeric computation. VariableTensorId contains // the bulk of this logic. + + // WARNING! If you add more "wrapper" style tensor ids (tensor + // ids which don't get kernels directly defined in native_functions.yaml; + // examples are tracing or profiling) here, you need to also adjust + // legacyExtractDispatchKey in c10/core/DispatchKeySet.h to mask them out. + CheckPointTensorId, + VariableTensorId, // Pre-autograd dispatch keys allow backends to override the autograd behavior diff --git a/c10/core/DispatchKeySet.h b/c10/core/DispatchKeySet.h index 50e3f024adb..8d6f63b35fa 100644 --- a/c10/core/DispatchKeySet.h +++ b/c10/core/DispatchKeySet.h @@ -130,7 +130,7 @@ static inline DispatchKey legacyExtractDispatchKey(DispatchKeySet s) { // VariableTensorId is being excluded from a DispatchKeySet right after dispatching // (See variable_excluded_from_dispatch in TensorBody.h) // Now we are getting rid of BackendSelect. - return s.remove(DispatchKey::BackendSelect).highestPriorityTypeId(); + return s.remove(DispatchKey::BackendSelect).remove(DispatchKey::CheckPointTensorId).highestPriorityTypeId(); } } From 5ffafc20de782f21d902bc6edf1c6f63754137f6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9B=BE=E9=9B=A8=E9=AD=94=E7=90=86=E6=B2=99?= Date: Sun, 15 Mar 2020 19:49:50 -0700 Subject: [PATCH 02/42] Hook up the overload with a passing test. (#15) * save * save * review * review * add back files --- aten/src/ATen/CheckpointTensorImpl.h | 33 ++++++++++++++++++++++ aten/src/ATen/gen.py | 6 ++++ aten/src/ATen/native/Checkpoint.cpp | 25 ++++++++++++++++ aten/src/ATen/native/native_functions.yaml | 9 ++++++ aten/src/ATen/preprocess_declarations.py | 2 +- aten/src/ATen/templates/TensorBody.h | 3 ++ aten/src/ATen/templates/TensorMethods.h | 4 +++ c10/core/Backend.h | 14 ++++----- c10/core/DispatchKey.cpp | 4 +-- c10/core/DispatchKey.h | 2 +- c10/core/DispatchKeySet.h | 2 +- c10/core/TensorImpl.h | 7 ++++- test.py | 8 ++++++ 13 files changed, 106 insertions(+), 13 deletions(-) create mode 100644 aten/src/ATen/CheckpointTensorImpl.h create mode 100644 aten/src/ATen/native/Checkpoint.cpp create mode 100644 test.py diff --git a/aten/src/ATen/CheckpointTensorImpl.h b/aten/src/ATen/CheckpointTensorImpl.h new file mode 100644 index 00000000000..22685effb3e --- /dev/null +++ b/aten/src/ATen/CheckpointTensorImpl.h @@ -0,0 +1,33 @@ +#pragma once + +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at { + +struct CAFFE2_API CheckpointTensorImpl final : public TensorImpl { + Tensor t; + explicit CheckpointTensorImpl(const Tensor& t) : + TensorImpl(t.key_set(), t.dtype(), t.optional_device()), + t(t) { } +}; + +} diff --git a/aten/src/ATen/gen.py b/aten/src/ATen/gen.py index e64bbd891a8..5a1707c6406 100644 --- a/aten/src/ATen/gen.py +++ b/aten/src/ATen/gen.py @@ -353,6 +353,8 @@ def generate_storage_type_and_tensor(backend, density, declarations, per_op_regi if env['DeviceType'] == 'CPU': top_env['cpu_type_headers'].append( '#include "ATen/{}.h"'.format(env['Type'])) + elif env['DeviceType'] == 'Checkpoint': + pass else: assert env['DeviceType'] == 'CUDA' top_env['cuda_type_headers'].append( @@ -411,6 +413,8 @@ def declare_outputs(): fname = gen_per_op_registration_filename(whitelisted_op) file_manager.will_write(fname) + file_manager.will_write("CheckpointType.h") + file_manager.will_write("CheckpointType.cpp") def filter_by_extension(files, *extensions): filtered_files = [] @@ -478,6 +482,8 @@ def generate_outputs(): generate_storage_type_and_tensor( backend, density, declarations, per_op_registrations) + generate_storage_type_and_tensor('Checkpoint', 'Dense', declarations, per_op_registrations) + core_files = { 'TensorBody.h': TENSOR_H, 'TensorMethods.h': TENSOR_METHODS_H, diff --git a/aten/src/ATen/native/Checkpoint.cpp b/aten/src/ATen/native/Checkpoint.cpp new file mode 100644 index 00000000000..2cb5b568ff6 --- /dev/null +++ b/aten/src/ATen/native/Checkpoint.cpp @@ -0,0 +1,25 @@ +#include +#include + +namespace at { namespace native { + +Tensor checkpoint(const Tensor& t) { + return Tensor(intrusive_ptr::make(t.detach())); +} + +Tensor decheckpoint(const Tensor& t) { + auto* cpti = dynamic_cast(t.unsafeGetTensorImpl()); + CHECK(cpti != nullptr); + return cpti->t; +} + +bool is_checkpoint(const Tensor& t) { + auto* cpti = dynamic_cast(t.unsafeGetTensorImpl()); + return cpti != nullptr; +} + +Tensor checkpoint_add(at::Tensor const& a, at::Tensor const& b, c10::Scalar c) { + return checkpoint(at::add(decheckpoint(a), decheckpoint(b), c)); +} + +}} diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index a9e02ea8ec9..4561891d82f 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -1,5 +1,13 @@ # See README.md in this directory for more guidance +- func: checkpoint(Tensor self) -> Tensor + variants: method + +- func: is_checkpoint(Tensor self) -> bool + variants: method + +- func: decheckpoint(Tensor self) -> Tensor + variants: method # Temporary type cast operators. These are needed to trace type-casts now since # Type's are not supported in the IR. Instead, we call down to these @@ -285,6 +293,7 @@ SparseCPU: add_sparse SparseCUDA: add_sparse MkldnnCPU: mkldnn_add + Checkpoint: checkpoint_add supports_named_tensor: True - func: add_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> Tensor(a!) diff --git a/aten/src/ATen/preprocess_declarations.py b/aten/src/ATen/preprocess_declarations.py index b7d56125f4e..35925c17342 100644 --- a/aten/src/ATen/preprocess_declarations.py +++ b/aten/src/ATen/preprocess_declarations.py @@ -28,7 +28,7 @@ all_types = type_map['floating_point'] + type_map['integral'] + type_map['quantized'] type_map['all'] = all_types -all_backends = ['CPU', 'CUDA', 'SparseCPU', 'SparseCUDA', 'MkldnnCPU', 'QuantizedCPU'] +all_backends = ['CPU', 'CUDA', 'SparseCPU', 'SparseCUDA', 'MkldnnCPU', 'QuantizedCPU', 'Checkpoint'] default_backends = ['CPU', 'CUDA'] diff --git a/aten/src/ATen/templates/TensorBody.h b/aten/src/ATen/templates/TensorBody.h index de230612e19..cd31abf9588 100644 --- a/aten/src/ATen/templates/TensorBody.h +++ b/aten/src/ATen/templates/TensorBody.h @@ -287,6 +287,9 @@ class CAFFE2_API Tensor { /// Returns a `Tensor`'s device. Device device() const; + /// Returns a `Tensor`'s device. + c10::optional optional_device() const; + /// Returns a `Tensor`'s device index. int64_t get_device() const; diff --git a/aten/src/ATen/templates/TensorMethods.h b/aten/src/ATen/templates/TensorMethods.h index 33983ec6175..80c8b4767a0 100644 --- a/aten/src/ATen/templates/TensorMethods.h +++ b/aten/src/ATen/templates/TensorMethods.h @@ -68,6 +68,10 @@ inline Device Tensor::device() const { return impl_->device(); } +inline c10::optional Tensor::optional_device() const { + return impl_->optional_device(); +} + inline int64_t Tensor::get_device() const { // NB: this is not a native function to avoid dispatching overhead. return impl_->get_device(); diff --git a/c10/core/Backend.h b/c10/core/Backend.h index c25192d4e0b..f32dace2d04 100644 --- a/c10/core/Backend.h +++ b/c10/core/Backend.h @@ -37,7 +37,7 @@ enum class Backend { QuantizedCPU, Undefined, MkldnnCPU, - CheckPoint, + Checkpoint, NumOptions }; @@ -64,8 +64,8 @@ static inline const char* toString(Backend b) { return "MkldnnCPU"; case Backend::QuantizedCPU: return "QuantizedCPU"; - case Backend::CheckPoint: - return "CheckPoint"; + case Backend::Checkpoint: + return "Checkpoint"; default: return "UNKNOWN_BACKEND"; } @@ -138,8 +138,8 @@ static inline Backend dispatchKeyToBackend(DispatchKey t) { return Backend::QuantizedCPU; } else if (t == DispatchKey::Undefined) { return Backend::Undefined; - } else if (t == DispatchKey::CheckPointTensorId) { - return Backend::CheckPoint; + } else if (t == DispatchKey::CheckpointTensorId) { + return Backend::Checkpoint; } else { AT_ERROR("Unrecognized tensor type ID: ", t); } @@ -167,8 +167,8 @@ static inline DispatchKey backendToDispatchKey(Backend b) { return DispatchKey::MkldnnCPUTensorId; case Backend::QuantizedCPU: return DispatchKey::QuantizedCPUTensorId; - case Backend::CheckPoint: - return DispatchKey::CheckPointTensorId; + case Backend::Checkpoint: + return DispatchKey::CheckpointTensorId; case Backend::Undefined: return DispatchKey::Undefined; default: diff --git a/c10/core/DispatchKey.cpp b/c10/core/DispatchKey.cpp index cd380872059..d5696184422 100644 --- a/c10/core/DispatchKey.cpp +++ b/c10/core/DispatchKey.cpp @@ -34,8 +34,8 @@ const char* toString(DispatchKey t) { return "MkldnnCPUTensorId"; case DispatchKey::QuantizedCPUTensorId: return "QuantizedCPUTensorId"; - case DispatchKey::CheckPointTensorId: - return "CheckPointTensorId"; + case DispatchKey::CheckpointTensorId: + return "CheckpointTensorId"; case DispatchKey::VariableTensorId: return "VariableTensorId"; case DispatchKey::BackendSelect: diff --git a/c10/core/DispatchKey.h b/c10/core/DispatchKey.h index 700b67e44d0..87d28769e3a 100644 --- a/c10/core/DispatchKey.h +++ b/c10/core/DispatchKey.h @@ -117,7 +117,7 @@ enum class DispatchKey : uint8_t { // ids which don't get kernels directly defined in native_functions.yaml; // examples are tracing or profiling) here, you need to also adjust // legacyExtractDispatchKey in c10/core/DispatchKeySet.h to mask them out. - CheckPointTensorId, + CheckpointTensorId, VariableTensorId, diff --git a/c10/core/DispatchKeySet.h b/c10/core/DispatchKeySet.h index 8d6f63b35fa..a289c94cde8 100644 --- a/c10/core/DispatchKeySet.h +++ b/c10/core/DispatchKeySet.h @@ -130,7 +130,7 @@ static inline DispatchKey legacyExtractDispatchKey(DispatchKeySet s) { // VariableTensorId is being excluded from a DispatchKeySet right after dispatching // (See variable_excluded_from_dispatch in TensorBody.h) // Now we are getting rid of BackendSelect. - return s.remove(DispatchKey::BackendSelect).remove(DispatchKey::CheckPointTensorId).highestPriorityTypeId(); + return s.remove(DispatchKey::BackendSelect).remove(DispatchKey::CheckpointTensorId).highestPriorityTypeId(); } } diff --git a/c10/core/TensorImpl.h b/c10/core/TensorImpl.h index 94ebaa3bfe7..79c5ec375fa 100644 --- a/c10/core/TensorImpl.h +++ b/c10/core/TensorImpl.h @@ -465,6 +465,10 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { return *device_opt_; } + c10::optional optional_device() const { + return device_opt_; + } + Layout layout() const { // NB: This method is not virtual and avoid dispatches for perf. if (is_sparse()) { @@ -858,7 +862,8 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { /** * One TensorImpl can be copied to another TensorImpl if they have the same - * DispatchKeySet. The only two special cases (for legacy reason) are: + * DispatchKeySet. + * Special cases (for legacy reason) are: * CPUTensorId is compatible with CUDATensorId and SparseCPUTensorId is * compatible with SparseCUDATensorId. */ diff --git a/test.py b/test.py new file mode 100644 index 00000000000..af2051c39ad --- /dev/null +++ b/test.py @@ -0,0 +1,8 @@ +import torch +x = torch.Tensor([1]).checkpoint() +y = torch.Tensor([2]).checkpoint() +z = x + y +print(z) +print(z.decheckpoint()) +print(z.is_checkpoint()) +print(z.decheckpoint().is_checkpoint()) From 64647525eb8d63c834d5c7c698ce0a7ccc478a84 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9B=BE=E9=9B=A8=E9=AD=94=E7=90=86=E6=B2=99?= Date: Thu, 19 Mar 2020 21:22:50 -0700 Subject: [PATCH 03/42] Adding some overload (#16) * save * save * review * remove outdated file * save * can compile again * save * save * up up array * finsih the overloads * replace vec[n] with vec.at(n) --- aten/src/ATen/CheckpointTensorImpl.cpp | 83 ++++++++ aten/src/ATen/CheckpointTensorImpl.h | 80 +++++++- aten/src/ATen/native/Checkpoint.cpp | 226 ++++++++++++++++++++- aten/src/ATen/native/native_functions.yaml | 35 ++++ 4 files changed, 418 insertions(+), 6 deletions(-) create mode 100644 aten/src/ATen/CheckpointTensorImpl.cpp diff --git a/aten/src/ATen/CheckpointTensorImpl.cpp b/aten/src/ATen/CheckpointTensorImpl.cpp new file mode 100644 index 00000000000..a3ddf28a5ef --- /dev/null +++ b/aten/src/ATen/CheckpointTensorImpl.cpp @@ -0,0 +1,83 @@ +#include +#include + +namespace at { + +struct DTRLogger { + std::ofstream out; + static std::string get_filename() { + std::time_t t = std::time(nullptr); + std::tm* tm = std::localtime(&t); + std::string str = + std::to_string(1900+tm->tm_year) + "-" + + std::to_string(1+tm->tm_mon) + "-" + + std::to_string(tm->tm_mday) + "-" + + std::to_string(tm->tm_hour) + "-" + + std::to_string(tm->tm_min) + "-" + + std::to_string(tm->tm_sec) + ".log"; + return str; + } + DTRLogger() : out(get_filename()) { } +}; + +void DTRLog(const std::string& str) { + static DTRLogger logger; + logger.out << str << std::endl; +} + +int CheckpointTensorCell::counter = 0; + +Tensors make_raw(const rematerialize_function_t& remat, + const strongs& input_values) { + std::vector input; + for (const strong& s: input_values) { + input.push_back(s->t); + } + auto output = remat(input); + Tensors ret; + for (const Tensor& o: output) { + ret.push_back(native::checkpoint(o)); + } + return ret; +} + +Tensors CheckpointTensorImpl::make(const std::string& name, + const rematerialize_function_t& remat, + const strongs& input_values) { + Tensors ret = make_raw(remat, input_values); + std::string log("("); + for (const Tensor& t: ret) { + log += cell_from_tensor(t)->value->name(); + log += ", "; + } + log += ") = "; + log += name; + log += "("; + for (const strong& s: input_values) { + log += s->name(); + log += ", "; + } + log += ")"; + DTRLog(log); + return ret; +} + +void CheckpointTensorImpl::mutate(const std::string& name, + const mutate_function_t& mutate, + const Tensors& inputs) { + auto remat = [=](const Tensors& t) -> Tensors { + auto t0 = t[0].clone(); + Tensors new_input_values = t; + new_input_values[0] = t0; + mutate(new_input_values); + return {t0}; + }; + strongs input_values; + for (const Tensor& t : inputs) { + input_values.push_back(from_tensor(t)); + } + auto modified = make_raw(remat, input_values)[0]; + cell_from_tensor(inputs[0])->value = cell_from_tensor(modified)->value; +} + +} diff --git a/aten/src/ATen/CheckpointTensorImpl.h b/aten/src/ATen/CheckpointTensorImpl.h index 22685effb3e..747590cdb71 100644 --- a/aten/src/ATen/CheckpointTensorImpl.h +++ b/aten/src/ATen/CheckpointTensorImpl.h @@ -23,11 +23,83 @@ namespace at { -struct CAFFE2_API CheckpointTensorImpl final : public TensorImpl { +void DTRLog(const std::string& str); + +struct CAFFE2_API CheckpointTensorCell : intrusive_ptr_target { Tensor t; - explicit CheckpointTensorImpl(const Tensor& t) : - TensorImpl(t.key_set(), t.dtype(), t.optional_device()), - t(t) { } + explicit CheckpointTensorCell(const Tensor& t) : t(t.detach()) { } + int id = gen_counter(); + static int counter; + static int gen_counter() { + return counter++; + } + std::string name() { + return std::string("x") + std::to_string(id); + } +}; + +struct CAFFE2_API CheckpointTensorImplCell : intrusive_ptr_target { + mutable intrusive_ptr value; + explicit CheckpointTensorImplCell(const intrusive_ptr& value) : value(value) { } + explicit CheckpointTensorImplCell(const Tensor& t) : value(intrusive_ptr::make(t)) { } + void release_resources() final { + value.reset(); + } +}; + +class CheckpointTensorCell; +using strong = intrusive_ptr; +using strongs = std::vector; +using weak = weak_intrusive_ptr; +using weaks = std::vector; +using Tensors = std::vector; +using rematerialize_function_t = std::function; +using mutate_function_t = std::function; + +inline DispatchKeySet convert_key_set(const DispatchKeySet& t) { + auto ret = t.add(DispatchKey::CheckpointTensorId); + CHECK(!ret.has(DispatchKey::VariableTensorId)); + return ret; +} + +struct CAFFE2_API CheckpointTensorImpl : TensorImpl { + intrusive_ptr ref; + void release_resources() final { + ref.reset(); + } + explicit CheckpointTensorImpl(const intrusive_ptr& ref) : TensorImpl(convert_key_set(ref->value->t.key_set()), + ref->value->t.dtype(), + ref->value->t.optional_device()), ref(ref) { } + explicit CheckpointTensorImpl(const Tensor& t) : CheckpointTensorImpl(intrusive_ptr::make(t)) { } + static Tensors make(const std::string& name, + const rematerialize_function_t& remat, + const strongs& input_values); + static void mutate(const std::string& name, + const mutate_function_t& mutate, + const Tensors& input_values); }; +inline CheckpointTensorImpl* get_cpti(const Tensor& t) { + auto* cpti = dynamic_cast(t.unsafeGetTensorImpl()); + TORCH_CHECK(cpti != nullptr); + return cpti; +} + +inline strong from_tensor(const Tensor& t) { + auto* cpt = dynamic_cast(t.unsafeGetTensorImpl()); + if(cpt != nullptr) { + return cpt->ref->value; + } else { + return get_cpti(native::checkpoint(t))->ref->value; + } +} + +inline Tensor get(const strong& s) { + return s->t; +} + +inline intrusive_ptr cell_from_tensor(const Tensor& t) { + return get_cpti(t)->ref; +} + } diff --git a/aten/src/ATen/native/Checkpoint.cpp b/aten/src/ATen/native/Checkpoint.cpp index 2cb5b568ff6..15099b1a497 100644 --- a/aten/src/ATen/native/Checkpoint.cpp +++ b/aten/src/ATen/native/Checkpoint.cpp @@ -10,7 +10,7 @@ Tensor checkpoint(const Tensor& t) { Tensor decheckpoint(const Tensor& t) { auto* cpti = dynamic_cast(t.unsafeGetTensorImpl()); CHECK(cpti != nullptr); - return cpti->t; + return cpti->ref->value->t; } bool is_checkpoint(const Tensor& t) { @@ -19,7 +19,229 @@ bool is_checkpoint(const Tensor& t) { } Tensor checkpoint_add(at::Tensor const& a, at::Tensor const& b, c10::Scalar c) { - return checkpoint(at::add(decheckpoint(a), decheckpoint(b), c)); + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::add(vec.at(0), vec.at(1), c)}; + }; + strongs s = {from_tensor(a), from_tensor(b)}; + return CheckpointTensorImpl::make("add", rt, s)[0]; +} + +Tensor& checkpoint_add_(Tensor& a, const Tensor& b, Scalar c) { + mutate_function_t mt = + [=](const Tensors& vec) { + vec.at(0).add_(vec.at(1), c); + }; + CheckpointTensorImpl::mutate("add_", mt, {a, b}); + return a; +} + +Tensor checkpoint_abs(at::Tensor const& a) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::abs(vec.at(0))}; + }; + strongs s = {from_tensor(a)}; + return CheckpointTensorImpl::make("abs", rt, s)[0]; +} + +Tensor checkpoint_div(at::Tensor const& a, at::Tensor const& b) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::div(vec.at(0), vec.at(1))}; + }; + strongs s = {from_tensor(a), from_tensor(b)}; + return CheckpointTensorImpl::make("div", rt, s)[0]; +} + +Tensor& checkpoint_div_(Tensor& a, const Tensor& b) { + mutate_function_t mt = + [=](const Tensors& vec) { + vec.at(0).div_(vec.at(1)); + }; + CheckpointTensorImpl::mutate("div_", mt, {a, b}); + return a; +} + +Tensor checkpoint_constant_pad_nd(Tensor const& a, c10::ArrayRef b, c10::Scalar c) { + std::vector b_ = b.vec(); + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::constant_pad_nd(vec.at(0), b_, c)}; + }; + strongs s = {from_tensor(a)}; + return CheckpointTensorImpl::make("constant_pad_nd", rt, s)[0]; +} + +Tensor checkpoint_binary_cross_entropy(at::Tensor const& a, at::Tensor const& b, at::Tensor const& c, long d) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::binary_cross_entropy(vec.at(0), vec.at(1), vec.at(2), d)}; + }; + strongs s = {from_tensor(a), from_tensor(b), from_tensor(c)}; + return CheckpointTensorImpl::make("binary_cross_entropy", rt, s)[0]; +} + +Tensor& checkpoint_binary_cross_entropy_out(at::Tensor& a, at::Tensor const& b, at::Tensor const& c, at::Tensor const& d, long e) { + mutate_function_t mt = + [=](const Tensors& vec) { + Tensor self = vec.at(0); + at::binary_cross_entropy_out(self, vec.at(1), vec.at(2), vec.at(3), e); + }; + CheckpointTensorImpl::mutate("binary_cross_entropy_out", mt, {a, b, c, d}); + return a; +} + +Tensor checkpoint_binary_cross_entropy_backward(at::Tensor const& a, at::Tensor const& b, at::Tensor const& c, at::Tensor const& d, long e) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::binary_cross_entropy_backward(vec.at(0), vec.at(1), vec.at(2), vec.at(3), e)}; + }; + strongs s = {from_tensor(a), from_tensor(b), from_tensor(c), from_tensor(d)}; + return CheckpointTensorImpl::make("binary_cross_entropy_backward", rt, s)[0]; +} + +Tensor& checkpoint_binary_cross_entropy_backward_out(at::Tensor& a, at::Tensor const& b, at::Tensor const& c, at::Tensor const& d, at::Tensor const& e, long f) { + mutate_function_t mt = + [=](const Tensors& vec) { + Tensor self = vec.at(0); + at::binary_cross_entropy_backward_out(self, vec.at(1), vec.at(2), vec.at(3), vec.at(4), f); + }; + CheckpointTensorImpl::mutate("binary_cross_entropy_backward_out", mt, {a, b, c, d, e}); + return a; +} + +Tensor checkpoint_embedding(at::Tensor const& a, at::Tensor const& b, long c, bool d, bool e) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::embedding(vec.at(0), vec.at(1), c, d, e)}; + }; + strongs s = {from_tensor(a), from_tensor(b)}; + return CheckpointTensorImpl::make("embedding", rt, s)[0]; +} + +Tensor checkpoint_embedding_backward(at::Tensor const& a, at::Tensor const& b, long c, long d, bool e, bool f) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::embedding_backward(vec.at(0), vec.at(1), c, d, e, f)}; + }; + strongs s = {from_tensor(a), from_tensor(b)}; + return CheckpointTensorImpl::make("embedding", rt, s)[0]; +} + +std::tuple +checkpoint_cudnn_batch_norm(at::Tensor const& a, at::Tensor const& b, at::Tensor const& c, at::Tensor const& d, at::Tensor const& e, bool f, double g, double h) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + auto ret = at::cudnn_batch_norm(vec.at(0), vec.at(1), vec.at(2), vec.at(3), vec.at(4), f, g, h); + return {std::get<0>(ret), std::get<1>(ret), std::get<2>(ret), std::get<3>(ret)}; + }; + strongs s = {from_tensor(a), from_tensor(b), from_tensor(c), from_tensor(d), from_tensor(e)}; + auto ret = CheckpointTensorImpl::make("cudnn_batch_norm", rt, s)[0]; + return {ret[0], ret[1], ret[2], ret[3]}; +} + +Tensor checkpoint_as_strided(at::Tensor const& a, c10::ArrayRef b, c10::ArrayRef c, c10::optional d) { + std::vector b_ = b.vec(), c_ = c.vec(); + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::as_strided(vec.at(0), b_, c_, d)}; + }; + strongs s = {from_tensor(a)}; + return CheckpointTensorImpl::make("as_strided", rt, s)[0]; +} + +Tensor checkpoint__masked_scale(at::Tensor const& a, at::Tensor const& b, double c) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::_masked_scale(vec.at(0), vec.at(1), c)}; + }; + strongs s = {from_tensor(a), from_tensor(b)}; + return CheckpointTensorImpl::make("_masked_scale", rt, s)[0]; +} + +Tensor checkpoint_cudnn_convolution(at::Tensor const& a, at::Tensor const& b, c10::ArrayRef c, c10::ArrayRef d, c10::ArrayRef e, long f, bool g, bool h) { + std::vector c_ = c.vec(), d_ = d.vec(), e_ = e.vec(); + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::cudnn_convolution(vec.at(0), vec.at(1), c_, d_, e_, f, g, h)}; + }; + strongs s = {from_tensor(a), from_tensor(b)}; + return CheckpointTensorImpl::make("cudnn_convolution", rt, s)[0]; +} + +Tensor checkpoint_cudnn_convolution_transpose(at::Tensor const& a, at::Tensor const& b, c10::ArrayRef c, c10::ArrayRef d, c10::ArrayRef e, c10::ArrayRef f, long g, bool h, bool i) { + std::vector c_ = c.vec(), d_ = d.vec(), e_ = e.vec(), f_ = f.vec(); + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::cudnn_convolution_transpose(vec.at(0), vec.at(1), c_, d_, e_, f_, g, h, i)}; + }; + strongs s = {from_tensor(a), from_tensor(b)}; + return CheckpointTensorImpl::make("cudnn_convolution_transpose", rt, s)[0]; +} + +std::tuple checkpoint_cudnn_convolution_backward(at::Tensor const& a, at::Tensor const& b, at::Tensor const& c, c10::ArrayRef d, c10::ArrayRef e, c10::ArrayRef f, long g, bool h, bool i, std::array j) { + std::vector d_ = d.vec(), e_ = e.vec(), f_ = f.vec(); + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + auto ret = at::cudnn_convolution_backward(vec.at(0), vec.at(1), vec.at(2), d_, e_, f_, g, h, i, j); + return {std::get<0>(ret), std::get<1>(ret)}; + }; + strongs s = {from_tensor(a), from_tensor(b), from_tensor(c)}; + auto ret = CheckpointTensorImpl::make("cudnn_convolution_backward", rt, s); + return {ret[0], ret[1]}; +} + +std::tuple checkpoint_cudnn_convolution_transpose_backward(at::Tensor const& a, at::Tensor const& b, at::Tensor const& c, c10::ArrayRef d, c10::ArrayRef e, c10::ArrayRef f, c10::ArrayRef g, long h, bool i, bool j, std::array k) { + std::vector d_ = d.vec(), e_ = e.vec(), f_ = f.vec(), g_ = g.vec(); + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + auto ret = at::cudnn_convolution_transpose_backward(vec.at(0), vec.at(1), vec.at(2), d_, e_, f_, g_, h, i, j, k); + return {std::get<0>(ret), std::get<1>(ret)}; + }; + strongs s = {from_tensor(a), from_tensor(b), from_tensor(c)}; + auto ret = CheckpointTensorImpl::make("cudnn_convolution_transpose_backward", rt, s); + return {ret[0], ret[1]}; +} + +Tensor checkpoint_cudnn_convolution_backward_input(c10::ArrayRef a, at::Tensor const& b, at::Tensor const& c, c10::ArrayRef d, c10::ArrayRef e, c10::ArrayRef f, long g, bool h, bool i) { + std::vector a_ = a.vec(), d_ = d.vec(), e_ = e.vec(), f_ = f.vec(); + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::cudnn_convolution_backward_input(a_, vec.at(0), vec.at(1), d_, e_, f_, g, h, i)}; + }; + strongs s = {from_tensor(b), from_tensor(c)}; + return CheckpointTensorImpl::make("cudnn_convolution_backward_input", rt, s)[0]; +} + +Tensor checkpoint_cudnn_convolution_transpose_backward_input(at::Tensor const& a, at::Tensor const& b, c10::ArrayRef c, c10::ArrayRef d, c10::ArrayRef e, long f, bool g, bool h) { + std::vector c_ = c.vec(), d_ = d.vec(), e_ = e.vec(); + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::cudnn_convolution_transpose_backward_input(vec.at(0), vec.at(1), c_, d_, e_, f, g, h)}; + }; + strongs s = {from_tensor(a), from_tensor(b)}; + return CheckpointTensorImpl::make("cudnn_convolution_transpose_backward_input", rt, s)[0]; +} + +Tensor checkpoint_cudnn_convolution_backward_weight(c10::ArrayRef a, at::Tensor const& b, at::Tensor const& c, c10::ArrayRef d, c10::ArrayRef e, c10::ArrayRef f, long g, bool h, bool i) { + std::vector a_ = a.vec(), d_ = d.vec(), e_ = e.vec(), f_ = f.vec(); + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::cudnn_convolution_backward_weight(a_, vec.at(0), vec.at(1), d_, e_, f_, g, h, i)}; + }; + strongs s = {from_tensor(b), from_tensor(c)}; + return CheckpointTensorImpl::make("cudnn_convolution_backward_weight", rt, s)[0]; +} + +Tensor checkpoint_cudnn_convolution_transpose_backward_weight(c10::ArrayRef a, at::Tensor const& b, at::Tensor const& c, c10::ArrayRef d, c10::ArrayRef e, c10::ArrayRef f, long g, bool h, bool i) { + std::vector a_ = a.vec(), d_ = d.vec(), e_ = e.vec(), f_ = f.vec(); + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::cudnn_convolution_transpose_backward_weight(a_, vec.at(0), vec.at(1), d_, e_, f_, g, h, i)}; + }; + strongs s = {from_tensor(b), from_tensor(c)}; + return CheckpointTensorImpl::make("cudnn_convolution_transpose_backward_weight", rt, s)[0]; } }} diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 4561891d82f..683ef1c5dbc 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -179,6 +179,7 @@ variants: function dispatch: CUDA: masked_scale_cuda + Checkpoint: checkpoint__masked_scale - func: _sobol_engine_draw(Tensor quasi, int n, Tensor sobolstate, int dimension, int num_generated, ScalarType? dtype) -> (Tensor, Tensor) @@ -225,6 +226,10 @@ use_c10_dispatcher: full variants: function, method supports_named_tensor: True + dispatch: + CUDA: abs + CPU: abs + Checkpoint: checkpoint_abs - func: abs_(Tensor(a!) self) -> Tensor(a!) variants: function, method @@ -304,6 +309,7 @@ SparseCPU: add_sparse_ SparseCUDA: add_sparse_ MkldnnCPU: mkldnn_add_ + Checkpoint: checkpoint_add_ supports_named_tensor: True - func: add.out(Tensor self, Tensor other, *, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!) @@ -426,6 +432,7 @@ CPU: as_strided_tensorimpl CUDA: as_strided_tensorimpl QuantizedCPU: as_strided_qtensorimpl + Checkpoint: checkpoint_as_strided device_guard: False supports_named_tensor: True @@ -537,6 +544,7 @@ dispatch: CPU: binary_cross_entropy_cpu CUDA: binary_cross_entropy_cuda + Checkpoint: checkpoint_binary_cross_entropy - func: binary_cross_entropy.out(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean, *, Tensor(a!) out) -> Tensor(a!) python_module: nn @@ -544,6 +552,7 @@ dispatch: CPU: binary_cross_entropy_out_cpu CUDA: binary_cross_entropy_out_cuda + Checkpoint: checkpoint_binary_cross_entropy_out - func: binary_cross_entropy_backward(Tensor grad_output, Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean) -> Tensor python_module: nn @@ -551,6 +560,7 @@ dispatch: CPU: binary_cross_entropy_backward_cpu CUDA: binary_cross_entropy_backward_cuda + Checkpoint: checkpoint_binary_cross_entropy_backward - func: binary_cross_entropy_backward.grad_input(Tensor grad_output, Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean, *, Tensor(a!) grad_input) -> Tensor(a!) python_module: nn @@ -558,6 +568,7 @@ dispatch: CPU: binary_cross_entropy_backward_out_cpu CUDA: binary_cross_entropy_backward_out_cuda + Checkpoint: checkpoint_binary_cross_entropy_backward_out - func: binary_cross_entropy_with_logits(Tensor self, Tensor target, Tensor? weight=None, Tensor? pos_weight=None, int reduction=Mean) -> Tensor variants: function @@ -763,6 +774,10 @@ - func: constant_pad_nd(Tensor self, int[] pad, Scalar value=0) -> Tensor variants: function + dispatch: + CPU: constant_pad_nd + CUDA: constant_pad_nd + Checkpoint: checkpoint_constant_pad_nd - func: contiguous(Tensor self, *, MemoryFormat memory_format=contiguous_format) -> Tensor variants: method @@ -865,11 +880,13 @@ - func: cudnn_batch_norm(Tensor input, Tensor weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float exponential_average_factor, float epsilon) -> (Tensor, Tensor, Tensor, Tensor) dispatch: CUDA: cudnn_batch_norm + Checkpoint: checkpoint_cudnn_batch_norm # NB: You can only use this if you used cudnn_batch_norm training=True - func: cudnn_batch_norm_backward(Tensor input, Tensor grad_output, Tensor weight, Tensor? running_mean, Tensor? running_var, Tensor? save_mean, Tensor? save_var, float epsilon, Tensor reserveSpace) -> (Tensor, Tensor, Tensor) dispatch: CUDA: cudnn_batch_norm_backward + Checkpoint: cudnn_batch_norm_backward - func: cudnn_convolution.deprecated(Tensor self, Tensor weight, Tensor? bias, int[] padding, int[] stride, int[] dilation, int groups, bool benchmark, bool deterministic) -> Tensor dispatch: @@ -878,18 +895,22 @@ - func: cudnn_convolution(Tensor self, Tensor weight, int[] padding, int[] stride, int[] dilation, int groups, bool benchmark, bool deterministic) -> Tensor dispatch: CUDA: cudnn_convolution + Checkpoint: checkpoint_cudnn_convolution - func: cudnn_convolution_backward_input(int[] self_size, Tensor grad_output, Tensor weight, int[] padding, int[] stride, int[] dilation, int groups, bool benchmark, bool deterministic) -> Tensor dispatch: CUDA: cudnn_convolution_backward_input + Checkpoint: checkpoint_cudnn_convolution_backward_input - func: cudnn_convolution_backward(Tensor self, Tensor grad_output, Tensor weight, int[] padding, int[] stride, int[] dilation, int groups, bool benchmark, bool deterministic, bool[2] output_mask) -> (Tensor, Tensor) dispatch: CUDA: cudnn_convolution_backward + Checkpoint: checkpoint_cudnn_convolution_backward - func: cudnn_convolution_backward_weight(int[] weight_size, Tensor grad_output, Tensor self, int[] padding, int[] stride, int[] dilation, int groups, bool benchmark, bool deterministic) -> Tensor dispatch: CUDA: cudnn_convolution_backward_weight + Checkpoint: checkpoint_cudnn_convolution_backward_weight - func: cudnn_convolution_transpose.deprecated(Tensor self, Tensor weight, Tensor? bias, int[] padding, int[] output_padding, int[] stride, int[] dilation, int groups, bool benchmark, bool deterministic) -> Tensor dispatch: @@ -898,20 +919,24 @@ - func: cudnn_convolution_transpose(Tensor self, Tensor weight, int[] padding, int[] output_padding, int[] stride, int[] dilation, int groups, bool benchmark, bool deterministic) -> Tensor dispatch: CUDA: cudnn_convolution_transpose + Checkpoint: checkpoint_cudnn_convolution_transpose # NB: output_padding not strictly needed here, but it's helpful for the float # backwards - func: cudnn_convolution_transpose_backward(Tensor self, Tensor grad_output, Tensor weight, int[] padding, int[] output_padding, int[] stride, int[] dilation, int groups, bool benchmark, bool deterministic, bool[2] output_mask) -> (Tensor, Tensor) dispatch: CUDA: cudnn_convolution_transpose_backward + Checkpoint: checkpoint_cudnn_convolution_transpose_backward - func: cudnn_convolution_transpose_backward_input(Tensor grad_output, Tensor weight, int[] padding, int[] stride, int[] dilation, int groups, bool benchmark, bool deterministic) -> Tensor dispatch: CUDA: cudnn_convolution_transpose_backward_input + Checkpoint: checkpoint_cudnn_convolution_transpose_backward_input - func: cudnn_convolution_transpose_backward_weight(int[] weight_size, Tensor grad_output, Tensor self, int[] padding, int[] stride, int[] dilation, int groups, bool benchmark, bool deterministic) -> Tensor dispatch: CUDA: cudnn_convolution_transpose_backward_weight + Checkpoint: checkpoint_cudnn_convolution_transpose_backward_weight # NB: input is special cased in a way I don't quite understand - func: cudnn_grid_sampler(Tensor self, Tensor grid) -> Tensor output @@ -1038,6 +1063,7 @@ CUDA: div SparseCPU: div_sparse SparseCUDA: div_sparse + Checkpoint: checkpoint_div supports_named_tensor: True - func: div_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) @@ -1047,6 +1073,7 @@ CUDA: div_ SparseCPU: div_sparse_ SparseCUDA: div_sparse_ + Checkpoint: checkpoint_div_ supports_named_tensor: True - func: div.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) @@ -1082,9 +1109,17 @@ - func: embedding(Tensor weight, Tensor indices, int padding_idx=-1, bool scale_grad_by_freq=False, bool sparse=False) -> Tensor use_c10_dispatcher: full + dispatch: + CPU: embedding + CUDA: embedding + Checkpoint: checkpoint_embedding - func: embedding_backward(Tensor grad, Tensor indices, int num_weights, int padding_idx, bool scale_grad_by_freq, bool sparse) -> Tensor use_c10_dispatcher: full + dispatch: + CPU: embedding_backward + CUDA: embedding_backward + Checkpoint: checkpoint_embedding_backward - func: embedding_dense_backward(Tensor grad_output, Tensor indices, int num_weights, int padding_idx, bool scale_grad_by_freq) -> Tensor use_c10_dispatcher: full From d09fb752417fe3720642ed7fab95a7f8e4e8d190 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9B=BE=E9=9B=A8=E9=AD=94=E7=90=86=E6=B2=99?= Date: Wed, 25 Mar 2020 17:16:33 -0700 Subject: [PATCH 04/42] make resnet pass again, adding more logs for logging simulator (#17) * commit * add more overloads * fix log * save --- aten/src/ATen/CheckpointTensorImpl.cpp | 51 +- aten/src/ATen/CheckpointTensorImpl.h | 51 +- aten/src/ATen/native/Activation.cpp | 12 + aten/src/ATen/native/Checkpoint.cpp | 512 ++++++++++++++++++--- aten/src/ATen/native/native_functions.yaml | 71 ++- c10/core/TensorImpl.h | 3 + tools/autograd/templates/Functions.cpp | 12 - 7 files changed, 605 insertions(+), 107 deletions(-) diff --git a/aten/src/ATen/CheckpointTensorImpl.cpp b/aten/src/ATen/CheckpointTensorImpl.cpp index a3ddf28a5ef..085fe73ee59 100644 --- a/aten/src/ATen/CheckpointTensorImpl.cpp +++ b/aten/src/ATen/CheckpointTensorImpl.cpp @@ -25,12 +25,13 @@ void DTRLog(const std::string& str) { logger.out << str << std::endl; } -int CheckpointTensorCell::counter = 0; +int CheckpointTensorImpl::counter = 0; Tensors make_raw(const rematerialize_function_t& remat, const strongs& input_values) { std::vector input; for (const strong& s: input_values) { + CHECK(!s->t.key_set().has(DispatchKey::CheckpointTensorId)); input.push_back(s->t); } auto output = remat(input); @@ -43,41 +44,55 @@ Tensors make_raw(const rematerialize_function_t& remat, Tensors CheckpointTensorImpl::make(const std::string& name, const rematerialize_function_t& remat, - const strongs& input_values) { + const Tensors& input) { + strongs input_values; + std::string arg = name + "("; + for (const Tensor& t: input) { + auto ft = from_tensor(t); + input_values.push_back(std::get<0>(ft)); + arg += std::get<1>(ft); + arg += ", "; + } + arg += ")"; + std::string log = "("; Tensors ret = make_raw(remat, input_values); - std::string log("("); for (const Tensor& t: ret) { - log += cell_from_tensor(t)->value->name(); + log += get_cpti(t)->counter_name(); log += ", "; } log += ") = "; - log += name; - log += "("; - for (const strong& s: input_values) { - log += s->name(); - log += ", "; - } - log += ")"; + log += arg; DTRLog(log); return ret; } void CheckpointTensorImpl::mutate(const std::string& name, const mutate_function_t& mutate, - const Tensors& inputs) { + const Tensors& inputs, + const std::vector& mutate_idx) { auto remat = [=](const Tensors& t) -> Tensors { - auto t0 = t[0].clone(); Tensors new_input_values = t; - new_input_values[0] = t0; + for (size_t idx: mutate_idx) { + new_input_values[idx] = t[idx].clone(); + } mutate(new_input_values); - return {t0}; + return new_input_values; }; strongs input_values; + std::string log = name; + log += "("; for (const Tensor& t : inputs) { - input_values.push_back(from_tensor(t)); + auto ft = from_tensor(t); + log += std::get<1>(ft); + log += ", "; + input_values.push_back(std::get<0>(ft)); + } + log += ")"; + DTRLog(log); + auto modified = make_raw(remat, input_values); + for (size_t idx: mutate_idx) { + cell_from_tensor(inputs[idx])->value = cell_from_tensor(modified[idx])->value; } - auto modified = make_raw(remat, input_values)[0]; - cell_from_tensor(inputs[0])->value = cell_from_tensor(modified)->value; } } diff --git a/aten/src/ATen/CheckpointTensorImpl.h b/aten/src/ATen/CheckpointTensorImpl.h index 747590cdb71..de83e60ab14 100644 --- a/aten/src/ATen/CheckpointTensorImpl.h +++ b/aten/src/ATen/CheckpointTensorImpl.h @@ -28,14 +28,6 @@ void DTRLog(const std::string& str); struct CAFFE2_API CheckpointTensorCell : intrusive_ptr_target { Tensor t; explicit CheckpointTensorCell(const Tensor& t) : t(t.detach()) { } - int id = gen_counter(); - static int counter; - static int gen_counter() { - return counter++; - } - std::string name() { - return std::string("x") + std::to_string(id); - } }; struct CAFFE2_API CheckpointTensorImplCell : intrusive_ptr_target { @@ -57,12 +49,21 @@ using rematerialize_function_t = std::function; using mutate_function_t = std::function; inline DispatchKeySet convert_key_set(const DispatchKeySet& t) { + CHECK(!t.has(DispatchKey::CheckpointTensorId)); auto ret = t.add(DispatchKey::CheckpointTensorId); CHECK(!ret.has(DispatchKey::VariableTensorId)); return ret; } struct CAFFE2_API CheckpointTensorImpl : TensorImpl { + int id = gen_counter(); + static int counter; + static int gen_counter() { + return counter++; + } + std::string counter_name() { + return std::string("x") + std::to_string(id); + } intrusive_ptr ref; void release_resources() final { ref.reset(); @@ -73,10 +74,34 @@ struct CAFFE2_API CheckpointTensorImpl : TensorImpl { explicit CheckpointTensorImpl(const Tensor& t) : CheckpointTensorImpl(intrusive_ptr::make(t)) { } static Tensors make(const std::string& name, const rematerialize_function_t& remat, - const strongs& input_values); + const Tensors& inputs); + // mutate_idx indicate which of the inputs will get mutated. static void mutate(const std::string& name, const mutate_function_t& mutate, - const Tensors& input_values); + const Tensors& inputs, + const std::vector& mutate_idx); + intrusive_ptr shallow_copy_and_detach(const VariableVersion& version_counter, + bool allow_tensor_metadata_change) const override { + return intrusive_ptr::make(ref); + } + int64_t dim() const override { + return ref->value->t.dim(); + } + int64_t numel() const override { + return ref->value->t.numel(); + } + IntArrayRef sizes() const override { + return ref->value->t.sizes(); + } + int64_t size(int64_t d) const override { + return ref->value->t.size(d); + } + IntArrayRef strides() const override { + return ref->value->t.strides(); + } + bool has_storage() const override { + return false; + } }; inline CheckpointTensorImpl* get_cpti(const Tensor& t) { @@ -85,12 +110,12 @@ inline CheckpointTensorImpl* get_cpti(const Tensor& t) { return cpti; } -inline strong from_tensor(const Tensor& t) { +inline std::tuple from_tensor(const Tensor& t) { auto* cpt = dynamic_cast(t.unsafeGetTensorImpl()); if(cpt != nullptr) { - return cpt->ref->value; + return {cpt->ref->value, cpt->counter_name()}; } else { - return get_cpti(native::checkpoint(t))->ref->value; + return from_tensor(native::checkpoint(t)); } } diff --git a/aten/src/ATen/native/Activation.cpp b/aten/src/ATen/native/Activation.cpp index c97aa5a941b..a5e64344946 100644 --- a/aten/src/ATen/native/Activation.cpp +++ b/aten/src/ATen/native/Activation.cpp @@ -712,4 +712,16 @@ Tensor& log_sigmoid_backward_out_cpu( DEFINE_DISPATCH(GeluKernel); DEFINE_DISPATCH(GeluBackwardKernel); +Tensor slice_backward(const Tensor& grad, IntArrayRef input_sizes, int64_t dim, int64_t start, int64_t end, int64_t step) { + auto grad_input = at::zeros(input_sizes, grad.options()); + grad_input.slice(dim, start, end, step).copy_(grad); + return grad_input; +} + +Tensor select_backward(const Tensor& grad, IntArrayRef input_sizes, int64_t dim, int64_t index) { + auto grad_input = at::zeros(input_sizes, grad.options()); + grad_input.select(dim, index).copy_(grad); + return grad_input; +} + }} // namespace at::native diff --git a/aten/src/ATen/native/Checkpoint.cpp b/aten/src/ATen/native/Checkpoint.cpp index 15099b1a497..f09a7f7e4e7 100644 --- a/aten/src/ATen/native/Checkpoint.cpp +++ b/aten/src/ATen/native/Checkpoint.cpp @@ -18,13 +18,12 @@ bool is_checkpoint(const Tensor& t) { return cpti != nullptr; } -Tensor checkpoint_add(at::Tensor const& a, at::Tensor const& b, c10::Scalar c) { +Tensor checkpoint_add(const Tensor& a, const Tensor& b, c10::Scalar c) { rematerialize_function_t rt = [=](const Tensors& vec) -> Tensors { return {at::add(vec.at(0), vec.at(1), c)}; }; - strongs s = {from_tensor(a), from_tensor(b)}; - return CheckpointTensorImpl::make("add", rt, s)[0]; + return CheckpointTensorImpl::make("add", rt, {a, b})[0]; } Tensor& checkpoint_add_(Tensor& a, const Tensor& b, Scalar c) { @@ -32,26 +31,24 @@ Tensor& checkpoint_add_(Tensor& a, const Tensor& b, Scalar c) { [=](const Tensors& vec) { vec.at(0).add_(vec.at(1), c); }; - CheckpointTensorImpl::mutate("add_", mt, {a, b}); + CheckpointTensorImpl::mutate("add_", mt, {a, b}, {0}); return a; } -Tensor checkpoint_abs(at::Tensor const& a) { +Tensor checkpoint_abs(const Tensor& a) { rematerialize_function_t rt = [=](const Tensors& vec) -> Tensors { return {at::abs(vec.at(0))}; }; - strongs s = {from_tensor(a)}; - return CheckpointTensorImpl::make("abs", rt, s)[0]; + return CheckpointTensorImpl::make("abs", rt, {a})[0]; } -Tensor checkpoint_div(at::Tensor const& a, at::Tensor const& b) { +Tensor checkpoint_div(const Tensor& a, const Tensor& b) { rematerialize_function_t rt = [=](const Tensors& vec) -> Tensors { return {at::div(vec.at(0), vec.at(1))}; }; - strongs s = {from_tensor(a), from_tensor(b)}; - return CheckpointTensorImpl::make("div", rt, s)[0]; + return CheckpointTensorImpl::make("div", rt, {a, b})[0]; } Tensor& checkpoint_div_(Tensor& a, const Tensor& b) { @@ -59,7 +56,7 @@ Tensor& checkpoint_div_(Tensor& a, const Tensor& b) { [=](const Tensors& vec) { vec.at(0).div_(vec.at(1)); }; - CheckpointTensorImpl::mutate("div_", mt, {a, b}); + CheckpointTensorImpl::mutate("div_", mt, {a, b}, {0}); return a; } @@ -69,179 +66,568 @@ Tensor checkpoint_constant_pad_nd(Tensor const& a, c10::ArrayRef b, c10::S [=](const Tensors& vec) -> Tensors { return {at::constant_pad_nd(vec.at(0), b_, c)}; }; - strongs s = {from_tensor(a)}; - return CheckpointTensorImpl::make("constant_pad_nd", rt, s)[0]; + return CheckpointTensorImpl::make("constant_pad_nd", rt, {a})[0]; } -Tensor checkpoint_binary_cross_entropy(at::Tensor const& a, at::Tensor const& b, at::Tensor const& c, long d) { +Tensor checkpoint_binary_cross_entropy(const Tensor& a, const Tensor& b, const Tensor& c, long d) { rematerialize_function_t rt = [=](const Tensors& vec) -> Tensors { return {at::binary_cross_entropy(vec.at(0), vec.at(1), vec.at(2), d)}; }; - strongs s = {from_tensor(a), from_tensor(b), from_tensor(c)}; - return CheckpointTensorImpl::make("binary_cross_entropy", rt, s)[0]; + return CheckpointTensorImpl::make("binary_cross_entropy", rt, {a, b, c})[0]; } -Tensor& checkpoint_binary_cross_entropy_out(at::Tensor& a, at::Tensor const& b, at::Tensor const& c, at::Tensor const& d, long e) { +Tensor& checkpoint_binary_cross_entropy_out(Tensor& a, const Tensor& b, const Tensor& c, const Tensor& d, long e) { mutate_function_t mt = [=](const Tensors& vec) { Tensor self = vec.at(0); at::binary_cross_entropy_out(self, vec.at(1), vec.at(2), vec.at(3), e); }; - CheckpointTensorImpl::mutate("binary_cross_entropy_out", mt, {a, b, c, d}); + CheckpointTensorImpl::mutate("binary_cross_entropy_out", mt, {a, b, c, d}, {0}); return a; } -Tensor checkpoint_binary_cross_entropy_backward(at::Tensor const& a, at::Tensor const& b, at::Tensor const& c, at::Tensor const& d, long e) { +Tensor checkpoint_binary_cross_entropy_backward(const Tensor& a, const Tensor& b, const Tensor& c, const Tensor& d, long e) { rematerialize_function_t rt = [=](const Tensors& vec) -> Tensors { return {at::binary_cross_entropy_backward(vec.at(0), vec.at(1), vec.at(2), vec.at(3), e)}; }; - strongs s = {from_tensor(a), from_tensor(b), from_tensor(c), from_tensor(d)}; - return CheckpointTensorImpl::make("binary_cross_entropy_backward", rt, s)[0]; + return CheckpointTensorImpl::make("binary_cross_entropy_backward", rt, {a, b, c, d})[0]; } -Tensor& checkpoint_binary_cross_entropy_backward_out(at::Tensor& a, at::Tensor const& b, at::Tensor const& c, at::Tensor const& d, at::Tensor const& e, long f) { +Tensor& checkpoint_binary_cross_entropy_backward_out(Tensor& a, const Tensor& b, const Tensor& c, const Tensor& d, const Tensor& e, long f) { mutate_function_t mt = [=](const Tensors& vec) { Tensor self = vec.at(0); at::binary_cross_entropy_backward_out(self, vec.at(1), vec.at(2), vec.at(3), vec.at(4), f); }; - CheckpointTensorImpl::mutate("binary_cross_entropy_backward_out", mt, {a, b, c, d, e}); + CheckpointTensorImpl::mutate("binary_cross_entropy_backward_out", mt, {a, b, c, d, e}, {0}); return a; } -Tensor checkpoint_embedding(at::Tensor const& a, at::Tensor const& b, long c, bool d, bool e) { +Tensor checkpoint_embedding(const Tensor& a, const Tensor& b, long c, bool d, bool e) { rematerialize_function_t rt = [=](const Tensors& vec) -> Tensors { return {at::embedding(vec.at(0), vec.at(1), c, d, e)}; }; - strongs s = {from_tensor(a), from_tensor(b)}; - return CheckpointTensorImpl::make("embedding", rt, s)[0]; + return CheckpointTensorImpl::make("embedding", rt, {a, b})[0]; } -Tensor checkpoint_embedding_backward(at::Tensor const& a, at::Tensor const& b, long c, long d, bool e, bool f) { +Tensor checkpoint_embedding_backward(const Tensor& a, const Tensor& b, long c, long d, bool e, bool f) { rematerialize_function_t rt = [=](const Tensors& vec) -> Tensors { return {at::embedding_backward(vec.at(0), vec.at(1), c, d, e, f)}; }; - strongs s = {from_tensor(a), from_tensor(b)}; - return CheckpointTensorImpl::make("embedding", rt, s)[0]; + return CheckpointTensorImpl::make("embedding", rt, {a, b})[0]; } -std::tuple -checkpoint_cudnn_batch_norm(at::Tensor const& a, at::Tensor const& b, at::Tensor const& c, at::Tensor const& d, at::Tensor const& e, bool f, double g, double h) { +std::tuple +checkpoint_cudnn_batch_norm(const Tensor& a, const Tensor& b, const Tensor& c, const Tensor& d, const Tensor& e, bool f, double g, double h) { rematerialize_function_t rt = [=](const Tensors& vec) -> Tensors { auto ret = at::cudnn_batch_norm(vec.at(0), vec.at(1), vec.at(2), vec.at(3), vec.at(4), f, g, h); return {std::get<0>(ret), std::get<1>(ret), std::get<2>(ret), std::get<3>(ret)}; }; - strongs s = {from_tensor(a), from_tensor(b), from_tensor(c), from_tensor(d), from_tensor(e)}; - auto ret = CheckpointTensorImpl::make("cudnn_batch_norm", rt, s)[0]; + auto ret = CheckpointTensorImpl::make("cudnn_batch_norm", rt, {a, b, c, d, e}); return {ret[0], ret[1], ret[2], ret[3]}; } -Tensor checkpoint_as_strided(at::Tensor const& a, c10::ArrayRef b, c10::ArrayRef c, c10::optional d) { +std::tuple checkpoint_cudnn_batch_norm_backward(at::Tensor const& a, at::Tensor const& b, at::Tensor const& c, at::Tensor const& d, at::Tensor const& e, at::Tensor const& f, at::Tensor const& g, double h, at::Tensor const& i) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + auto ret = at::cudnn_batch_norm_backward(vec.at(0), vec.at(1), vec.at(2), vec.at(3), vec.at(4), vec.at(5), vec.at(6), h, vec.at(7)); + return {std::get<0>(ret), std::get<1>(ret), std::get<2>(ret)}; + }; + auto ret = CheckpointTensorImpl::make("cudnn_batch_norm_backward", rt, {a, b, c, d, e, f, g, i}); + return {ret[0], ret[1], ret[2]}; +} + +Tensor checkpoint_as_strided(const Tensor& a, c10::ArrayRef b, c10::ArrayRef c, c10::optional d) { std::vector b_ = b.vec(), c_ = c.vec(); rematerialize_function_t rt = [=](const Tensors& vec) -> Tensors { return {at::as_strided(vec.at(0), b_, c_, d)}; }; - strongs s = {from_tensor(a)}; - return CheckpointTensorImpl::make("as_strided", rt, s)[0]; + return CheckpointTensorImpl::make("as_strided", rt, {a})[0]; } -Tensor checkpoint__masked_scale(at::Tensor const& a, at::Tensor const& b, double c) { +Tensor checkpoint__masked_scale(const Tensor& a, const Tensor& b, double c) { rematerialize_function_t rt = [=](const Tensors& vec) -> Tensors { return {at::_masked_scale(vec.at(0), vec.at(1), c)}; }; - strongs s = {from_tensor(a), from_tensor(b)}; - return CheckpointTensorImpl::make("_masked_scale", rt, s)[0]; + return CheckpointTensorImpl::make("_masked_scale", rt, {a, b})[0]; } -Tensor checkpoint_cudnn_convolution(at::Tensor const& a, at::Tensor const& b, c10::ArrayRef c, c10::ArrayRef d, c10::ArrayRef e, long f, bool g, bool h) { +Tensor checkpoint_cudnn_convolution(const Tensor& a, const Tensor& b, c10::ArrayRef c, c10::ArrayRef d, c10::ArrayRef e, long f, bool g, bool h) { std::vector c_ = c.vec(), d_ = d.vec(), e_ = e.vec(); rematerialize_function_t rt = [=](const Tensors& vec) -> Tensors { return {at::cudnn_convolution(vec.at(0), vec.at(1), c_, d_, e_, f, g, h)}; }; - strongs s = {from_tensor(a), from_tensor(b)}; - return CheckpointTensorImpl::make("cudnn_convolution", rt, s)[0]; + return CheckpointTensorImpl::make("cudnn_convolution", rt, {a, b})[0]; } -Tensor checkpoint_cudnn_convolution_transpose(at::Tensor const& a, at::Tensor const& b, c10::ArrayRef c, c10::ArrayRef d, c10::ArrayRef e, c10::ArrayRef f, long g, bool h, bool i) { +Tensor checkpoint_cudnn_convolution_transpose(const Tensor& a, const Tensor& b, c10::ArrayRef c, c10::ArrayRef d, c10::ArrayRef e, c10::ArrayRef f, long g, bool h, bool i) { std::vector c_ = c.vec(), d_ = d.vec(), e_ = e.vec(), f_ = f.vec(); rematerialize_function_t rt = [=](const Tensors& vec) -> Tensors { return {at::cudnn_convolution_transpose(vec.at(0), vec.at(1), c_, d_, e_, f_, g, h, i)}; }; - strongs s = {from_tensor(a), from_tensor(b)}; - return CheckpointTensorImpl::make("cudnn_convolution_transpose", rt, s)[0]; + return CheckpointTensorImpl::make("cudnn_convolution_transpose", rt, {a, b})[0]; } -std::tuple checkpoint_cudnn_convolution_backward(at::Tensor const& a, at::Tensor const& b, at::Tensor const& c, c10::ArrayRef d, c10::ArrayRef e, c10::ArrayRef f, long g, bool h, bool i, std::array j) { +std::tuple checkpoint_cudnn_convolution_backward(const Tensor& a, const Tensor& b, const Tensor& c, c10::ArrayRef d, c10::ArrayRef e, c10::ArrayRef f, long g, bool h, bool i, std::array j) { std::vector d_ = d.vec(), e_ = e.vec(), f_ = f.vec(); rematerialize_function_t rt = [=](const Tensors& vec) -> Tensors { auto ret = at::cudnn_convolution_backward(vec.at(0), vec.at(1), vec.at(2), d_, e_, f_, g, h, i, j); return {std::get<0>(ret), std::get<1>(ret)}; }; - strongs s = {from_tensor(a), from_tensor(b), from_tensor(c)}; - auto ret = CheckpointTensorImpl::make("cudnn_convolution_backward", rt, s); + auto ret = CheckpointTensorImpl::make("cudnn_convolution_backward", rt, {a, b, c}); return {ret[0], ret[1]}; } -std::tuple checkpoint_cudnn_convolution_transpose_backward(at::Tensor const& a, at::Tensor const& b, at::Tensor const& c, c10::ArrayRef d, c10::ArrayRef e, c10::ArrayRef f, c10::ArrayRef g, long h, bool i, bool j, std::array k) { +std::tuple checkpoint_cudnn_convolution_transpose_backward(const Tensor& a, const Tensor& b, const Tensor& c, c10::ArrayRef d, c10::ArrayRef e, c10::ArrayRef f, c10::ArrayRef g, long h, bool i, bool j, std::array k) { std::vector d_ = d.vec(), e_ = e.vec(), f_ = f.vec(), g_ = g.vec(); rematerialize_function_t rt = [=](const Tensors& vec) -> Tensors { auto ret = at::cudnn_convolution_transpose_backward(vec.at(0), vec.at(1), vec.at(2), d_, e_, f_, g_, h, i, j, k); return {std::get<0>(ret), std::get<1>(ret)}; }; - strongs s = {from_tensor(a), from_tensor(b), from_tensor(c)}; - auto ret = CheckpointTensorImpl::make("cudnn_convolution_transpose_backward", rt, s); + auto ret = CheckpointTensorImpl::make("cudnn_convolution_transpose_backward", rt, {a, b, c}); return {ret[0], ret[1]}; } -Tensor checkpoint_cudnn_convolution_backward_input(c10::ArrayRef a, at::Tensor const& b, at::Tensor const& c, c10::ArrayRef d, c10::ArrayRef e, c10::ArrayRef f, long g, bool h, bool i) { +Tensor checkpoint_cudnn_convolution_backward_input(c10::ArrayRef a, const Tensor& b, const Tensor& c, c10::ArrayRef d, c10::ArrayRef e, c10::ArrayRef f, long g, bool h, bool i) { std::vector a_ = a.vec(), d_ = d.vec(), e_ = e.vec(), f_ = f.vec(); rematerialize_function_t rt = [=](const Tensors& vec) -> Tensors { return {at::cudnn_convolution_backward_input(a_, vec.at(0), vec.at(1), d_, e_, f_, g, h, i)}; }; - strongs s = {from_tensor(b), from_tensor(c)}; - return CheckpointTensorImpl::make("cudnn_convolution_backward_input", rt, s)[0]; + return CheckpointTensorImpl::make("cudnn_convolution_backward_input", rt, {b, c})[0]; } -Tensor checkpoint_cudnn_convolution_transpose_backward_input(at::Tensor const& a, at::Tensor const& b, c10::ArrayRef c, c10::ArrayRef d, c10::ArrayRef e, long f, bool g, bool h) { +Tensor checkpoint_cudnn_convolution_transpose_backward_input(const Tensor& a, const Tensor& b, c10::ArrayRef c, c10::ArrayRef d, c10::ArrayRef e, long f, bool g, bool h) { std::vector c_ = c.vec(), d_ = d.vec(), e_ = e.vec(); rematerialize_function_t rt = [=](const Tensors& vec) -> Tensors { return {at::cudnn_convolution_transpose_backward_input(vec.at(0), vec.at(1), c_, d_, e_, f, g, h)}; }; - strongs s = {from_tensor(a), from_tensor(b)}; - return CheckpointTensorImpl::make("cudnn_convolution_transpose_backward_input", rt, s)[0]; + return CheckpointTensorImpl::make("cudnn_convolution_transpose_backward_input", rt, {a, b})[0]; } -Tensor checkpoint_cudnn_convolution_backward_weight(c10::ArrayRef a, at::Tensor const& b, at::Tensor const& c, c10::ArrayRef d, c10::ArrayRef e, c10::ArrayRef f, long g, bool h, bool i) { +Tensor checkpoint_cudnn_convolution_backward_weight(c10::ArrayRef a, const Tensor& b, const Tensor& c, c10::ArrayRef d, c10::ArrayRef e, c10::ArrayRef f, long g, bool h, bool i) { std::vector a_ = a.vec(), d_ = d.vec(), e_ = e.vec(), f_ = f.vec(); rematerialize_function_t rt = [=](const Tensors& vec) -> Tensors { return {at::cudnn_convolution_backward_weight(a_, vec.at(0), vec.at(1), d_, e_, f_, g, h, i)}; }; - strongs s = {from_tensor(b), from_tensor(c)}; - return CheckpointTensorImpl::make("cudnn_convolution_backward_weight", rt, s)[0]; + return CheckpointTensorImpl::make("cudnn_convolution_backward_weight", rt, {b, c})[0]; } -Tensor checkpoint_cudnn_convolution_transpose_backward_weight(c10::ArrayRef a, at::Tensor const& b, at::Tensor const& c, c10::ArrayRef d, c10::ArrayRef e, c10::ArrayRef f, long g, bool h, bool i) { +Tensor checkpoint_cudnn_convolution_transpose_backward_weight(c10::ArrayRef a, const Tensor& b, const Tensor& c, c10::ArrayRef d, c10::ArrayRef e, c10::ArrayRef f, long g, bool h, bool i) { std::vector a_ = a.vec(), d_ = d.vec(), e_ = e.vec(), f_ = f.vec(); rematerialize_function_t rt = [=](const Tensors& vec) -> Tensors { return {at::cudnn_convolution_transpose_backward_weight(a_, vec.at(0), vec.at(1), d_, e_, f_, g, h, i)}; }; - strongs s = {from_tensor(b), from_tensor(c)}; - return CheckpointTensorImpl::make("cudnn_convolution_transpose_backward_weight", rt, s)[0]; + return CheckpointTensorImpl::make("cudnn_convolution_transpose_backward_weight", rt, {b, c})[0]; +} + +Tensor checkpoint_relu(const Tensor& a) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::relu(vec.at(0))}; + }; + return CheckpointTensorImpl::make("relu", rt, {a})[0]; +} + +Tensor& checkpoint_relu_(Tensor& a) { + mutate_function_t mt = + [=](const Tensors& vec) { + vec.at(0).relu_(); + }; + CheckpointTensorImpl::mutate("relu_", mt, {a}, {0}); + return a; +} + +std::tuple checkpoint_max_pool2d_with_indices_out(Tensor& a, Tensor& b, const Tensor& c, c10::ArrayRef d, c10::ArrayRef e, c10::ArrayRef f, c10::ArrayRef g, bool h) { + std::vector d_ = d.vec(), e_ = e.vec(), f_ = f.vec(), g_ = g.vec(); + mutate_function_t mt = + [=](const Tensors& vec) { + Tensor a_ = vec.at(0), b_ = vec.at(1); + at::max_pool2d_with_indices_out(a_, b_, vec.at(2), d_, e_, f_, g_, h); + }; + CheckpointTensorImpl::mutate("max_pool2d_with_indices_out", mt, {a, b, c}, {0, 1}); + return {a, b}; +} + +Tensor checkpoint_avg_pool2d(const Tensor& a, c10::ArrayRef b, c10::ArrayRef c, c10::ArrayRef d, bool e, bool f, c10::optional g) { + std::vector b_ = b.vec(), c_ = c.vec(), d_ = d.vec(); + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::avg_pool2d(vec.at(0), b_, c_, d_, e, f, g)}; + }; + return CheckpointTensorImpl::make("avg_pool2d", rt, {a})[0]; +} + +Tensor checkpoint_avg_pool2d_backward(const Tensor& a, const Tensor& b, c10::ArrayRef c, c10::ArrayRef d, c10::ArrayRef e, bool f, bool g, c10::optional h) { + std::vector c_ = c.vec(), d_ = d.vec(), e_ = e.vec(); + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::avg_pool2d_backward(vec.at(0), vec.at(1), c_, d_, e_, f, g, h)}; + }; + return CheckpointTensorImpl::make("avg_pool2d_backward", rt, {a, b})[0]; +} + +Tensor& checkpoint_avg_pool2d_out(Tensor& a, const Tensor& b, c10::ArrayRef c, c10::ArrayRef d, c10::ArrayRef e, bool f, bool g, c10::optional h) { + std::vector c_ = c.vec(), d_ = d.vec(), e_ = e.vec(); + mutate_function_t mt = + [=](const Tensors& vec) { + Tensor a_ = vec.at(0); + at::avg_pool2d_out(a_, vec.at(1), c_, d_, e_, f, g, h); + }; + CheckpointTensorImpl::mutate("avg_pool2d_out", mt, {a, b}, {0}); + return a; +} + +Tensor& checkpoint_avg_pool2d_backward_grad_input(Tensor& a, const Tensor& b, const Tensor& c, c10::ArrayRef d, c10::ArrayRef e, c10::ArrayRef f, bool g, bool h, c10::optional i) { + std::vector d_ = d.vec(), e_ = e.vec(), f_ = f.vec(); + mutate_function_t mt = + [=](const Tensors& vec) { + Tensor a_ = vec.at(0); + at::avg_pool2d_backward_out(a_, vec.at(1), vec.at(2), d_, e_, f_, g, h, i); + }; + CheckpointTensorImpl::mutate("avg_pool2d_backward_grad_input", mt, {a, b, c}, {0}); + return a; +} + +std::tuple checkpoint_max_pool2d_with_indices(const Tensor& a, c10::ArrayRef b, c10::ArrayRef c, c10::ArrayRef d, c10::ArrayRef e, bool f) { + std::vector b_ = b.vec(), c_ = c.vec(), d_ = d.vec(), e_ = e.vec(); + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + auto ret = at::max_pool2d_with_indices(vec.at(0), b_, c_, d_, e_, f); + return {std::get<0>(ret), std::get<1>(ret)}; + }; + auto ret = CheckpointTensorImpl::make("max_pool2d_backward", rt, {a}); + return {ret[0], ret[1]}; +} + +Tensor& checkpoint_max_pool2d_with_indices_backward_grad_input(Tensor& a, const Tensor& b, const Tensor& c, c10::ArrayRef d, c10::ArrayRef e, c10::ArrayRef f, c10::ArrayRef g, bool h, const Tensor& i) { + std::vector d_ = d.vec(), e_ = e.vec(), f_ = f.vec(), g_ = g.vec(); + mutate_function_t mt = + [=](const Tensors& vec) { + Tensor a_ = vec.at(0); + at::max_pool2d_with_indices_backward_out(a_, vec.at(1), vec.at(2), d_, e_, f_, g, h, vec.at(3)); + }; + CheckpointTensorImpl::mutate("max_pool2d_with_indices_backward_grad_input", mt, {a, b, c, i}, {0}); + return a; +} + +Tensor checkpoint_max_pool2d_with_indices_backward(const Tensor& a, const Tensor& b, c10::ArrayRef c, c10::ArrayRef d, c10::ArrayRef e, c10::ArrayRef f, bool g, const Tensor& h) { + std::vector c_ = c.vec(), d_ = d.vec(), e_ = e.vec(), f_ = f.vec(); + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::max_pool2d_with_indices_backward(vec.at(0), vec.at(1), c_, d_, e_, f_, g, vec.at(2))}; + }; + return CheckpointTensorImpl::make("max_pool2d_with_indices_backward", rt, {a, b, h})[0]; +} + +Tensor checkpoint_view(const Tensor& a, c10::ArrayRef b) { + std::vector b_ = b.vec(); + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {vec.at(0).view(b_)}; + }; + return CheckpointTensorImpl::make("view", rt, {a})[0]; +} + +Tensor checkpoint_ne_Scalar(const Tensor& a, c10::Scalar b) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::ne(vec.at(0), b)}; + }; + return CheckpointTensorImpl::make("ne_Scalar", rt, {a})[0]; +} + +Tensor& checkpoint_ne_Scalar_out(Tensor& a, const Tensor& b, c10::Scalar c) { + mutate_function_t mt = + [=](const Tensors& vec) { + Tensor a_ = vec.at(0); + at::ne_out(a_, vec.at(1), c); + }; + CheckpointTensorImpl::mutate("ne_Scalar_out", mt, {a, b}, {0}); + return a; +} + +Tensor checkpoint_ne_Tensor(const Tensor& a, const Tensor& b) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::ne(vec.at(0), vec.at(1))}; + }; + return CheckpointTensorImpl::make("ne_Tensor", rt, {a, b})[0]; +} + +Tensor& checkpoint_ne_Tensor_out(Tensor& a, const Tensor& b, const Tensor& c) { + mutate_function_t mt = + [=](const Tensors& vec) { + Tensor a_ = vec.at(0); + at::ne_out(a_, vec.at(1), vec.at(2)); + }; + CheckpointTensorImpl::mutate("ne_Tensor_out", mt, {a, b, c}, {0}); + return a; +} + +Tensor checkpoint_eq_Scalar(const Tensor& a, c10::Scalar b) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::eq(vec.at(0), b)}; + }; + return CheckpointTensorImpl::make("eq_Scalar", rt, {a})[0]; +} + +Tensor& checkpoint_eq_Scalar_out(Tensor& a, const Tensor& b, c10::Scalar c) { + mutate_function_t mt = + [=](const Tensors& vec) { + Tensor a_ = vec.at(0); + at::eq_out(a_, vec.at(1), c); + }; + CheckpointTensorImpl::mutate("eq_Scalar_out", mt, {a, b}, {0}); + return a; +} + +Tensor checkpoint_eq_Tensor(const Tensor& a, const Tensor& b) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::eq(vec.at(0), vec.at(1))}; + }; + return CheckpointTensorImpl::make("eq_Tensor", rt, {a, b})[0]; +} + +Tensor& checkpoint_eq_Tensor_out(Tensor& a, const Tensor& b, const Tensor& c) { + mutate_function_t mt = + [=](const Tensors& vec) { + Tensor a_ = vec.at(0); + at::eq_out(a_, vec.at(1), vec.at(2)); + }; + CheckpointTensorImpl::mutate("eq_Tensor_out", mt, {a, b, c}, {0}); + return a; +} + +Tensor checkpoint_addmm(const Tensor& a, const Tensor& b, const Tensor& c, c10::Scalar d, c10::Scalar e) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::addmm(vec.at(0), vec.at(1), vec.at(2), d, e)}; + }; + return CheckpointTensorImpl::make("addmm", rt, {a, b, c})[0]; +} + +Tensor& checkpoint_addmm_out(Tensor& a, const Tensor& b, const Tensor& c, const Tensor& d, c10::Scalar e, c10::Scalar f) { + mutate_function_t mt = + [=](const Tensors& vec) { + Tensor a_ = vec.at(0); + at::addmm_out(a_, vec.at(1), vec.at(2), d, e, f); + }; + CheckpointTensorImpl::mutate("addmm_out", mt, {a, b, c}, {0}); + return a; +} + +Tensor& checkpoint_addmm_(Tensor& a, const Tensor& b, const Tensor& c, c10::Scalar d, c10::Scalar e) { + mutate_function_t mt = + [=](const Tensors& vec) { + Tensor a_ = vec.at(0); + a.addmm_(vec.at(1), vec.at(2), d, e); + }; + CheckpointTensorImpl::mutate("addmm_", mt, {a, b, c}, {0}); + return a; +} + +Tensor checkpoint_sigmoid(const Tensor& a) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::sigmoid(vec.at(0))}; + }; + return CheckpointTensorImpl::make("sigmoid", rt, {a})[0]; +} + +Tensor& checkpoint_sigmoid_(Tensor& a) { + mutate_function_t mt = + [=](const Tensors& vec) { + Tensor a_ = vec.at(0); + a.sigmoid_(); + }; + CheckpointTensorImpl::mutate("sigmoid_", mt, {a}, {0}); + return a; +} + +Tensor checkpoint__log_softmax(const Tensor& a, long b, bool c) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::_log_softmax(vec.at(0), b, c)}; + }; + return CheckpointTensorImpl::make("_log_softmax", rt, {a})[0]; +} + +Tensor checkpoint__log_softmax_backward_data(const Tensor& a, const Tensor& b, long c, const Tensor& d) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::_log_softmax_backward_data(vec.at(0), vec.at(1), c, vec.at(2))}; + }; + return CheckpointTensorImpl::make("_log_softmax_backward_data", rt, {a, b, d})[0]; +} + +std::tuple checkpoint_nll_loss_forward(const Tensor& a, const Tensor& b, const Tensor& c, long d, long e) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + auto ret = at::nll_loss_forward(vec.at(0), vec.at(1), vec.at(2), d, e); + return {std::get<0>(ret), std::get<1>(ret)}; + }; + auto ret = CheckpointTensorImpl::make("nll_loss_forward", rt, {a, b, c}); + return {ret[0], ret[1]}; +} + +std::tuple checkpoint_nll_loss_forward_out(Tensor& a, Tensor& b, const Tensor& c, const Tensor& d, const Tensor& e, long f, long g) { + mutate_function_t mt = + [=](const Tensors& vec) { + Tensor a_ = vec.at(0); + Tensor b_ = vec.at(1); + at::nll_loss_forward_out(a_, b_, vec.at(2), vec.at(3), vec.at(4), f, g); + }; + CheckpointTensorImpl::mutate("nll_loss_forward_out", mt, {a, b, c, d, e}, {0, 1}); + return {a, b}; +} + +Tensor checkpoint_nll_loss_backward(const Tensor& a, const Tensor& b, const Tensor& c, const Tensor& d, long e, long f, const Tensor& g) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::nll_loss_backward(vec.at(0), vec.at(1), vec.at(2), vec.at(3), e, f, vec.at(4))}; + }; + return CheckpointTensorImpl::make("nll_loss_backward", rt, {a, b, c, d, g})[0]; +} + +Tensor& checkpoint_nll_loss_backward_grad_input(Tensor& a, const Tensor& b, const Tensor& c, const Tensor& d, const Tensor& e, long f, long g, const Tensor& h) { + mutate_function_t mt = + [=](const Tensors& vec) { + Tensor a_ = vec.at(0); + at::nll_loss_backward_out(a_, vec.at(1), vec.at(2), vec.at(3), vec.at(4), f, g, vec.at(5)); + }; + CheckpointTensorImpl::mutate("nll_loss_backward_grad_input", mt, {a, b, c, d, e, h}, {0}); + return a; +} + +Tensor checkpoint_mm(const Tensor& a, const Tensor& b) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::mm(vec.at(0), vec.at(1))}; + }; + return CheckpointTensorImpl::make("mm", rt, {a, b})[0]; +} + +Tensor& checkpoint_mm_out(Tensor& a, const Tensor& b, const Tensor& c) { + mutate_function_t mt = + [=](const Tensors& vec) { + Tensor a_ = vec.at(0); + at::mm_out(a_, vec.at(1), vec.at(2)); + }; + CheckpointTensorImpl::mutate("mm_out", mt, {a, b, c}, {0}); + return a; +} + +Tensor checkpoint_sum(const Tensor& a, c10::optional b) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::sum(vec.at(0), b)}; + }; + return CheckpointTensorImpl::make("sum", rt, {a})[0]; +} + +Tensor checkpoint_sum_dim_IntList(const Tensor& a, c10::ArrayRef b, bool c, c10::optional d) { + std::vector b_ = b.vec(); + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::sum(vec.at(0), b_, c, d)}; + }; + return CheckpointTensorImpl::make("sum_dim_IntList", rt, {a})[0]; +} + +Tensor checkpoint_threshold(const Tensor& a, c10::Scalar b, c10::Scalar c) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::threshold(vec.at(0), b, c)}; + }; + return CheckpointTensorImpl::make("threshold", rt, {a})[0]; +} + +Tensor& checkpoint_threshold_(Tensor& a, c10::Scalar b, c10::Scalar c) { + mutate_function_t mt = + [=](const Tensors& vec) { + Tensor a_ = vec.at(0); + at::threshold_(a_, b, c); + }; + CheckpointTensorImpl::mutate("threshold_", mt, {a}, {0}); + return a; +} + +Tensor& checkpoint_threshold_out(Tensor& a, const Tensor& b, c10::Scalar c, c10::Scalar d) { + mutate_function_t mt = + [=](const Tensors& vec) { + Tensor a_ = vec.at(0); + at::threshold_out(a_, b, c, d); + }; + CheckpointTensorImpl::mutate("threshold_out", mt, {a}, {0}); + return a; +} + +Tensor checkpoint_threshold_backward(const Tensor& a, const Tensor& b, c10::Scalar c) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::threshold_backward(vec.at(0), vec.at(1), c)}; + }; + return CheckpointTensorImpl::make("threshold_backward", rt, {a, b})[0]; +} + +Tensor checkpoint_select(const Tensor& a, long b, long c) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::select(vec.at(0), b, c)}; + }; + return CheckpointTensorImpl::make("select", rt, {a})[0]; +} + +Tensor checkpoint_select_backward(const Tensor& a, c10::ArrayRef b, long c, long d) { + std::vector b_ = b.vec(); + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::select_backward(vec.at(0), b_, c, d)}; + }; + return CheckpointTensorImpl::make("select_backward", rt, {a})[0]; +} + +Tensor checkpoint_slice(const Tensor& a, long b, long c, long d, long e) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::slice(vec.at(0), b, c, d, e)}; + }; + return CheckpointTensorImpl::make("slice", rt, {a})[0]; +} + +Tensor checkpoint_slice_backward(const Tensor& a, c10::ArrayRef b, long c, long d, long e, long f) { + std::vector b_ = b.vec(); + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::slice_backward(vec.at(0), b_, c, d, e, f)}; + }; + return CheckpointTensorImpl::make("slice_backward", rt, {a})[0]; +} + +Tensor& checkpoint_zero_(Tensor& a) { + mutate_function_t mt = + [=](const Tensors& vec) { + vec.at(0).zero_(); + }; + CheckpointTensorImpl::mutate("zero_", mt, {a}, {0}); + return a; } }} diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 683ef1c5dbc..e9b6b23f4a6 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -886,7 +886,7 @@ - func: cudnn_batch_norm_backward(Tensor input, Tensor grad_output, Tensor weight, Tensor? running_mean, Tensor? running_var, Tensor? save_mean, Tensor? save_var, float epsilon, Tensor reserveSpace) -> (Tensor, Tensor, Tensor) dispatch: CUDA: cudnn_batch_norm_backward - Checkpoint: cudnn_batch_norm_backward + Checkpoint: checkpoint_cudnn_batch_norm_backward - func: cudnn_convolution.deprecated(Tensor self, Tensor weight, Tensor? bias, int[] padding, int[] stride, int[] dilation, int groups, bool benchmark, bool deterministic) -> Tensor dispatch: @@ -1747,12 +1747,14 @@ dispatch: CPU: log_softmax_cpu CUDA: log_softmax_cuda + Checkpoint: checkpoint__log_softmax - func: _log_softmax_backward_data(Tensor grad_output, Tensor output, int dim, Tensor self) -> Tensor use_c10_dispatcher: full dispatch: CPU: log_softmax_backward_cpu CUDA: log_softmax_backward_cuda + Checkpoint: checkpoint__log_softmax_backward_data - func: logsumexp(Tensor self, int[1] dim, bool keepdim=False) -> Tensor supports_named_tensor: True @@ -1985,6 +1987,7 @@ CUDA: legacy::cuda::_th_mm SparseCPU: _sparse_mm SparseCUDA: _sparse_mm + Checkpoint: checkpoint_mm supports_named_tensor: True - func: mm.out(Tensor self, Tensor mat2, *, Tensor(a!) out) -> Tensor(a!) @@ -1993,6 +1996,7 @@ CUDA: legacy::cuda::_th_mm_out SparseCPU: _sparse_mm_out SparseCUDA: _sparse_mm_out + Checkpoint: checkpoint_mm_out supports_named_tensor: True - func: _sparse_mm(Tensor sparse, Tensor dense) -> Tensor @@ -2382,6 +2386,7 @@ CUDA: relu MkldnnCPU: mkldnn_relu QuantizedCPU: quantized_relu + Checkpoint: checkpoint_relu supports_named_tensor: True - func: relu_(Tensor(a!) self) -> Tensor(a!) @@ -2392,6 +2397,7 @@ CUDA: relu_ MkldnnCPU: mkldnn_relu_ QuantizedCPU: quantized_relu_ + Checkpoint: checkpoint_relu_ - func: prelu(Tensor self, Tensor weight) -> Tensor use_c10_dispatcher: full @@ -2452,6 +2458,18 @@ variants: function, method device_guard: False supports_named_tensor: True + dispatch: + CPU: select + CUDA: select + Checkpoint: checkpoint_select + +- func: select_backward(Tensor grad, int[] sizes, int dim, int index) -> Tensor + variants: function + device_guard: False + dispatch: + CPU: select_backward + CUDA: select_backward + Checkpoint: checkpoint_select_backward - func: selu(Tensor self) -> Tensor use_c10_dispatcher: full @@ -2473,6 +2491,7 @@ CUDA: sigmoid QuantizedCPU: quantized_sigmoid MkldnnCPU: mkldnn_sigmoid + Checkpoint: checkpoint_sigmoid - func: sigmoid_(Tensor(a!) self) -> Tensor(a!) supports_named_tensor: True @@ -2481,6 +2500,7 @@ CPU: sigmoid_ CUDA: sigmoid_ MkldnnCPU: mkldnn_sigmoid_ + Checkpoint: checkpoint_sigmoid_ - func: sigmoid.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) supports_named_tensor: True @@ -2552,6 +2572,18 @@ variants: function, method device_guard: False supports_named_tensor: True + dispatch: + CPU: slice + CUDA: slice + Checkpoint: checkpoint_slice + +- func: slice_backward(Tensor grad, int[] input_sizes, int dim, int start, int end, int step) -> Tensor + variants: function, method + device_guard: False + dispatch: + CPU: slice_backward + CUDA: slice_backward + Checkpoint: checkpoint_slice_backward - func: slogdet(Tensor self) -> (Tensor sign, Tensor logabsdet) variants: function, method @@ -2655,10 +2687,18 @@ - func: sum(Tensor self, *, ScalarType? dtype=None) -> Tensor variants: function, method supports_named_tensor: True + dispatch: + CPU: sum + CUDA: sum + Checkpoint: checkpoint_sum - func: sum.dim_IntList(Tensor self, int[1] dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor variants: function, method supports_named_tensor: True + dispatch: + CPU: sum + CUDA: sum + Checkpoint: checkpoint_sum_dim_IntList - func: sum.dim_DimnameList(Tensor self, Dimname[1] dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor variants: function, method @@ -2805,6 +2845,7 @@ dispatch: CPU: threshold CUDA: threshold_cuda + Checkpoint: checkpoint_threshold - func: threshold_(Tensor(a!) self, Scalar threshold, Scalar value) -> Tensor(a!) variants: function @@ -2812,12 +2853,14 @@ dispatch: CPU: threshold_ CUDA: threshold__cuda + Checkpoint: checkpoint_threshold_ - func: threshold.out(Tensor self, Scalar threshold, Scalar value, *, Tensor(a!) out) -> Tensor(a!) supports_named_tensor: True dispatch: CPU: threshold_out CUDA: threshold_out_cuda + Checkpoint: checkpoint_threshold_out - func: threshold_backward(Tensor grad_output, Tensor self, Scalar threshold) -> Tensor use_c10_dispatcher: full @@ -2825,6 +2868,7 @@ dispatch: CPU: threshold_backward CUDA: threshold_backward_cuda + Checkpoint: checkpoint_threshold_backward - func: transpose.int(Tensor(a) self, int dim0, int dim1) -> Tensor(a) variants: function, method @@ -3182,6 +3226,7 @@ SparseCPU: zero_sparse_ SparseCUDA: zero_sparse_ MkldnnCPU: mkldnn_zero_ + Checkpoint: checkpoint_zero_ - func: sub.out(Tensor self, Tensor other, *, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!) dispatch: @@ -3243,6 +3288,7 @@ CUDA: legacy::cuda::_th_addmm_out SparseCPU: addmm_out_sparse_dense_cpu SparseCUDA: addmm_out_sparse_dense_cuda + Checkpoint: checkpoint_addmm_out supports_named_tensor: True - func: addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor @@ -3253,6 +3299,7 @@ CUDA: legacy::cuda::_th_addmm SparseCPU: addmm_sparse_dense_cpu SparseCUDA: addmm_sparse_dense_cuda + Checkpoint: checkpoint_addmm supports_named_tensor: True - func: addmm_(Tensor(a!) self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor(a!) @@ -3264,6 +3311,7 @@ # broadcasting SparseCPU: s_addmm_sparse_dense_cpu_ SparseCUDA: s_addmm_sparse_dense_cuda_ + Checkpoint: checkpoint_addmm_ supports_named_tensor: True @@ -3915,6 +3963,7 @@ CUDA: view MkldnnCPU: mkldnn_view QuantizedCPU: view + Checkpoint: checkpoint_view - func: put_(Tensor(a!) self, Tensor index, Tensor source, bool accumulate=False) -> Tensor(a!) variants: method @@ -4421,6 +4470,7 @@ CPU: ne_out CUDA: ne_out QuantizedCPU: ne_out_quantized_cpu + Checkpoint: checkpoint_ne_Scalar_out - func: ne.Scalar(Tensor self, Scalar other) -> Tensor supports_named_tensor: True @@ -4430,6 +4480,7 @@ CPU: ne CUDA: ne QuantizedCPU: ne_quantized_cpu + Checkpoint: checkpoint_ne_Scalar - func: ne.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) supports_named_tensor: True @@ -4437,6 +4488,7 @@ CPU: ne_out CUDA: ne_out QuantizedCPU: ne_out_quantized_cpu + Checkpoint: checkpoint_ne_Tensor_out - func: ne.Tensor(Tensor self, Tensor other) -> Tensor supports_named_tensor: True @@ -4446,6 +4498,7 @@ CPU: ne CUDA: ne QuantizedCPU: ne_quantized_cpu + Checkpoint: checkpoint_ne_Tensor - func: eq.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) supports_named_tensor: True @@ -4453,6 +4506,7 @@ CPU: eq_out CUDA: eq_out QuantizedCPU: eq_out_quantized_cpu + Checkpoint: checkpoint_eq_Scalar_out - func: eq.Scalar(Tensor self, Scalar other) -> Tensor supports_named_tensor: True @@ -4462,6 +4516,7 @@ CPU: eq CUDA: eq QuantizedCPU: eq_quantized_cpu + Checkpoint: checkpoint_eq_Scalar - func: eq.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) supports_named_tensor: True @@ -4469,6 +4524,7 @@ CPU: eq_out CUDA: eq_out QuantizedCPU: eq_out_quantized_cpu + Checkpoint: checkpoint_eq_Tensor_out - func: eq.Tensor(Tensor self, Tensor other) -> Tensor supports_named_tensor: True @@ -4478,6 +4534,7 @@ CPU: eq CUDA: eq QuantizedCPU: eq_quantized_cpu + Checkpoint: checkpoint_eq_Tensor - func: ge.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) supports_named_tensor: True @@ -5461,24 +5518,28 @@ dispatch: CPU: nll_loss_forward_out_cpu CUDA: legacy::cuda::_thnn_nll_loss_forward_out + Checkpoint: checkpoint_nll_loss_forward_out - func: nll_loss_forward(Tensor self, Tensor target, Tensor? weight, int reduction, int ignore_index) -> (Tensor output, Tensor total_weight) python_module: nn dispatch: CPU: nll_loss_forward_cpu CUDA: legacy::cuda::_thnn_nll_loss_forward + Checkpoint: checkpoint_nll_loss_forward - func: nll_loss_backward.grad_input(Tensor grad_output, Tensor self, Tensor target, Tensor? weight, int reduction, int ignore_index, Tensor total_weight, *, Tensor(a!) grad_input) -> Tensor(a!) python_module: nn dispatch: CPU: nll_loss_backward_out_cpu CUDA: legacy::cuda::_thnn_nll_loss_backward_out + Checkpoint: checkpoint_nll_loss_backward_grad_input - func: nll_loss_backward(Tensor grad_output, Tensor self, Tensor target, Tensor? weight, int reduction, int ignore_index, Tensor total_weight) -> Tensor python_module: nn dispatch: CPU: nll_loss_backward_cpu CUDA: legacy::cuda::_thnn_nll_loss_backward + Checkpoint: checkpoint_nll_loss_backward - func: nll_loss2d.out(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean, int ignore_index=-100, *, Tensor(a!) out) -> Tensor(a!) python_module: nn @@ -5837,6 +5898,7 @@ CPU: avg_pool2d_out_cpu CUDA: avg_pool2d_out_cuda MkldnnCPU: mkldnn_avg_pool2d_out + Checkpoint: checkpoint_avg_pool2d_out - func: avg_pool2d(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, bool ceil_mode=False, bool count_include_pad=True, int? divisor_override=None) -> Tensor python_module: nn @@ -5845,18 +5907,21 @@ CUDA: avg_pool2d_cuda MkldnnCPU: mkldnn_avg_pool2d QuantizedCPU: quantized_avg_pool2d + Checkpoint: checkpoint_avg_pool2d - func: avg_pool2d_backward.grad_input(Tensor grad_output, Tensor self, int[2] kernel_size, int[2] stride, int[2] padding, bool ceil_mode, bool count_include_pad, int? divisor_override, *, Tensor(a!) grad_input) -> Tensor(a!) python_module: nn dispatch: CPU: avg_pool2d_backward_out_cpu CUDA: avg_pool2d_backward_out_cuda + Checkpoint: checkpoint_avg_pool2d_backward_grad_input - func: avg_pool2d_backward(Tensor grad_output, Tensor self, int[2] kernel_size, int[2] stride, int[2] padding, bool ceil_mode, bool count_include_pad, int? divisor_override) -> Tensor python_module: nn dispatch: CPU: avg_pool2d_backward_cpu CUDA: avg_pool2d_backward_cuda + Checkpoint: checkpoint_avg_pool2d_backward - func: avg_pool3d.out(Tensor self, int[3] kernel_size, int[3] stride=[], int[3] padding=0, bool ceil_mode=False, bool count_include_pad=True, int? divisor_override=None, *, Tensor(a!) out) -> Tensor(a!) python_module: nn @@ -5940,6 +6005,7 @@ dispatch: CPU: max_pool2d_with_indices_out_cpu CUDA: max_pool2d_with_indices_out_cuda + Checkpoint: checkpoint_max_pool2d_with_indices_out # Return: (Tensor output, Tensor indices) - func: max_pool2d_with_indices(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> (Tensor, Tensor) @@ -5947,6 +6013,7 @@ dispatch: CPU: max_pool2d_with_indices_cpu CUDA: max_pool2d_with_indices_cuda + Checkpoint: checkpoint_max_pool2d_with_indices supports_named_tensor: True - func: max_pool2d_with_indices_backward.grad_input(Tensor grad_output, Tensor self, int[2] kernel_size, int[2] stride, int[2] padding, int[2] dilation, bool ceil_mode, Tensor indices, *, Tensor(a!) grad_input) -> Tensor(a!) @@ -5954,12 +6021,14 @@ dispatch: CPU: max_pool2d_with_indices_backward_out_cpu CUDA: max_pool2d_with_indices_backward_out_cuda + Checkpoint: checkpoint_max_pool2d_with_indices_backward_grad_input - func: max_pool2d_with_indices_backward(Tensor grad_output, Tensor self, int[2] kernel_size, int[2] stride, int[2] padding, int[2] dilation, bool ceil_mode, Tensor indices) -> Tensor python_module: nn dispatch: CPU: max_pool2d_with_indices_backward_cpu CUDA: max_pool2d_with_indices_backward_cuda + Checkpoint: checkpoint_max_pool2d_with_indices_backward # Return: (Tensor output, Tensor indices) - func: max_pool3d_with_indices.out(Tensor self, int[3] kernel_size, int[3] stride=[], int[3] padding=0, int[3] dilation=1, bool ceil_mode=False, *, Tensor(a!) out, Tensor(b!) indices) -> (Tensor(a!), Tensor(b!)) diff --git a/c10/core/TensorImpl.h b/c10/core/TensorImpl.h index 79c5ec375fa..a024be2e4ab 100644 --- a/c10/core/TensorImpl.h +++ b/c10/core/TensorImpl.h @@ -868,6 +868,9 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { * compatible with SparseCUDATensorId. */ inline bool has_compatible_shallow_copy_type(DispatchKeySet from) { + if (key_set_.has(DispatchKey::CheckpointTensorId) || from.has(DispatchKey::CheckpointTensorId)) { + return false; + } auto is_dense = [](DispatchKeySet ts) { return ts.has(DispatchKey::CPUTensorId) || ts.has(DispatchKey::CUDATensorId) || diff --git a/tools/autograd/templates/Functions.cpp b/tools/autograd/templates/Functions.cpp index 70278a08190..03b6fc5b4cc 100644 --- a/tools/autograd/templates/Functions.cpp +++ b/tools/autograd/templates/Functions.cpp @@ -656,18 +656,6 @@ Tensor index_select_backward(Tensor grad, int64_t dim, Tensor indices, IntArrayR return at::zeros(sizes, grad.options()).scatter_(dim, indices, grad); } -Tensor slice_backward(Tensor grad, IntArrayRef input_sizes, int64_t dim, int64_t start, int64_t end, int64_t step) { - auto grad_input = at::zeros(input_sizes, grad.options()); - grad_input.slice(dim, start, end, step).copy_(grad); - return grad_input; -} - -Tensor select_backward(Tensor grad, IntArrayRef input_sizes, int64_t dim, int64_t index) { - auto grad_input = at::zeros(input_sizes, grad.options()); - grad_input.select(dim, index).copy_(grad); - return grad_input; -} - Tensor trace_backward(const Tensor & grad, IntArrayRef sizes) { if (sizes.size() != 2) { throw std::runtime_error("expected matrix input"); From 6e6c4e1a2d397c7af0d52c60382c747f9162984f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9B=BE=E9=9B=A8=E9=AD=94=E7=90=86=E6=B2=99?= Date: Tue, 7 Apr 2020 20:53:19 -0700 Subject: [PATCH 05/42] Logging everything for resnet (#18) * save * save * save * save * use release instead of free --- .gitmodules | 3 + aten/src/ATen/CheckpointTensorImpl.cpp | 211 +++++++++++++++++++++---- aten/src/ATen/CheckpointTensorImpl.h | 16 +- aten/src/ATen/native/Checkpoint.cpp | 15 -- third_party/json | 1 + 5 files changed, 197 insertions(+), 49 deletions(-) create mode 160000 third_party/json diff --git a/.gitmodules b/.gitmodules index 3ae80c83792..6be1789ef76 100644 --- a/.gitmodules +++ b/.gitmodules @@ -122,3 +122,6 @@ ignore = dirty path = third_party/XNNPACK url = https://github.com/google/XNNPACK.git +[submodule "third_party/json"] + path = third_party/json + url = git@github.com:nlohmann/json.git diff --git a/aten/src/ATen/CheckpointTensorImpl.cpp b/aten/src/ATen/CheckpointTensorImpl.cpp index 085fe73ee59..67e5748b559 100644 --- a/aten/src/ATen/CheckpointTensorImpl.cpp +++ b/aten/src/ATen/CheckpointTensorImpl.cpp @@ -1,5 +1,6 @@ #include #include +#include <../../../third_party/json/single_include/nlohmann/json.hpp> namespace at { @@ -25,45 +26,187 @@ void DTRLog(const std::string& str) { logger.out << str << std::endl; } +using json = nlohmann::json;using json = nlohmann::json; +bool log_json = true; +const std::string INSTRUCTION = "INSTRUCTION"; +const std::string RELEASE = "RELEASE"; +const std::string TIME = "TIME"; +const std::string ARGS = "ARGS"; +const std::string MEMORY = "MEMORY"; +const std::string NAME = "NAME"; +const std::string CONSTANT = "CONSTANT"; + +void DTRLogConstant(const std::string& name) { + if (log_json) { + json j; + j[INSTRUCTION] = CONSTANT; + j[NAME] = name; + DTRLog(j.dump()); + } else { + DTRLog(CONSTANT + " " + name); + } +} + +void DTRLogMemory(const std::string& name, size_t memory) { + if (log_json) { + json j; + j[INSTRUCTION] = MEMORY; + j[NAME] = name; + j[MEMORY] = std::to_string(memory); + DTRLog(j.dump()); + } else { + DTRLog(name + " " + MEMORY + ": " + std::to_string(memory)); + } +} + +namespace native { + +Tensor checkpoint(const Tensor& t) { + auto cpti = intrusive_ptr::make(t.detach()); + DTRLogConstant(cpti->counter_name()); + DTRLogMemory(cpti->counter_name(), cpti->ref->value->memory()); + return Tensor(cpti); +} + +Tensor decheckpoint(const Tensor& t) { + auto* cpti = dynamic_cast(t.unsafeGetTensorImpl()); + CHECK(cpti != nullptr); + return cpti->ref->value->t; +} + +bool is_checkpoint(const Tensor& t) { + auto* cpti = dynamic_cast(t.unsafeGetTensorImpl()); + return cpti != nullptr; +} + +} + +void DTRLogCopy(const std::string& new_name, const std::string& old_name) { + if (log_json) { + json j; + j[INSTRUCTION] = "COPY"; + j["DST"] = new_name; + j["SRC"] = old_name; + DTRLog(j.dump()); + } else { + DTRLog(new_name + " = " + old_name); + } +} + +intrusive_ptr CheckpointTensorImpl::shallow_copy_and_detach(const VariableVersion& version_counter, + bool allow_tensor_metadata_change) const { + auto ret = intrusive_ptr::make(ref); + DTRLogCopy(ret->counter_name(), counter_name()); + return ret; +} + int CheckpointTensorImpl::counter = 0; -Tensors make_raw(const rematerialize_function_t& remat, - const strongs& input_values) { +Tensor checkpoint_raw(const Tensor& t) { + return Tensor(intrusive_ptr::make(t.detach())); +} + +std::tuple make_raw(const rematerialize_function_t& remat, + const strongs& input_values) { std::vector input; for (const strong& s: input_values) { CHECK(!s->t.key_set().has(DispatchKey::CheckpointTensorId)); input.push_back(s->t); } + time_t pre = std::chrono::system_clock::now(); auto output = remat(input); - Tensors ret; - for (const Tensor& o: output) { - ret.push_back(native::checkpoint(o)); + time_t post = std::chrono::system_clock::now(); + return {output, post - pre}; +} + +std::string from_time(duration_t t) { + return std::to_string(std::chrono::nanoseconds(t).count()); +} + +void DTRLogCall(const std::vector& res, const std::string& name, const std::vector& args, const std::string& time) { + if (log_json) { + json j; + j[INSTRUCTION] = "CALL"; + j[NAME] = name; + j["RESULT"] = res; + j[ARGS] = args; + j[TIME] = time; + DTRLog(j.dump()); + } else { + std::string arg = name + "("; + for (const auto& s : arg) { + arg += s; + arg += ", "; + } + arg += ")"; + std::string log = "("; + for (const auto& s: res) { + log += s; + log += ", "; + } + log += ") = "; + log += arg; + log += " TIME: "; + log += time; + DTRLog(log); } - return ret; } Tensors CheckpointTensorImpl::make(const std::string& name, const rematerialize_function_t& remat, const Tensors& input) { strongs input_values; - std::string arg = name + "("; + std::vector args; for (const Tensor& t: input) { auto ft = from_tensor(t); input_values.push_back(std::get<0>(ft)); - arg += std::get<1>(ft); - arg += ", "; + args.push_back(std::get<1>(ft)); } - arg += ")"; - std::string log = "("; - Tensors ret = make_raw(remat, input_values); - for (const Tensor& t: ret) { - log += get_cpti(t)->counter_name(); - log += ", "; + std::vector res; + auto ret = make_raw(remat, input_values); + Tensors tensors; + for (const Tensor& t: std::get<0>(ret)) { + auto cp = checkpoint_raw(t); + tensors.push_back(cp); + res.push_back(get_cpti(cp)->counter_name()); + } + DTRLogCall(res, name, args, from_time(std::get<1>(ret))); + for (const Tensor& t: tensors) { + auto cpti = get_cpti(t); + DTRLogMemory(cpti->counter_name(), cpti->ref->value->memory()); + } + return tensors; +} + +void DTRLogMutate(const std::string& name, const std::vector& args, const std::vector& mutate, const std::string& time) { + if (log_json) { + json j; + j[INSTRUCTION] = "MUTATE"; + j[NAME] = name; + j[ARGS] = args; + j["MUTATE"] = mutate; + j[TIME] = time; + DTRLog(j.dump()); + } else { + std::string log = name; + log += "("; + for (const auto& s : args) { + log += s; + log += ", "; + } + log += ") "; + log += " MUTATING: "; + log += "("; + for (const size_t i : mutate) { + log += std::to_string(i); + log += ", "; + } + log += ") "; + log += TIME; + log += ": "; + log += time; + DTRLog(log); } - log += ") = "; - log += arg; - DTRLog(log); - return ret; } void CheckpointTensorImpl::mutate(const std::string& name, @@ -79,20 +222,34 @@ void CheckpointTensorImpl::mutate(const std::string& name, return new_input_values; }; strongs input_values; - std::string log = name; - log += "("; + std::vector args; for (const Tensor& t : inputs) { auto ft = from_tensor(t); - log += std::get<1>(ft); - log += ", "; + args.push_back(std::get<1>(ft)); input_values.push_back(std::get<0>(ft)); } - log += ")"; - DTRLog(log); - auto modified = make_raw(remat, input_values); + auto ret = make_raw(remat, input_values); + const auto& modified = std::get<0>(ret); for (size_t idx: mutate_idx) { - cell_from_tensor(inputs[idx])->value = cell_from_tensor(modified[idx])->value; + cell_from_tensor(inputs[idx])->value = intrusive_ptr::make(modified[idx]); } + DTRLogMutate(name, args, mutate_idx, from_time(std::get<1>(ret))); +} + +void DTRLogRelease(const std::string& counter_name) { + if (log_json) { + json j; + j[INSTRUCTION] = RELEASE; + j[NAME] = counter_name; + DTRLog(j.dump()); + } else { + DTRLog(RELEASE + ": " + counter_name); + } +} + +void CheckpointTensorImpl::release_resources() { + DTRLogRelease(counter_name()); + ref.reset(); } } diff --git a/aten/src/ATen/CheckpointTensorImpl.h b/aten/src/ATen/CheckpointTensorImpl.h index de83e60ab14..e4c32a973df 100644 --- a/aten/src/ATen/CheckpointTensorImpl.h +++ b/aten/src/ATen/CheckpointTensorImpl.h @@ -28,6 +28,9 @@ void DTRLog(const std::string& str); struct CAFFE2_API CheckpointTensorCell : intrusive_ptr_target { Tensor t; explicit CheckpointTensorCell(const Tensor& t) : t(t.detach()) { } + size_t memory() { + return t.defined() ? t.numel() * t.itemsize() : 0; + } }; struct CAFFE2_API CheckpointTensorImplCell : intrusive_ptr_target { @@ -48,6 +51,9 @@ using Tensors = std::vector; using rematerialize_function_t = std::function; using mutate_function_t = std::function; +using time_t = std::chrono::time_point; +using duration_t = std::chrono::system_clock::duration; + inline DispatchKeySet convert_key_set(const DispatchKeySet& t) { CHECK(!t.has(DispatchKey::CheckpointTensorId)); auto ret = t.add(DispatchKey::CheckpointTensorId); @@ -61,13 +67,11 @@ struct CAFFE2_API CheckpointTensorImpl : TensorImpl { static int gen_counter() { return counter++; } - std::string counter_name() { + std::string counter_name() const { return std::string("x") + std::to_string(id); } intrusive_ptr ref; - void release_resources() final { - ref.reset(); - } + void release_resources() final; explicit CheckpointTensorImpl(const intrusive_ptr& ref) : TensorImpl(convert_key_set(ref->value->t.key_set()), ref->value->t.dtype(), ref->value->t.optional_device()), ref(ref) { } @@ -81,9 +85,7 @@ struct CAFFE2_API CheckpointTensorImpl : TensorImpl { const Tensors& inputs, const std::vector& mutate_idx); intrusive_ptr shallow_copy_and_detach(const VariableVersion& version_counter, - bool allow_tensor_metadata_change) const override { - return intrusive_ptr::make(ref); - } + bool allow_tensor_metadata_change) const override; int64_t dim() const override { return ref->value->t.dim(); } diff --git a/aten/src/ATen/native/Checkpoint.cpp b/aten/src/ATen/native/Checkpoint.cpp index f09a7f7e4e7..92ae62315ce 100644 --- a/aten/src/ATen/native/Checkpoint.cpp +++ b/aten/src/ATen/native/Checkpoint.cpp @@ -3,21 +3,6 @@ namespace at { namespace native { -Tensor checkpoint(const Tensor& t) { - return Tensor(intrusive_ptr::make(t.detach())); -} - -Tensor decheckpoint(const Tensor& t) { - auto* cpti = dynamic_cast(t.unsafeGetTensorImpl()); - CHECK(cpti != nullptr); - return cpti->ref->value->t; -} - -bool is_checkpoint(const Tensor& t) { - auto* cpti = dynamic_cast(t.unsafeGetTensorImpl()); - return cpti != nullptr; -} - Tensor checkpoint_add(const Tensor& a, const Tensor& b, c10::Scalar c) { rematerialize_function_t rt = [=](const Tensors& vec) -> Tensors { diff --git a/third_party/json b/third_party/json new file mode 160000 index 00000000000..19843b038ca --- /dev/null +++ b/third_party/json @@ -0,0 +1 @@ +Subproject commit 19843b038caa463164d6f89ea1b2765fae7552e9 From dab0aa8f6df3f073bee7abc3408e72b7462a7f80 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9B=BE=E9=9B=A8=E9=AD=94=E7=90=86=E6=B2=99?= Date: Wed, 15 Apr 2020 12:16:48 -0700 Subject: [PATCH 06/42] Fix logger (special handling for small constants) (#19) * save * fix comment --- aten/src/ATen/CheckpointTensorImpl.cpp | 78 ++++++++++++++++++++------ aten/src/ATen/CheckpointTensorImpl.h | 9 --- 2 files changed, 61 insertions(+), 26 deletions(-) diff --git a/aten/src/ATen/CheckpointTensorImpl.cpp b/aten/src/ATen/CheckpointTensorImpl.cpp index 67e5748b559..03891985460 100644 --- a/aten/src/ATen/CheckpointTensorImpl.cpp +++ b/aten/src/ATen/CheckpointTensorImpl.cpp @@ -35,6 +35,7 @@ const std::string ARGS = "ARGS"; const std::string MEMORY = "MEMORY"; const std::string NAME = "NAME"; const std::string CONSTANT = "CONSTANT"; +const std::string CONSTANTS = "CONSTANTS"; void DTRLogConstant(const std::string& name) { if (log_json) { @@ -106,12 +107,27 @@ Tensor checkpoint_raw(const Tensor& t) { return Tensor(intrusive_ptr::make(t.detach())); } +// remat take a single vector of tensors, +// while there are two vector, one storing nonconstants and one storing constants. +// the constants are small and they will not be considered for eviction. +// however, we have to stitch the two vectors together to pass it in remat. +// the size_t in constants decide the location to stitch them in, while input_values fill in the rest. std::tuple make_raw(const rematerialize_function_t& remat, - const strongs& input_values) { + const strongs& input_values, + const std::vector>& constants) { std::vector input; - for (const strong& s: input_values) { - CHECK(!s->t.key_set().has(DispatchKey::CheckpointTensorId)); - input.push_back(s->t); + size_t i = 0, j = 0; + while (i != input_values.size() || j != constants.size()) { + if (j < constants.size() && std::get<1>(constants[j]) == input.size()) { + input.push_back(std::get<0>(constants[j])); + ++j; + } + else { + CHECK(i < input_values.size()); + CHECK(!input_values[i]->t.key_set().has(DispatchKey::CheckpointTensorId)); + input.push_back(input_values[i]->t); + ++i; + } } time_t pre = std::chrono::system_clock::now(); auto output = remat(input); @@ -123,16 +139,22 @@ std::string from_time(duration_t t) { return std::to_string(std::chrono::nanoseconds(t).count()); } -void DTRLogCall(const std::vector& res, const std::string& name, const std::vector& args, const std::string& time) { +void DTRLogCall(const std::vector& res, + const std::string& name, + const std::vector& args, + const std::vector& constants, + const std::string& time) { if (log_json) { json j; j[INSTRUCTION] = "CALL"; j[NAME] = name; j["RESULT"] = res; j[ARGS] = args; + j[CONSTANTS] = constants; j[TIME] = time; DTRLog(j.dump()); } else { + CHECK(constants.size() == 0); //TODO: implement. std::string arg = name + "("; for (const auto& s : arg) { arg += s; @@ -156,21 +178,29 @@ Tensors CheckpointTensorImpl::make(const std::string& name, const rematerialize_function_t& remat, const Tensors& input) { strongs input_values; + std::vector> constants; + std::vector constant_idx; std::vector args; for (const Tensor& t: input) { - auto ft = from_tensor(t); - input_values.push_back(std::get<0>(ft)); - args.push_back(std::get<1>(ft)); + if (auto* cpt = dynamic_cast(t.unsafeGetTensorImpl())) { + input_values.push_back(cpt->ref->value); + args.push_back(cpt->counter_name()); + } + else { + size_t idx = input_values.size() + constants.size(); + constants.push_back({t, idx}); + constant_idx.push_back(idx); + } } std::vector res; - auto ret = make_raw(remat, input_values); + auto ret = make_raw(remat, input_values, constants); Tensors tensors; for (const Tensor& t: std::get<0>(ret)) { auto cp = checkpoint_raw(t); tensors.push_back(cp); res.push_back(get_cpti(cp)->counter_name()); } - DTRLogCall(res, name, args, from_time(std::get<1>(ret))); + DTRLogCall(res, name, args, constant_idx, from_time(std::get<1>(ret))); for (const Tensor& t: tensors) { auto cpti = get_cpti(t); DTRLogMemory(cpti->counter_name(), cpti->ref->value->memory()); @@ -178,16 +208,22 @@ Tensors CheckpointTensorImpl::make(const std::string& name, return tensors; } -void DTRLogMutate(const std::string& name, const std::vector& args, const std::vector& mutate, const std::string& time) { +void DTRLogMutate(const std::string& name, + const std::vector& args, + const std::vector& constants, + const std::vector& mutate, + const std::string& time) { if (log_json) { json j; j[INSTRUCTION] = "MUTATE"; j[NAME] = name; j[ARGS] = args; + j[CONSTANTS] = constants; j["MUTATE"] = mutate; j[TIME] = time; DTRLog(j.dump()); } else { + CHECK(constants.size() == 0); //TODO: implement. std::string log = name; log += "("; for (const auto& s : args) { @@ -222,18 +258,26 @@ void CheckpointTensorImpl::mutate(const std::string& name, return new_input_values; }; strongs input_values; + std::vector> constants; + std::vector constant_idx; std::vector args; - for (const Tensor& t : inputs) { - auto ft = from_tensor(t); - args.push_back(std::get<1>(ft)); - input_values.push_back(std::get<0>(ft)); + for (const Tensor& t: inputs) { + if (auto* cpt = dynamic_cast(t.unsafeGetTensorImpl())) { + input_values.push_back(cpt->ref->value); + args.push_back(cpt->counter_name()); + } + else { + size_t idx = input_values.size() + constants.size(); + constants.push_back({t, idx}); + constant_idx.push_back(idx); + } } - auto ret = make_raw(remat, input_values); + auto ret = make_raw(remat, input_values, constants); const auto& modified = std::get<0>(ret); for (size_t idx: mutate_idx) { cell_from_tensor(inputs[idx])->value = intrusive_ptr::make(modified[idx]); } - DTRLogMutate(name, args, mutate_idx, from_time(std::get<1>(ret))); + DTRLogMutate(name, args, constant_idx, mutate_idx, from_time(std::get<1>(ret))); } void DTRLogRelease(const std::string& counter_name) { diff --git a/aten/src/ATen/CheckpointTensorImpl.h b/aten/src/ATen/CheckpointTensorImpl.h index e4c32a973df..259bcfe2f23 100644 --- a/aten/src/ATen/CheckpointTensorImpl.h +++ b/aten/src/ATen/CheckpointTensorImpl.h @@ -112,15 +112,6 @@ inline CheckpointTensorImpl* get_cpti(const Tensor& t) { return cpti; } -inline std::tuple from_tensor(const Tensor& t) { - auto* cpt = dynamic_cast(t.unsafeGetTensorImpl()); - if(cpt != nullptr) { - return {cpt->ref->value, cpt->counter_name()}; - } else { - return from_tensor(native::checkpoint(t)); - } -} - inline Tensor get(const strong& s) { return s->t; } From 0680ddb81f6f5cdb01cb74b8badde6a14ab22e0f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9B=BE=E9=9B=A8=E9=AD=94=E7=90=86=E6=B2=99?= Date: Fri, 17 Apr 2020 16:52:04 -0700 Subject: [PATCH 07/42] TreeLSTM overloads (#21) --- aten/src/ATen/native/Checkpoint.cpp | 178 +++++++++++++++++++++ aten/src/ATen/native/native_functions.yaml | 47 ++++++ 2 files changed, 225 insertions(+) diff --git a/aten/src/ATen/native/Checkpoint.cpp b/aten/src/ATen/native/Checkpoint.cpp index 92ae62315ce..4a00913fe1e 100644 --- a/aten/src/ATen/native/Checkpoint.cpp +++ b/aten/src/ATen/native/Checkpoint.cpp @@ -20,6 +20,23 @@ Tensor& checkpoint_add_(Tensor& a, const Tensor& b, Scalar c) { return a; } +Tensor checkpoint_mul(at::Tensor const& a, at::Tensor const& b) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::mul(vec.at(0), vec.at(1))}; + }; + return CheckpointTensorImpl::make("mul", rt, {a, b})[0]; +} + +Tensor& checkpoint_mul_(at::Tensor& a, at::Tensor const& b) { + mutate_function_t mt = + [=](const Tensors& vec) { + vec.at(0).mul_(vec.at(1)); + }; + CheckpointTensorImpl::mutate("mul_", mt, {a, b}, {0}); + return a; +} + Tensor checkpoint_abs(const Tensor& a) { rematerialize_function_t rt = [=](const Tensors& vec) -> Tensors { @@ -615,4 +632,165 @@ Tensor& checkpoint_zero_(Tensor& a) { return a; } +Tensor& checkpoint_squeeze_(at::Tensor& a, at::Dimname b) { + mutate_function_t mt = + [=](const Tensors& vec) { + vec.at(0).squeeze_(b); + }; + CheckpointTensorImpl::mutate("squeeze_", mt, {a}, {0}); + return a; +} + +Tensor& checkpoint_squeeze_(at::Tensor& a) { + mutate_function_t mt = + [=](const Tensors& vec) { + vec.at(0).squeeze_(); + }; + CheckpointTensorImpl::mutate("squeeze_", mt, {a}, {0}); + return a; +} + +Tensor& checkpoint_squeeze_(at::Tensor& a, long b) { + mutate_function_t mt = + [=](const Tensors& vec) { + vec.at(0).squeeze_(b); + }; + CheckpointTensorImpl::mutate("squeeze_", mt, {a}, {0}); + return a; +} + +Tensor checkpoint_sigmoid_backward(at::Tensor const& a, at::Tensor const& b) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::sigmoid_backward(vec.at(0), vec.at(1))}; + }; + return CheckpointTensorImpl::make("sigmoid_backward", rt, {a, b})[0]; +} + +Tensor& checkpoint_sigmoid_backward_out(at::Tensor& a, at::Tensor const& b, at::Tensor const& c) { + mutate_function_t mt = + [=](const Tensors& vec) { + Tensor a_ = vec.at(0); + at::sigmoid_backward_out(a_, vec.at(1), vec.at(2)); + }; + CheckpointTensorImpl::mutate("sigmoid_backward_out", mt, {a, b, c}, {0}); + return a; +} + +Tensor& checkpoint_sign_out(at::Tensor& a, at::Tensor const& b) { + mutate_function_t mt = + [=](const Tensors& vec) { + Tensor a_ = vec.at(0); + at::sign_out(a_, vec.at(1)); + }; + CheckpointTensorImpl::mutate("sign_out", mt, {a, b}, {0}); + return a; +} + +Tensor checkpoint_sign(const Tensor& a) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::sign(vec.at(0))}; + }; + return CheckpointTensorImpl::make("sign", rt, {a})[0]; +} + +Tensor checkpoint_tanh(const Tensor& a) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::tanh(vec.at(0))}; + }; + return CheckpointTensorImpl::make("tanh", rt, {a})[0]; +} + +Tensor checkpoint_tanh_backward(at::Tensor const& a, at::Tensor const& b) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::tanh_backward(vec.at(0), vec.at(1))}; + }; + return CheckpointTensorImpl::make("tanh_backward", rt, {a, b})[0]; +} + +Tensor& checkpoint_tanh_backward_out(at::Tensor& a, at::Tensor const& b, at::Tensor const& c) { + mutate_function_t mt = + [=](const Tensors& vec) { + Tensor a_ = vec.at(0); + at::tanh_backward_out(a_, vec.at(1), vec.at(2)); + }; + CheckpointTensorImpl::mutate("tanh_backward_out", mt, {a, b, c}, {0}); + return a; +} + +Tensor checkpoint_neg(at::Tensor const& a) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::neg(vec.at(0))}; + }; + return CheckpointTensorImpl::make("neg", rt, {a})[0]; +} + +Tensor checkpoint_sub(at::Tensor const& a, at::Tensor const& b, c10::Scalar c) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::sub(vec.at(0), vec.at(1), c)}; + }; + return CheckpointTensorImpl::make("sub", rt, {a, b})[0]; +} + +Tensor checkpoint_repeat(const at::Tensor& a, c10::ArrayRef b) { + std::vector b_ = b.vec(); + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {vec.at(0).repeat(b_)}; + }; + return CheckpointTensorImpl::make("repeat", rt, {a})[0]; +} + +Tensor checkpoint__cat(c10::ArrayRef a, long b) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::cat(vec, b)}; + }; + std::vector s; + for (const Tensor& t : a) { + s.push_back(t); + } + return CheckpointTensorImpl::make("_cat", rt, s)[0]; +} + +Tensor& checkpoint__cat_out(Tensor& a, c10::ArrayRef b, long c) { + std::vector args; + args.push_back(a); + for (const Tensor& t : b) { + args.push_back(t); + } + mutate_function_t mt = + [=](const Tensors& vec) { + Tensor t = vec[0]; + at::cat_out(t, ArrayRef(vec.data() + 1, vec.size() - 1), c); + }; + CheckpointTensorImpl::mutate("_cat_out", mt, args, {0}); + return a; +} + +Tensor checkpoint_kl_div(at::Tensor const& a, at::Tensor const& b, long c) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::kl_div(vec.at(0), vec.at(1), c)}; + }; + return CheckpointTensorImpl::make("kl_div", rt, {a, b})[0]; +} + +Tensor checkpoint_kl_div_backward(at::Tensor const& a, at::Tensor const& b, at::Tensor const& c, long d) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::kl_div_backward(vec.at(0), vec.at(1), vec.at(2), d)}; + }; + return CheckpointTensorImpl::make("kl_div_backward", rt, {a, b, c})[0]; +} + +Scalar checkpoint__local_scalar_dense(at::Tensor const& a) { + return at::_local_scalar_dense(decheckpoint(a)); +} + }} diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index e9b6b23f4a6..04ff2643319 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -1579,12 +1579,17 @@ - func: kl_div(Tensor self, Tensor target, int reduction=Mean) -> Tensor use_c10_dispatcher: full + dispatch: + CPU: kl_div + CUDA: kl_div + Checkpoint: checkpoint_kl_div - func: kl_div_backward(Tensor grad_output, Tensor self, Tensor target, int reduction=Mean) -> Tensor use_c10_dispatcher: full dispatch: CPU: kl_div_backward_cpu CUDA: kl_div_backward_cuda + Checkpoint: checkpoint_kl_div_backward - func: kthvalue(Tensor self, int k, int dim=-1, bool keepdim=False) -> (Tensor values, Tensor indices) supports_named_tensor: True @@ -2025,6 +2030,7 @@ SparseCPU: mul_sparse SparseCUDA: mul_sparse MkldnnCPU: mkldnn_mul + Checkpoint: checkpoint_mul supports_named_tensor: True - func: mul_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) @@ -2035,6 +2041,7 @@ SparseCPU: mul_sparse_ SparseCUDA: mul_sparse_ MkldnnCPU: mkldnn_mul_ + Checkpoint: checkpoint_mul_ supports_named_tensor: True - func: mul.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) @@ -2314,6 +2321,10 @@ use_c10_dispatcher: full supports_named_tensor: True variants: function, method + dispatch: + CPU: neg + CUDA: neg + Checkpoint: checkpoint_neg - func: neg_(Tensor(a!) self) -> Tensor(a!) supports_named_tensor: True @@ -2327,6 +2338,10 @@ - func: repeat(Tensor self, int[] repeats) -> Tensor variants: method # This is method-only to match the previous tensor API. In the future we could make this a function too. + dispatch: + CPU: repeat + CUDA: repeat + Checkpoint: checkpoint_repeat - func: repeat_interleave.Tensor(Tensor repeats) -> Tensor use_c10_dispatcher: full @@ -2642,14 +2657,26 @@ - func: squeeze_(Tensor(a!) self) -> Tensor(a!) variants: method device_guard: False + dispatch: + CPU: squeeze_ + CUDA: squeeze_ + Checkpoint: checkpoint_squeeze_ - func: squeeze_.dim(Tensor(a!) self, int dim) -> Tensor(a!) variants: method device_guard: False + dispatch: + CPU: squeeze_ + CUDA: squeeze_ + Checkpoint: checkpoint_squeeze_ - func: squeeze_.dimname(Tensor(a!) self, Dimname dim) -> Tensor(a!) variants: method device_guard: False + dispatch: + CPU: squeeze_ + CUDA: squeeze_ + Checkpoint: checkpoint_squeeze_ - func: sspaddmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor use_c10_dispatcher: full @@ -2820,6 +2847,7 @@ CPU: tanh CUDA: tanh QuantizedCPU: quantized_tanh + Checkpoint: checkpoint_tanh - func: tanh_(Tensor(a!) self) -> Tensor(a!) supports_named_tensor: True @@ -3244,6 +3272,7 @@ CUDA: sub SparseCPU: sub_sparse SparseCUDA: sub_sparse + Checkpoint: checkpoint_sub supports_named_tensor: True - func: sub_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> Tensor(a!) @@ -3799,6 +3828,7 @@ dispatch: CPU: _local_scalar_dense_cpu CUDA: _local_scalar_dense_cuda + Checkpoint: checkpoint__local_scalar_dense variants: function supports_named_tensor: True @@ -5007,6 +5037,10 @@ - func: sign(Tensor self) -> Tensor variants: function, method supports_named_tensor: True + dispatch: + CPU: sign + CUDA: sign + Checkpoint: checkpoint_sign - func: sign_(Tensor(a!) self) -> Tensor(a!) variants: method @@ -5017,6 +5051,7 @@ dispatch: CPU: sign_out CUDA: sign_out + Checkpoint: checkpoint_sign_out - func: dist(Tensor self, Tensor other, Scalar p=2) -> Tensor use_c10_dispatcher: full @@ -5375,12 +5410,14 @@ CPU: _cat_cpu CUDA: cat_cuda QuantizedCPU: quantized_cat + Checkpoint: checkpoint__cat - func: _cat.out(Tensor[] tensors, int dim=0, *, Tensor(a!) out) -> Tensor(a!) dispatch: CPU: _cat_out_cpu CUDA: cat_out_cuda QuantizedCPU: quantized_cat_out + Checkpoint: checkpoint__cat_out - func: _mode(Tensor self, int dim=-1, bool keepdim=False) -> (Tensor, Tensor) dispatch: @@ -6400,20 +6437,30 @@ dispatch: CPU: sigmoid_backward_out CUDA: sigmoid_backward_out + Checkpoint: checkpoint_sigmoid_backward_out - func: sigmoid_backward(Tensor grad_output, Tensor output) -> Tensor use_c10_dispatcher: full python_module: nn + dispatch: + CPU: sigmoid_backward + CUDA: sigmoid_backward + Checkpoint: checkpoint_sigmoid_backward - func: tanh_backward.grad_input(Tensor grad_output, Tensor output, *, Tensor(a!) grad_input) -> Tensor(a!) python_module: nn dispatch: CPU: tanh_backward_out CUDA: tanh_backward_out + Checkpoint: checkpoint_tanh_backward_out - func: tanh_backward(Tensor grad_output, Tensor output) -> Tensor use_c10_dispatcher: full python_module: nn + dispatch: + CPU: tanh_backward + CUDA: tanh_backward + Checkpoint: checkpoint_tanh_backward # What's a thnn_conv_ versus a slow_conv_? # From 3008ba5c223dcaf9bb997a976a6721baacaf09e5 Mon Sep 17 00:00:00 2001 From: "Steven S. Lyubomirsky" Date: Fri, 17 Apr 2020 20:09:38 -0700 Subject: [PATCH 08/42] Overloads for mean and mean.dim (needed for densenet) --- aten/src/ATen/native/Checkpoint.cpp | 17 +++++++++++++++++ aten/src/ATen/native/native_functions.yaml | 2 ++ 2 files changed, 19 insertions(+) diff --git a/aten/src/ATen/native/Checkpoint.cpp b/aten/src/ATen/native/Checkpoint.cpp index 4a00913fe1e..cfa7ecfaf68 100644 --- a/aten/src/ATen/native/Checkpoint.cpp +++ b/aten/src/ATen/native/Checkpoint.cpp @@ -746,6 +746,23 @@ Tensor checkpoint_repeat(const at::Tensor& a, c10::ArrayRef b) { return CheckpointTensorImpl::make("repeat", rt, {a})[0]; } +Tensor checkpoint_mean(const Tensor& self, c10::optional dtype) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::native::mean_cpu_gpu(vec[0], dtype)}; + }; + return CheckpointTensorImpl::make("mean", rt, {self})[0]; +} + +Tensor checkpoint_mean(const Tensor& self, IntArrayRef dim, bool keepdim, c10::optional dtype) { + std::vector dim_ = dim.vec(); + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::native::mean_cpu_gpu(vec[0], dim_, keepdim, dtype)}; + }; + return CheckpointTensorImpl::make("mean.dim", rt, {self})[0]; +} + Tensor checkpoint__cat(c10::ArrayRef a, long b) { rematerialize_function_t rt = [=](const Tensors& vec) -> Tensors { diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 04ff2643319..4dcea3d4e6e 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -1848,6 +1848,7 @@ CPU: mean_cpu_gpu CUDA: mean_cpu_gpu QuantizedCPU: quantized_mean_cpu + Checkpoint: checkpoint_mean - func: mean.dim(Tensor self, int[1] dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor variants: function, method @@ -1856,6 +1857,7 @@ CPU: mean_cpu_gpu CUDA: mean_cpu_gpu QuantizedCPU: quantized_mean_cpu + Checkpoint: checkpoint_mean - func: mean.out(Tensor self, int[1] dim, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!) supports_named_tensor: True From b2cc5f42c41966eb1c8e211c7b6ed18e1f3f3a11 Mon Sep 17 00:00:00 2001 From: "Steven S. Lyubomirsky" Date: Thu, 30 Apr 2020 13:58:27 -0700 Subject: [PATCH 09/42] Add overloads for U-Net --- aten/src/ATen/native/Checkpoint.cpp | 89 ++++++++++++++++++++++ aten/src/ATen/native/native_functions.yaml | 15 ++++ 2 files changed, 104 insertions(+) diff --git a/aten/src/ATen/native/Checkpoint.cpp b/aten/src/ATen/native/Checkpoint.cpp index cfa7ecfaf68..5633244b6ac 100644 --- a/aten/src/ATen/native/Checkpoint.cpp +++ b/aten/src/ATen/native/Checkpoint.cpp @@ -737,6 +737,17 @@ Tensor checkpoint_sub(at::Tensor const& a, at::Tensor const& b, c10::Scalar c) { return CheckpointTensorImpl::make("sub", rt, {a, b})[0]; } +Tensor& checkpoint_sub_(at::Tensor& a, at::Tensor const& b, c10::Scalar c) { + mutate_function_t mt = + [=](const Tensors& vec) { + Tensor self = vec.at(0); + self.sub_(vec.at(1), c); + }; + CheckpointTensorImpl::mutate("sub_", mt, {a, b}, {0}); + return a; +} + + Tensor checkpoint_repeat(const at::Tensor& a, c10::ArrayRef b) { std::vector b_ = b.vec(); rematerialize_function_t rt = @@ -806,6 +817,84 @@ Tensor checkpoint_kl_div_backward(at::Tensor const& a, at::Tensor const& b, at:: return CheckpointTensorImpl::make("kl_div_backward", rt, {a, b, c})[0]; } +Tensor checkpoint_upsample_bilinear2d(at::Tensor const& self, c10::ArrayRef output_size, bool align_corners, c10::optional scales_h, c10::optional scales_w) { + std::vector output_size_ = output_size.vec(); + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::upsample_bilinear2d(vec.at(0), output_size_, align_corners, scales_h, scales_w)}; + }; + return CheckpointTensorImpl::make("upsample_bilinear2d", rt, {self})[0]; +} + +Tensor& checkpoint_upsample_bilinear2d_out(at::Tensor& out, const at::Tensor& self, c10::ArrayRef output_size, bool align_corners, c10::optional scales_h, c10::optional scales_w) { + std::vector output_size_ = output_size.vec(); + mutate_function_t mt = + [=](const Tensors& vec) { + Tensor out = vec.at(0); + at::upsample_bilinear2d_out(out, vec.at(1), output_size_, align_corners, scales_h, scales_w); + }; + CheckpointTensorImpl::mutate("binary_cross_entropy_out", mt, {out, self}, {0}); + return out; +} + +Tensor& checkpoint_upsample_bilinear2d_backward_out(at::Tensor& grad_input, const at::Tensor& grad_output, c10::ArrayRef output_size, c10::ArrayRef input_size, bool align_corners, c10::optional scales_h, c10::optional scales_w) { + std::vector output_size_ = output_size.vec(); + std::vector input_size_ = input_size.vec(); + mutate_function_t mt = + [=](const Tensors& vec) { + Tensor grad_input = vec.at(0); + at::upsample_bilinear2d_backward_out(grad_input, vec.at(1), output_size_, input_size_, align_corners, scales_h, scales_w); + }; + CheckpointTensorImpl::mutate("upsample_bilinear2d_backward_out", mt, {grad_input, grad_output}, {0}); + return grad_input; +} + +Tensor checkpoint_upsample_bilinear2d_backward(at::Tensor const& grad_output, c10::ArrayRef output_size, c10::ArrayRef input_size, bool align_corners, c10::optional scales_h, c10::optional scales_w) { + std::vector output_size_ = output_size.vec(); + std::vector input_size_ = input_size.vec(); + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::upsample_bilinear2d_backward(vec.at(0), output_size_, input_size_, align_corners, scales_h, scales_w)}; + }; + return CheckpointTensorImpl::make("upsample_bilinear2d_backward", rt, {grad_output})[0]; +} + +Tensor& checkpoint_clamp_min_(Tensor& a, Scalar min) { + mutate_function_t mt = + [=](const Tensors& vec) { + Tensor self = vec.at(0); + at::clamp_min_(self, min); + }; + CheckpointTensorImpl::mutate("clamp_min_", mt, {a}, {0}); + return a; +} + +Tensor& checkpoint_clamp_min__out(Tensor& out, const Tensor& self, Scalar min) { + mutate_function_t mt = + [=](const Tensors& vec) { + Tensor out = vec.at(0); + at::clamp_min_out(out, vec.at(1), min); + }; + CheckpointTensorImpl::mutate("clamp_min__out", mt, {out, self}, {0}); + return out; +} + +Tensor checkpoint_binary_cross_entropy_with_logits(const Tensor& input, const Tensor& target, const Tensor& weight, const Tensor& pos_weight, int64_t reduction) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::binary_cross_entropy_with_logits(vec.at(0), vec.at(1), vec.at(2), vec.at(3), reduction)}; + }; + return CheckpointTensorImpl::make("binary_cross_entropy_with_logits", rt, {input, target, weight, pos_weight})[0]; +} + +Tensor checkpoint_binary_cross_entropy_with_logits_backward(const Tensor& grad, const Tensor& input, const Tensor& target, const Tensor& weight, const Tensor& pos_weight, int64_t reduction) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::binary_cross_entropy_with_logits_backward(vec.at(0), vec.at(1), vec.at(2), vec.at(3), vec.at(4), reduction)}; + }; + return CheckpointTensorImpl::make("binary_cross_entropy_with_logits_backward", rt, {grad, input, target, weight, pos_weight})[0]; +} + Scalar checkpoint__local_scalar_dense(at::Tensor const& a) { return at::_local_scalar_dense(decheckpoint(a)); } diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 4dcea3d4e6e..c524be5c1ad 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -572,9 +572,17 @@ - func: binary_cross_entropy_with_logits(Tensor self, Tensor target, Tensor? weight=None, Tensor? pos_weight=None, int reduction=Mean) -> Tensor variants: function + dispatch: + CPU: binary_cross_entropy_with_logits + CUDA: binary_cross_entropy_with_logits + Checkpoint: checkpoint_binary_cross_entropy_with_logits - func: binary_cross_entropy_with_logits_backward(Tensor grad_output, Tensor self, Tensor target, Tensor? weight=None, Tensor? pos_weight=None, int reduction=Mean) -> Tensor variants: function + dispatch: + CPU: binary_cross_entropy_with_logits_backward + CUDA: binary_cross_entropy_with_logits_backward + Checkpoint: checkpoint_binary_cross_entropy_with_logits_backward - func: bincount(Tensor self, Tensor? weights=None, int minlength=0) -> Tensor variants: function, method @@ -761,12 +769,14 @@ dispatch: CPU: _clamp_min__cpu CUDA: _clamp_min__cuda + Checkpoint: checkpoint_clamp_min_ - func: clamp_min.out(Tensor self, Scalar min, *, Tensor(a!) out) -> Tensor(a!) supports_named_tensor: True dispatch: CPU: _clamp_min_out_cpu CUDA: _clamp_min_out_cuda + Checkpoint: checkpoint_clamp_min__out - func: cudnn_is_acceptable(Tensor self) -> bool use_c10_dispatcher: full @@ -3282,6 +3292,7 @@ dispatch: CPU: sub_ CUDA: sub_ + Checkpoint: checkpoint_sub_ SparseCPU: sub_sparse_ SparseCUDA: sub_sparse_ supports_named_tensor: True @@ -6293,6 +6304,7 @@ dispatch: CPU: upsample_bilinear2d_out_cpu CUDA: upsample_bilinear2d_out_cuda + Checkpoint: checkpoint_upsample_bilinear2d_out - func: upsample_bilinear2d(Tensor self, int[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor python_module: nn @@ -6300,18 +6312,21 @@ CPU: upsample_bilinear2d_cpu CUDA: upsample_bilinear2d_cuda QuantizedCPU: quantized_upsample_bilinear2d_cpu + Checkpoint: checkpoint_upsample_bilinear2d - func: upsample_bilinear2d_backward.grad_input(Tensor grad_output, int[2] output_size, int[4] input_size, bool align_corners, float? scales_h=None, float? scales_w=None, *, Tensor(a!) grad_input) -> Tensor(a!) python_module: nn dispatch: CPU: upsample_bilinear2d_backward_out_cpu CUDA: upsample_bilinear2d_backward_out_cuda + Checkpoint: checkpoint_upsample_bilinear2d_backward_out - func: upsample_bilinear2d_backward(Tensor grad_output, int[2] output_size, int[4] input_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor python_module: nn dispatch: CPU: upsample_bilinear2d_backward_cpu CUDA: upsample_bilinear2d_backward_cuda + Checkpoint: checkpoint_upsample_bilinear2d_backward - func: upsample_bicubic2d.out(Tensor self, int[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None, *, Tensor(a!) out) -> Tensor(a!) python_module: nn From 5c2311564359bc39fd0e397080308c7bd0e0b2fa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9B=BE=E9=9B=A8=E9=AD=94=E7=90=86=E6=B2=99?= Date: Mon, 4 May 2020 20:51:58 -0700 Subject: [PATCH 10/42] Add ability to annotate log and make new log file (#25) --- aten/src/ATen/CheckpointTensorImpl.cpp | 63 ++++++++++++++-------- aten/src/ATen/CheckpointTensorImpl.h | 2 - aten/src/ATen/native/native_functions.yaml | 6 +++ test.py | 11 ++-- 4 files changed, 51 insertions(+), 31 deletions(-) diff --git a/aten/src/ATen/CheckpointTensorImpl.cpp b/aten/src/ATen/CheckpointTensorImpl.cpp index 03891985460..98f46c79e1c 100644 --- a/aten/src/ATen/CheckpointTensorImpl.cpp +++ b/aten/src/ATen/CheckpointTensorImpl.cpp @@ -5,30 +5,34 @@ namespace at { struct DTRLogger { + std::string time_prefix; std::ofstream out; - static std::string get_filename() { + static std::string get_time_prefix() { std::time_t t = std::time(nullptr); std::tm* tm = std::localtime(&t); - std::string str = + return std::to_string(1900+tm->tm_year) + "-" + std::to_string(1+tm->tm_mon) + "-" + std::to_string(tm->tm_mday) + "-" + std::to_string(tm->tm_hour) + "-" + std::to_string(tm->tm_min) + "-" + - std::to_string(tm->tm_sec) + ".log"; - return str; + std::to_string(tm->tm_sec); + } + std::string get_filename(const std::string& name) { + return time_prefix + "-" + name + ".log"; + } + DTRLogger() : time_prefix(get_time_prefix()), out(get_filename("default")) { } + void log(const std::string& str) { + out << str << std::endl; } - DTRLogger() : out(get_filename()) { } }; -void DTRLog(const std::string& str) { - static DTRLogger logger; - logger.out << str << std::endl; -} +static DTRLogger logger; -using json = nlohmann::json;using json = nlohmann::json; +using json = nlohmann::json; bool log_json = true; const std::string INSTRUCTION = "INSTRUCTION"; +const std::string ANNOTATION = "ANNOTATION"; const std::string RELEASE = "RELEASE"; const std::string TIME = "TIME"; const std::string ARGS = "ARGS"; @@ -42,9 +46,9 @@ void DTRLogConstant(const std::string& name) { json j; j[INSTRUCTION] = CONSTANT; j[NAME] = name; - DTRLog(j.dump()); + logger.log(j.dump()); } else { - DTRLog(CONSTANT + " " + name); + logger.log(CONSTANT + " " + name); } } @@ -54,9 +58,9 @@ void DTRLogMemory(const std::string& name, size_t memory) { j[INSTRUCTION] = MEMORY; j[NAME] = name; j[MEMORY] = std::to_string(memory); - DTRLog(j.dump()); + logger.log(j.dump()); } else { - DTRLog(name + " " + MEMORY + ": " + std::to_string(memory)); + logger.log(name + " " + MEMORY + ": " + std::to_string(memory)); } } @@ -80,6 +84,21 @@ bool is_checkpoint(const Tensor& t) { return cpti != nullptr; } +void new_log(std::string str) { + logger.out = std::ofstream(logger.get_filename(str)); +} + +void annotate_log(std::string str) { + if (log_json) { + json j; + j[INSTRUCTION] = "ANNOTATE"; + j[ANNOTATION] = str; + logger.log(j.dump()); + } else { + logger.log(str); + } +} + } void DTRLogCopy(const std::string& new_name, const std::string& old_name) { @@ -88,9 +107,9 @@ void DTRLogCopy(const std::string& new_name, const std::string& old_name) { j[INSTRUCTION] = "COPY"; j["DST"] = new_name; j["SRC"] = old_name; - DTRLog(j.dump()); + logger.log(j.dump()); } else { - DTRLog(new_name + " = " + old_name); + logger.log(new_name + " = " + old_name); } } @@ -152,7 +171,7 @@ void DTRLogCall(const std::vector& res, j[ARGS] = args; j[CONSTANTS] = constants; j[TIME] = time; - DTRLog(j.dump()); + logger.log(j.dump()); } else { CHECK(constants.size() == 0); //TODO: implement. std::string arg = name + "("; @@ -170,7 +189,7 @@ void DTRLogCall(const std::vector& res, log += arg; log += " TIME: "; log += time; - DTRLog(log); + logger.log(log); } } @@ -221,7 +240,7 @@ void DTRLogMutate(const std::string& name, j[CONSTANTS] = constants; j["MUTATE"] = mutate; j[TIME] = time; - DTRLog(j.dump()); + logger.log(j.dump()); } else { CHECK(constants.size() == 0); //TODO: implement. std::string log = name; @@ -241,7 +260,7 @@ void DTRLogMutate(const std::string& name, log += TIME; log += ": "; log += time; - DTRLog(log); + logger.log(log); } } @@ -285,9 +304,9 @@ void DTRLogRelease(const std::string& counter_name) { json j; j[INSTRUCTION] = RELEASE; j[NAME] = counter_name; - DTRLog(j.dump()); + logger.log(j.dump()); } else { - DTRLog(RELEASE + ": " + counter_name); + logger.log(RELEASE + ": " + counter_name); } } diff --git a/aten/src/ATen/CheckpointTensorImpl.h b/aten/src/ATen/CheckpointTensorImpl.h index 259bcfe2f23..9fc960f7c3c 100644 --- a/aten/src/ATen/CheckpointTensorImpl.h +++ b/aten/src/ATen/CheckpointTensorImpl.h @@ -23,8 +23,6 @@ namespace at { -void DTRLog(const std::string& str); - struct CAFFE2_API CheckpointTensorCell : intrusive_ptr_target { Tensor t; explicit CheckpointTensorCell(const Tensor& t) : t(t.detach()) { } diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index c524be5c1ad..02fdb8fa768 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -9,6 +9,12 @@ - func: decheckpoint(Tensor self) -> Tensor variants: method +- func: new_log(str logname) -> () + variants: function + +- func: annotate_log(str logname) -> () + variants: function + # Temporary type cast operators. These are needed to trace type-casts now since # Type's are not supported in the IR. Instead, we call down to these # specialized operators for each datatype. diff --git a/test.py b/test.py index af2051c39ad..bebb7397265 100644 --- a/test.py +++ b/test.py @@ -1,8 +1,5 @@ import torch -x = torch.Tensor([1]).checkpoint() -y = torch.Tensor([2]).checkpoint() -z = x + y -print(z) -print(z.decheckpoint()) -print(z.is_checkpoint()) -print(z.decheckpoint().is_checkpoint()) +torch.annotate_log("hello") +torch.annotate_log("hello again") +torch.new_log("new") +torch.annotate_log("again") From 690883fd4cc75e965fa6630895cb9a35be7eb7d7 Mon Sep 17 00:00:00 2001 From: "Steven S. Lyubomirsky" Date: Wed, 6 May 2020 16:26:11 -0700 Subject: [PATCH 11/42] Restore overloads needed for LSTM and GRU --- aten/src/ATen/native/Checkpoint.cpp | 64 ++++++++++++++++++++++ aten/src/ATen/native/native_functions.yaml | 5 ++ 2 files changed, 69 insertions(+) diff --git a/aten/src/ATen/native/Checkpoint.cpp b/aten/src/ATen/native/Checkpoint.cpp index 5633244b6ac..7e0a850798c 100644 --- a/aten/src/ATen/native/Checkpoint.cpp +++ b/aten/src/ATen/native/Checkpoint.cpp @@ -895,6 +895,70 @@ Tensor checkpoint_binary_cross_entropy_with_logits_backward(const Tensor& grad, return CheckpointTensorImpl::make("binary_cross_entropy_with_logits_backward", rt, {grad, input, target, weight, pos_weight})[0]; } +std::tuple checkpoint__fused_dropout(const Tensor & self, double p, Generator* g) { + // TODO: Figure out how to properly duplicate the generator; + // note that the commented-out code below results in a segfault! + // Ref> gen; + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + // Generator* cur = gen.t ? gen.t.get() : g; + // auto newG = cur->clone(); + // auto res = at::_fused_dropout(vec.at(0), p, cur); + // gen.t = newG; + auto res = at::_fused_dropout(vec.at(0), p); + return {std::get<0>(res), std::get<1>(res)}; + }; + auto res = CheckpointTensorImpl::make("_fused_droupout_", rt, {self}); + return {res[0], res[1]}; +} + +std::tuple checkpoint__thnn_fused_lstm_cell(const Tensor& input_gates, const Tensor& hidden_gates, const Tensor& cx, const Tensor& input_bias, const Tensor& hidden_bias) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + auto res = at::_thnn_fused_lstm_cell(vec.at(0), vec.at(1), vec.at(2), + vec.at(3), vec.at(4)); + return {std::get<0>(res), std::get<1>(res), std::get<2>(res)}; + }; + auto res = CheckpointTensorImpl::make("_thnn_fused_lstm_cell", rt, + {input_gates, hidden_gates, cx, input_bias, hidden_bias}); + return {res[0], res[1], res[2]}; +} + +std::tuple checkpoint__thnn_fused_lstm_cell_backward(const Tensor& grad_hy, const Tensor& grad_cy, const Tensor& cx, const Tensor& cy, const Tensor& workspace, bool has_bias) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + auto res = at::_thnn_fused_lstm_cell_backward(vec.at(0), vec.at(1), vec.at(2), vec.at(3), vec.at(4), has_bias); + return {std::get<0>(res), std::get<1>(res), + std::get<2>(res), std::get<3>(res), std::get<4>(res)}; + }; + auto res = CheckpointTensorImpl::make("_thnn_fused_lstm_cell_backward", rt, + {grad_hy, grad_cy, cx, cy, workspace}); + return {res[0], res[1], res[2], res[3], res[4]}; +} + +std::tuple checkpoint__thnn_fused_gru_cell(const Tensor& input_gates, const Tensor& hidden_gates, const Tensor& hx, const Tensor& input_bias, const Tensor& hidden_bias) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + auto res = at::_thnn_fused_gru_cell(vec.at(0), vec.at(1), vec.at(2), vec.at(3), vec.at(4)); + return {std::get<0>(res), std::get<1>(res)}; + }; + auto res = CheckpointTensorImpl::make("_thnn_fused_gru_cell", rt, + {input_gates, hidden_gates, hx, input_bias, hidden_bias}); + return {res[0], res[1]}; +} + +std::tuple checkpoint__thnn_fused_gru_cell_backward(const Tensor& grad_hy, const Tensor& workspace, bool has_bias) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + auto res = at::_thnn_fused_gru_cell_backward(vec.at(0), vec.at(1), has_bias); + return {std::get<0>(res), std::get<1>(res), + std::get<2>(res), std::get<3>(res), std::get<4>(res)}; + }; + auto res = CheckpointTensorImpl::make("_thnn_fused_gru_cell_backward", rt, + {grad_hy, workspace}); + return {res[0], res[1], res[2], res[3], res[4]}; +} + Scalar checkpoint__local_scalar_dense(at::Tensor const& a) { return at::_local_scalar_dense(decheckpoint(a)); } diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 02fdb8fa768..b2a9fea4851 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -178,6 +178,7 @@ variants: function dispatch: CUDA: fused_dropout_cuda + Checkpoint: checkpoint__fused_dropout supports_named_tensor: True - func: _masked_scale(Tensor self, Tensor mask, float scale) -> Tensor @@ -3855,20 +3856,24 @@ - func: _thnn_fused_lstm_cell(Tensor input_gates, Tensor hidden_gates, Tensor cx, Tensor? input_bias=None, Tensor? hidden_bias=None) -> (Tensor, Tensor, Tensor) dispatch: CUDA: _thnn_fused_lstm_cell_cuda + Checkpoint: checkpoint__thnn_fused_lstm_cell - func: _thnn_fused_lstm_cell_backward(Tensor? grad_hy, Tensor? grad_cy, Tensor cx, Tensor cy, Tensor workspace, bool has_bias) -> (Tensor, Tensor, Tensor, Tensor, Tensor) dispatch: CUDA: _thnn_fused_lstm_cell_backward_cuda + Checkpoint: checkpoint__thnn_fused_lstm_cell_backward - func: _thnn_differentiable_lstm_cell_backward(Tensor? grad_hy, Tensor? grad_cy, Tensor input_gates, Tensor hidden_gates, Tensor? input_bias, Tensor? hidden_bias, Tensor cx, Tensor cy) -> (Tensor, Tensor, Tensor, Tensor, Tensor) - func: _thnn_fused_gru_cell(Tensor input_gates, Tensor hidden_gates, Tensor hx, Tensor? input_bias=None, Tensor? hidden_bias=None) -> (Tensor, Tensor) dispatch: CUDA: _thnn_fused_gru_cell_cuda + Checkpoint: checkpoint__thnn_fused_gru_cell - func: _thnn_fused_gru_cell_backward(Tensor grad_hy, Tensor workspace, bool has_bias) -> (Tensor, Tensor, Tensor, Tensor, Tensor) dispatch: CUDA: _thnn_fused_gru_cell_backward_cuda + Checkpoint: checkpoint__thnn_fused_gru_cell_backward - func: _thnn_differentiable_gru_cell_backward(Tensor grad_hy, Tensor input_gates, Tensor hidden_gates, Tensor hx, Tensor? input_bias, Tensor? hidden_bias) -> (Tensor, Tensor, Tensor, Tensor, Tensor) From d5f5e86cfc1e11b0df30399216ae4e6b4ec6cc3a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9B=BE=E9=9B=A8=E9=AD=94=E7=90=86=E6=B2=99?= Date: Wed, 6 May 2020 21:42:38 -0700 Subject: [PATCH 12/42] Implemented unrolled_gan (#23) * save * save * save --- aten/src/ATen/CheckpointTensorImpl.cpp | 37 ++++- aten/src/ATen/CheckpointTensorImpl.h | 1 + aten/src/ATen/native/Checkpoint.cpp | 155 +++++++++++++++++++++ aten/src/ATen/native/native_functions.yaml | 79 +++++++++++ c10/core/TensorImpl.h | 5 +- test.py | 7 +- torch/csrc/utils/tensor_numpy.cpp | 3 +- 7 files changed, 277 insertions(+), 10 deletions(-) diff --git a/aten/src/ATen/CheckpointTensorImpl.cpp b/aten/src/ATen/CheckpointTensorImpl.cpp index 98f46c79e1c..0de3ced8b83 100644 --- a/aten/src/ATen/CheckpointTensorImpl.cpp +++ b/aten/src/ATen/CheckpointTensorImpl.cpp @@ -30,7 +30,7 @@ struct DTRLogger { static DTRLogger logger; using json = nlohmann::json; -bool log_json = true; +constexpr bool log_json = true; const std::string INSTRUCTION = "INSTRUCTION"; const std::string ANNOTATION = "ANNOTATION"; const std::string RELEASE = "RELEASE"; @@ -73,12 +73,17 @@ Tensor checkpoint(const Tensor& t) { return Tensor(cpti); } -Tensor decheckpoint(const Tensor& t) { +Tensor uncheckpoint(const Tensor& t) { auto* cpti = dynamic_cast(t.unsafeGetTensorImpl()); CHECK(cpti != nullptr); return cpti->ref->value->t; } +Tensor decheckpoint(const Tensor& t) { + auto* cpti = dynamic_cast(t.unsafeGetTensorImpl()); + return cpti ? cpti->ref->value->t : t; +} + bool is_checkpoint(const Tensor& t) { auto* cpti = dynamic_cast(t.unsafeGetTensorImpl()); return cpti != nullptr; @@ -95,12 +100,24 @@ void annotate_log(std::string str) { j[ANNOTATION] = str; logger.log(j.dump()); } else { - logger.log(str); + logger.log("# " + str); } } } +void DTRLogCopyFrom(const std::string& to, const std::string& from) { + if (log_json) { + json j; + j[INSTRUCTION] = "COPY_FROM"; + j["DST"] = to; + j["SRC"] = from; + logger.log(j.dump()); + } else { + logger.log(to + " <- " + from); + } +} + void DTRLogCopy(const std::string& new_name, const std::string& old_name) { if (log_json) { json j; @@ -120,6 +137,14 @@ intrusive_ptr CheckpointTensorImpl::shallow_copy_and_detach(const Va return ret; } +void CheckpointTensorImpl::shallow_copy_from(const c10::intrusive_ptr& impl) { + TORCH_CHECK(impl->key_set().has(DispatchKey::CheckpointTensorId)); + auto* cpti = dynamic_cast(impl.get()); + TORCH_CHECK(cpti != nullptr); + ref->value = cpti->ref->value; + DTRLogCopyFrom(counter_name(), cpti->counter_name()); +} + int CheckpointTensorImpl::counter = 0; Tensor checkpoint_raw(const Tensor& t) { @@ -138,7 +163,9 @@ std::tuple make_raw(const rematerialize_function_t& remat, size_t i = 0, j = 0; while (i != input_values.size() || j != constants.size()) { if (j < constants.size() && std::get<1>(constants[j]) == input.size()) { - input.push_back(std::get<0>(constants[j])); + Tensor t = std::get<0>(constants[j]); + TORCH_CHECK(!t.key_set().has(DispatchKey::CheckpointTensorId)); + input.push_back(t); ++j; } else { @@ -175,7 +202,7 @@ void DTRLogCall(const std::vector& res, } else { CHECK(constants.size() == 0); //TODO: implement. std::string arg = name + "("; - for (const auto& s : arg) { + for (const auto& s : args) { arg += s; arg += ", "; } diff --git a/aten/src/ATen/CheckpointTensorImpl.h b/aten/src/ATen/CheckpointTensorImpl.h index 9fc960f7c3c..5733127646e 100644 --- a/aten/src/ATen/CheckpointTensorImpl.h +++ b/aten/src/ATen/CheckpointTensorImpl.h @@ -84,6 +84,7 @@ struct CAFFE2_API CheckpointTensorImpl : TensorImpl { const std::vector& mutate_idx); intrusive_ptr shallow_copy_and_detach(const VariableVersion& version_counter, bool allow_tensor_metadata_change) const override; + void shallow_copy_from(const c10::intrusive_ptr& impl) override; int64_t dim() const override { return ref->value->t.dim(); } diff --git a/aten/src/ATen/native/Checkpoint.cpp b/aten/src/ATen/native/Checkpoint.cpp index 7e0a850798c..c1056e3be46 100644 --- a/aten/src/ATen/native/Checkpoint.cpp +++ b/aten/src/ATen/native/Checkpoint.cpp @@ -11,6 +11,22 @@ Tensor checkpoint_add(const Tensor& a, const Tensor& b, c10::Scalar c) { return CheckpointTensorImpl::make("add", rt, {a, b})[0]; } +Tensor checkpoint_t(at::Tensor const& a) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::t(vec.at(0))}; + }; + return CheckpointTensorImpl::make("t", rt, {a})[0]; +} + +Tensor checkpoint_add(at::Tensor const& a, c10::Scalar b, c10::Scalar c) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::add(vec.at(0), b, c)}; + }; + return CheckpointTensorImpl::make("add", rt, {a})[0]; +} + Tensor& checkpoint_add_(Tensor& a, const Tensor& b, Scalar c) { mutate_function_t mt = [=](const Tensors& vec) { @@ -37,6 +53,48 @@ Tensor& checkpoint_mul_(at::Tensor& a, at::Tensor const& b) { return a; } +Tensor& checkpoint_mul_(at::Tensor& a, c10::Scalar b) { + mutate_function_t mt = + [=](const Tensors& vec) { + vec.at(0).mul_(b); + }; + CheckpointTensorImpl::mutate("mul_", mt, {a}, {0}); + return a; +} + +Tensor checkpoint_zeros_like(at::Tensor const& a, c10::TensorOptions const& b, c10::optional c) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::zeros_like(vec.at(0), b, c)}; + }; + return CheckpointTensorImpl::make("zeros_like", rt, {a})[0]; +} + +Tensor checkpoint_ones_like(at::Tensor const& a, c10::TensorOptions const& b, c10::optional c) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::ones_like(vec.at(0), b, c)}; + }; + return CheckpointTensorImpl::make("ones_like", rt, {a})[0]; +} + +Tensor checkpoint_addcmul(at::Tensor const& a, at::Tensor const& b, at::Tensor const& c, c10::Scalar d) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::addcmul(vec.at(0), vec.at(1), vec.at(2), d)}; + }; + return CheckpointTensorImpl::make("addcmul", rt, {a, b, c})[0]; +} + +Tensor& checkpoint_addcmul_(at::Tensor& a, at::Tensor const& b, at::Tensor const& c, c10::Scalar d) { + mutate_function_t mt = + [=](const Tensors& vec) { + vec.at(0).addcmul_(vec.at(1), vec.at(2), d); + }; + CheckpointTensorImpl::mutate("addcmul_", mt, {a, b, c}, {0}); + return a; +} + Tensor checkpoint_abs(const Tensor& a) { rematerialize_function_t rt = [=](const Tensors& vec) -> Tensors { @@ -45,6 +103,40 @@ Tensor checkpoint_abs(const Tensor& a) { return CheckpointTensorImpl::make("abs", rt, {a})[0]; } +Tensor checkpoint_sqrt(const Tensor& a) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::sqrt(vec.at(0))}; + }; + return CheckpointTensorImpl::make("sqrt", rt, {a})[0]; +} + +Tensor& checkpoint_addcdiv_(at::Tensor& a, at::Tensor const& b, at::Tensor const& c, c10::Scalar d) { + mutate_function_t mt = + [=](const Tensors& vec) { + vec.at(0).addcdiv_(vec.at(1), vec.at(2), d); + }; + CheckpointTensorImpl::mutate("addcdiv_", mt, {a, b, c}, {0}); + return a; +} + +Tensor checkpoint_addcdiv(at::Tensor const& a, at::Tensor const& b, at::Tensor const& c, c10::Scalar d) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::addcdiv(vec.at(0), vec.at(1), vec.at(2), d)}; + }; + return CheckpointTensorImpl::make("addcdiv", rt, {a, b, c})[0]; +} + +Tensor checkpoint_to(at::Tensor const& a, c10::TensorOptions const& b, bool c, bool d, c10::optional e) { + c10::TensorOptions b_ = b; + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {vec.at(0).to(b_, c, d, e)}; + }; + return CheckpointTensorImpl::make("to", rt, {a})[0]; +} + Tensor checkpoint_div(const Tensor& a, const Tensor& b) { rematerialize_function_t rt = [=](const Tensors& vec) -> Tensors { @@ -62,6 +154,27 @@ Tensor& checkpoint_div_(Tensor& a, const Tensor& b) { return a; } +Tensor checkpoint_clone(at::Tensor const& a, c10::optional b) { + if (b) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::clone(vec.at(0), b)}; + }; + return CheckpointTensorImpl::make("clone", rt, {a})[0]; + } + else { + return a; + } +} + +Tensor checkpoint_where(at::Tensor const& a, at::Tensor const& b, at::Tensor const& c) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::where(vec.at(0), vec.at(1), vec.at(2))}; + }; + return CheckpointTensorImpl::make("where", rt, {a, b, c})[0]; +} + Tensor checkpoint_constant_pad_nd(Tensor const& a, c10::ArrayRef b, c10::Scalar c) { std::vector b_ = b.vec(); rematerialize_function_t rt = @@ -254,6 +367,48 @@ Tensor& checkpoint_relu_(Tensor& a) { return a; } +Tensor checkpoint_log(at::Tensor const& a) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::log(vec.at(0))}; + }; + return CheckpointTensorImpl::make("log", rt, {a})[0]; +} + +Tensor& checkpoint_log_out(at::Tensor& a, at::Tensor const& b) { + mutate_function_t mt = + [=](const Tensors& vec) { + Tensor a_ = vec.at(0); + at::log_out(a_, vec.at(1)); + }; + CheckpointTensorImpl::mutate("log_out", mt, {a, b}, {0}); + return a; +} + +Tensor checkpoint_rsub(at::Tensor const& a, at::Tensor const& b, c10::Scalar c) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::rsub(vec.at(0), vec.at(1), c)}; + }; + return CheckpointTensorImpl::make("rsub", rt, {a, b})[0]; +} + +Tensor checkpoint_rsub(at::Tensor const& a, c10::Scalar b, c10::Scalar c) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::rsub(vec.at(0), b, c)}; + }; + return CheckpointTensorImpl::make("rsub", rt, {a})[0]; +} + +Tensor checkpoint_mul(at::Tensor const& a, c10::Scalar b) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::mul(vec.at(0), b)}; + }; + return CheckpointTensorImpl::make("mul", rt, {a})[0]; +} + std::tuple checkpoint_max_pool2d_with_indices_out(Tensor& a, Tensor& b, const Tensor& c, c10::ArrayRef d, c10::ArrayRef e, c10::ArrayRef f, c10::ArrayRef g, bool h) { std::vector d_ = d.vec(), e_ = e.vec(), f_ = f.vec(), g_ = g.vec(); mutate_function_t mt = diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index b2a9fea4851..5636c3e49de 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -6,6 +6,12 @@ - func: is_checkpoint(Tensor self) -> bool variants: method +# convert checkpointed tensor into normal tensor. +# uncheckpoint assume the input is checkpointed and will fail otherwise. +# decheckpoint return the input if it is not checkpointed. +- func: uncheckpoint(Tensor self) -> Tensor + variants: method + - func: decheckpoint(Tensor self) -> Tensor variants: method @@ -333,6 +339,10 @@ use_c10_dispatcher: full variants: function, method supports_named_tensor: True + dispatch: + CPU: add + CUDA: add + Checkpoint: checkpoint_add - func: add_.Scalar(Tensor(a!) self, Scalar other, Scalar alpha=1) -> Tensor(a!) variants: method @@ -1681,6 +1691,10 @@ use_c10_dispatcher: full supports_named_tensor: True variants: function, method + dispatch: + CPU: log + CUDA: log + Checkpoint: checkpoint_log - func: log_(Tensor(a!) self) -> Tensor(a!) supports_named_tensor: True @@ -1691,6 +1705,7 @@ dispatch: CPU: log_out CUDA: log_out + Checkpoint: checkpoint_log_out - func: log10(Tensor self) -> Tensor use_c10_dispatcher: full @@ -2076,9 +2091,20 @@ - func: mul.Scalar(Tensor self, Scalar other) -> Tensor use_c10_dispatcher: full variants: function, method + dispatch: + CPU: mul + CUDA: mul + SparseCPU: mul + SparseCUDA: mul + MkldnnCPU: mul + Checkpoint: checkpoint_mul - func: mul_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) variants: method + dispatch: + CPU: mul_ + CUDA: mul_ + Checkpoint: checkpoint_mul_ - func: mv(Tensor self, Tensor vec) -> Tensor use_c10_dispatcher: full @@ -2188,6 +2214,10 @@ - func: ones_like(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor supports_named_tensor: True + dispatch: + CPU: ones_like + CUDA: ones_like + Checkpoint: checkpoint_ones_like - func: pairwise_distance(Tensor x1, Tensor x2, float p=2, float eps=1e-06, bool keepdim=False) -> Tensor use_c10_dispatcher: full @@ -2764,6 +2794,10 @@ use_c10_dispatcher: full supports_named_tensor: True variants: function, method + dispatch: + CPU: sqrt + CUDA: sqrt + Checkpoint: checkpoint_sqrt - func: sqrt_(Tensor(a!) self) -> Tensor(a!) supports_named_tensor: True @@ -2835,6 +2869,10 @@ device_guard: False variants: function, method supports_named_tensor: True + dispatch: + CPU: t + CUDA: t + Checkpoint: checkpoint_t - func: t_(Tensor(a!) self) -> Tensor(a!) device_guard: False @@ -3086,6 +3124,10 @@ - func: where.self(Tensor condition, Tensor self, Tensor other) -> Tensor use_c10_dispatcher: full variants: function, method + dispatch: + CPU: where + CUDA: where + Checkpoint: checkpoint_where - func: where(Tensor condition) -> Tensor[] variants: function @@ -3125,6 +3167,10 @@ - func: zeros_like(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor supports_named_tensor: True + dispatch: + CPU: zeros_like + CUDA: zeros_like + Checkpoint: checkpoint_zeros_like - func: _standard_gamma_grad(Tensor self, Tensor output) -> Tensor use_c10_dispatcher: full @@ -3239,6 +3285,7 @@ SparseCUDA: clone_sparse MkldnnCPU: mkldnn_clone QuantizedCPU: quantized_clone + Checkpoint: checkpoint_clone supports_named_tensor: True - func: resize_as_(Tensor(a!) self, Tensor the_template, *, MemoryFormat? memory_format=None) -> Tensor(a!) @@ -3318,12 +3365,24 @@ use_c10_dispatcher: full variants: function supports_named_tensor: True + dispatch: + CPU: rsub + CUDA: rsub + SparseCPU: rsub + SparseCUDA: rsub + Checkpoint: checkpoint_rsub # For C++ only, until we have conversion from C++ numbers to Tensor - func: rsub.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> Tensor use_c10_dispatcher: full variants: function supports_named_tensor: True + dispatch: + CPU: rsub + CUDA: rsub + SparseCPU: rsub + SparseCUDA: rsub + Checkpoint: checkpoint_rsub # Functionally the same as addmm, but we give it a different derivative formula # that doesn't propagate gradients to non-present entries on sparse. @@ -3796,6 +3855,10 @@ variants: method device_guard: False supports_named_tensor: True + dispatch: + CPU: to + CUDA: to + Checkpoint: checkpoint_to - func: to.device(Tensor self, Device device, ScalarType dtype, bool non_blocking=False, bool copy=False, MemoryFormat? memory_format=None) -> Tensor variants: method @@ -4427,6 +4490,10 @@ - func: addcdiv_(Tensor(a!) self, Tensor tensor1, Tensor tensor2, *, Scalar value=1) -> Tensor(a!) variants: method supports_named_tensor: True + dispatch: + CPU: addcdiv_ + CUDA: addcdiv_ + Checkpoint: checkpoint_addcdiv_ - func: random_.from(Tensor(a!) self, int from, int? to, *, Generator? generator=None) -> Tensor(a!) variants: method @@ -4805,10 +4872,18 @@ use_c10_dispatcher: full variants: method, function supports_named_tensor: True + dispatch: + CPU: addcmul + CUDA: addcmul + Checkpoint: checkpoint_addcmul - func: addcmul_(Tensor(a!) self, Tensor tensor1, Tensor tensor2, *, Scalar value=1) -> Tensor(a!) variants: method supports_named_tensor: True + dispatch: + CPU: addcmul_ + CUDA: addcmul_ + Checkpoint: checkpoint_addcmul_ - func: addcdiv.out(Tensor self, Tensor tensor1, Tensor tensor2, *, Scalar value=1, Tensor(a!) out) -> Tensor(a!) supports_named_tensor: True @@ -4817,6 +4892,10 @@ use_c10_dispatcher: full variants: method, function supports_named_tensor: True + dispatch: + CPU: addcdiv + CUDA: addcdiv + Checkpoint: checkpoint_addcdiv - func: lstsq.X(Tensor self, Tensor A, *, Tensor(a!) X, Tensor(b!) qr) -> (Tensor(a!) solution, Tensor(b!) QR) dispatch: diff --git a/c10/core/TensorImpl.h b/c10/core/TensorImpl.h index a024be2e4ab..58e6263d861 100644 --- a/c10/core/TensorImpl.h +++ b/c10/core/TensorImpl.h @@ -868,6 +868,9 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { * compatible with SparseCUDATensorId. */ inline bool has_compatible_shallow_copy_type(DispatchKeySet from) { + if (key_set_ == from) { + return true; + } if (key_set_.has(DispatchKey::CheckpointTensorId) || from.has(DispatchKey::CheckpointTensorId)) { return false; } @@ -881,7 +884,7 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { ts.has(DispatchKey::SparseCUDATensorId) || ts.has(DispatchKey::SparseHIPTensorId); }; - return (key_set_ == from) || (is_dense(key_set_) && is_dense(from)) || (is_sparse(key_set_) && is_sparse(from)); + return (is_dense(key_set_) && is_dense(from)) || (is_sparse(key_set_) && is_sparse(from)); } /** diff --git a/test.py b/test.py index bebb7397265..cdea37872be 100644 --- a/test.py +++ b/test.py @@ -1,5 +1,6 @@ import torch torch.annotate_log("hello") -torch.annotate_log("hello again") -torch.new_log("new") -torch.annotate_log("again") +x = torch.Tensor([1]).checkpoint() +y = torch.Tensor([2]).checkpoint() +z = x + y +x.data = z diff --git a/torch/csrc/utils/tensor_numpy.cpp b/torch/csrc/utils/tensor_numpy.cpp index 092d0a37da4..7e56c232786 100644 --- a/torch/csrc/utils/tensor_numpy.cpp +++ b/torch/csrc/utils/tensor_numpy.cpp @@ -73,7 +73,8 @@ static std::vector seq_to_aten_shape(PyObject *py_seq) { return result; } -PyObject* tensor_to_numpy(const at::Tensor& tensor) { +PyObject* tensor_to_numpy(const at::Tensor& tensor_) { + Tensor tensor = tensor_.decheckpoint(); if (tensor.device().type() != DeviceType::CPU) { throw TypeError( "can't convert %s device type tensor to numpy. Use Tensor.cpu() to " From d8a4aea88e939a9e5b1c107ff5a601983bd5349e Mon Sep 17 00:00:00 2001 From: "Steven S. Lyubomirsky" Date: Wed, 6 May 2020 21:44:06 -0700 Subject: [PATCH 13/42] Overload bitwise and for ParityACT --- aten/src/ATen/native/Checkpoint.cpp | 20 ++++++++++++++++++++ aten/src/ATen/native/native_functions.yaml | 2 ++ 2 files changed, 22 insertions(+) diff --git a/aten/src/ATen/native/Checkpoint.cpp b/aten/src/ATen/native/Checkpoint.cpp index c1056e3be46..82f62f746d0 100644 --- a/aten/src/ATen/native/Checkpoint.cpp +++ b/aten/src/ATen/native/Checkpoint.cpp @@ -1114,6 +1114,26 @@ std::tuple checkpoint__thnn_fused_gru_ce return {res[0], res[1], res[2], res[3], res[4]}; } +Tensor& checkpoint_bitwise_and_out(Tensor& self, const Tensor& other, const Tensor& out) { + mutate_function_t mt = + [=](const Tensors& vec) { + Tensor self = vec.at(0); + at::bitwise_and_out(self, vec.at(1), vec.at(2)); + }; + CheckpointTensorImpl::mutate("bitwise_and_out", mt, {self, other, out}, {0}); + return self; +} + +Tensor& checkpoint_bitwise_and_out(Tensor& self, const Tensor& out, Scalar other) { + mutate_function_t mt = + [=](const Tensors& vec) { + Tensor self = vec.at(0); + at::bitwise_and_out(self, vec.at(1), other); + }; + CheckpointTensorImpl::mutate("bitwise_and_out", mt, {self, out}, {0}); + return self; +} + Scalar checkpoint__local_scalar_dense(at::Tensor const& a) { return at::_local_scalar_dense(decheckpoint(a)); } diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 5636c3e49de..2619c1660cb 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -4221,12 +4221,14 @@ dispatch: CPU: bitwise_and_out CUDA: bitwise_and_out + Checkpoint: checkpoint_bitwise_and_out - func: bitwise_and.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) variants: function dispatch: CPU: bitwise_and_out CUDA: bitwise_and_out + Checkpoint: checkpoint_bitwise_and_out - func: bitwise_and.Scalar(Tensor self, Scalar other) -> Tensor variants: method, function From 97fb92d48a9b382f2f16c9689f69c5d93867dab6 Mon Sep 17 00:00:00 2001 From: "Steven S. Lyubomirsky" Date: Tue, 12 May 2020 19:39:31 -0700 Subject: [PATCH 14/42] Add various overloads for ACT model --- aten/src/ATen/native/Checkpoint.cpp | 69 ++++++++++++++++++++++ aten/src/ATen/native/native_functions.yaml | 18 ++++++ 2 files changed, 87 insertions(+) diff --git a/aten/src/ATen/native/Checkpoint.cpp b/aten/src/ATen/native/Checkpoint.cpp index 82f62f746d0..9ca3ffb5678 100644 --- a/aten/src/ATen/native/Checkpoint.cpp +++ b/aten/src/ATen/native/Checkpoint.cpp @@ -1134,6 +1134,75 @@ Tensor& checkpoint_bitwise_and_out(Tensor& self, const Tensor& out, Scalar other return self; } +Tensor& checkpoint_fill_(Tensor& self, const Tensor& value) { + mutate_function_t mt = + [=](const Tensors& vec) { + Tensor self = vec.at(0); + at::fill_(self, vec.at(1)); + }; + CheckpointTensorImpl::mutate("fill_", mt, {self, value}, {0}); + return self; +} + +Tensor& checkpoint_fill_(Tensor& self, Scalar value) { + mutate_function_t mt = + [=](const Tensors& vec) { + Tensor self = vec.at(0); + at::fill_(self, value); + }; + CheckpointTensorImpl::mutate("fill_", mt, {self}, {0}); + return self; +} + +Tensor& checkpoint_masked_select_out(Tensor& self, const Tensor& mask, const Tensor& out) { + mutate_function_t mt = + [=](const Tensors& vec) { + Tensor self = vec.at(0); + at::masked_select_out(self, vec.at(1), vec.at(2)); + }; + CheckpointTensorImpl::mutate("masked_select_out", mt, {self, mask, out}, {0}); + return self; +} + +Tensor checkpoint_masked_select(const Tensor& self, const Tensor& mask) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::masked_select(vec.at(0), vec.at(1))}; + }; + return CheckpointTensorImpl::make("masked_select", rt, {self, mask})[0]; +} + +Tensor checkpoint_index(const Tensor& self, ArrayRef indices) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + auto self = vec.at(0); + auto indices = std::vector(vec.begin() + 1, vec.end()); + return {at::index(self, indices)}; + }; + + std::vector s = {self}; + for (const Tensor& t: indices) { + s.push_back(t); + } + return CheckpointTensorImpl::make("index", rt, s)[0]; +} + +Tensor& checkpoint_index_put_(Tensor& self, ArrayRef indices, const Tensor& values, const bool accumulate) { + mutate_function_t mt = + [=](const Tensors& vec) { + Tensor self = vec.at(0); + auto values = vec.at(1); + auto indices = std::vector(vec.begin() + 2, vec.end()); + at::index_put_(self, indices, values, accumulate); + }; + std::vector s = {self, values}; + for (const Tensor& t: indices) { + s.push_back(t); + } + CheckpointTensorImpl::mutate("index_put_", mt, s, {0}); + return self; +} + Scalar checkpoint__local_scalar_dense(at::Tensor const& a) { return at::_local_scalar_dense(decheckpoint(a)); } diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 2619c1660cb..3e119e38b0a 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -1358,10 +1358,18 @@ - func: fill_.Scalar(Tensor(a!) self, Scalar value) -> Tensor(a!) supports_named_tensor: True variants: function, method + dispatch: + CPU: fill_ + CUDA: fill_ + Checkpoint: checkpoint_fill_ - func: fill_.Tensor(Tensor(a!) self, Tensor value) -> Tensor(a!) supports_named_tensor: True variants: function, method + dispatch: + CPU: fill_ + CUDA: fill_ + Checkpoint: checkpoint_fill_ - func: floor(Tensor self) -> Tensor use_c10_dispatcher: full @@ -1506,6 +1514,10 @@ - func: index.Tensor(Tensor self, Tensor?[] indices) -> Tensor variants: function, method + dispatch: + CPU: index + CUDA: index + Checkpoint: checkpoint_index # NB: This function is special-cased in tools/autograd/gen_variable_type.py # NB: The following functions are declared in aten/src/ATen/templates/TensorBody.h and defined in aten/src/ATen/TensorIndexing.cpp: # - Tensor Tensor::index(ArrayRef indices) @@ -1526,6 +1538,10 @@ - func: index_put_(Tensor(a!) self, Tensor?[] indices, Tensor values, bool accumulate=False) -> Tensor(a!) variants: function, method + dispatch: + CPU: index_put_ + CUDA: index_put_ + Checkpoint: checkpoint_index_put_ # NB: The following functions are declared in aten/src/ATen/templates/TensorBody.h and defined in aten/src/ATen/TensorIndexing.cpp: # - Tensor & Tensor::index_put_(ArrayRef indices, Tensor const & rhs) # - Tensor & Tensor::index_put_(ArrayRef indices, Scalar v) @@ -4822,6 +4838,7 @@ dispatch: CPU: masked_select_out_cpu CUDA: masked_select_out_cuda + Checkpoint: checkpoint_masked_select_out supports_named_tensor: True - func: masked_select(Tensor self, Tensor mask) -> Tensor @@ -4830,6 +4847,7 @@ dispatch: CPU: masked_select_cpu CUDA: masked_select_cuda + Checkpoint: checkpoint_masked_select supports_named_tensor: True - func: nonzero.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) From 74ce175d42528149d44dafcefd581f8fe02d5d65 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9B=BE=E9=9B=A8=E9=AD=94=E7=90=86=E6=B2=99?= Date: Wed, 13 May 2020 14:13:00 -0700 Subject: [PATCH 15/42] Log aliases of operator outputs --- aten/src/ATen/CheckpointTensorImpl.cpp | 30 ++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/aten/src/ATen/CheckpointTensorImpl.cpp b/aten/src/ATen/CheckpointTensorImpl.cpp index 0de3ced8b83..3e3f61ddcdf 100644 --- a/aten/src/ATen/CheckpointTensorImpl.cpp +++ b/aten/src/ATen/CheckpointTensorImpl.cpp @@ -37,6 +37,7 @@ const std::string RELEASE = "RELEASE"; const std::string TIME = "TIME"; const std::string ARGS = "ARGS"; const std::string MEMORY = "MEMORY"; +const std::string ALIAS = "ALIAS"; const std::string NAME = "NAME"; const std::string CONSTANT = "CONSTANT"; const std::string CONSTANTS = "CONSTANTS"; @@ -106,6 +107,18 @@ void annotate_log(std::string str) { } +void DTRLogAlias(const std::string& name, int index) { + if (log_json) { + json j; + j[INSTRUCTION] = ALIAS; + j[NAME] = name; + j[ALIAS] = std::to_string(index); + logger.log(j.dump()); + } else { + logger.log(name + " " + ALIAS + ": " + std::to_string(index)); + } +} + void DTRLogCopyFrom(const std::string& to, const std::string& from) { if (log_json) { json j; @@ -220,6 +233,22 @@ void DTRLogCall(const std::vector& res, } } +// return an index for alias. +// we dont care which one because they all lead to the same alias pool. +// return -1 for no alias. +// may god forgive my sin. +int get_alias(const Tensors& ts, const Tensor& t) { + if (t.defined()) { + for (size_t i = 0; i < ts.size(); ++i) { + Tensor tsd = ts[i].decheckpoint(); + if (tsd.defined() && t.is_alias_of(tsd)) { + return i; + } + } + } + return -1; +} + Tensors CheckpointTensorImpl::make(const std::string& name, const rematerialize_function_t& remat, const Tensors& input) { @@ -250,6 +279,7 @@ Tensors CheckpointTensorImpl::make(const std::string& name, for (const Tensor& t: tensors) { auto cpti = get_cpti(t); DTRLogMemory(cpti->counter_name(), cpti->ref->value->memory()); + DTRLogAlias(cpti->counter_name(), get_alias(input, cpti->ref->value->t)); } return tensors; } From 3545dfba0f420ed445da2573478b05d3a1da1e94 Mon Sep 17 00:00:00 2001 From: "Steven S. Lyubomirsky" Date: Thu, 14 May 2020 22:32:28 -0700 Subject: [PATCH 16/42] Overloads and changes needed for transformer --- aten/src/ATen/native/Checkpoint.cpp | 52 ++++++++++++++++++++++ aten/src/ATen/native/native_functions.yaml | 6 +++ torch/nn/functional.py | 8 ++-- 3 files changed, 63 insertions(+), 3 deletions(-) diff --git a/aten/src/ATen/native/Checkpoint.cpp b/aten/src/ATen/native/Checkpoint.cpp index 9ca3ffb5678..184dcfcb813 100644 --- a/aten/src/ATen/native/Checkpoint.cpp +++ b/aten/src/ATen/native/Checkpoint.cpp @@ -1203,6 +1203,58 @@ Tensor& checkpoint_index_put_(Tensor& self, ArrayRef indices, const Tens return self; } +Tensor checkpoint_bmm(const Tensor& self, const Tensor& mat2) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::bmm(vec.at(0), vec.at(1))}; + }; + return CheckpointTensorImpl::make("bmm", rt, {self, mat2})[0]; +} + +Tensor checkpoint__softmax(const Tensor& self, long dim, bool half_to_float) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::_softmax(vec.at(0), dim, half_to_float)}; + }; + return CheckpointTensorImpl::make("_softmax", rt, {self})[0]; +} + +Tensor checkpoint__softmax_backward_data(const Tensor& grad_output, const Tensor& output, long dim, const Tensor& self) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::_softmax_backward_data(vec.at(0), vec.at(1), dim, vec.at(2))}; + }; + return CheckpointTensorImpl::make("_softmax_backward_data", rt, {grad_output, output, self})[0]; +} + +std::tuple +checkpoint_layer_norm(const Tensor& input, const Tensor& weight, const Tensor& bias, long M, long N, double eps) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + auto ret = at::native_layer_norm(vec.at(0), vec.at(1), vec.at(2), M, N, eps); + return {std::get<0>(ret), std::get<1>(ret), std::get<2>(ret)}; + }; + auto ret = CheckpointTensorImpl::make("native_layer_norm", rt, {input, weight, bias}); + return {ret[0], ret[1], ret[2]}; +} + +std::tuple +checkpoint_layer_norm_backward(const Tensor& grad_out, const Tensor& input, const Tensor& mean, const Tensor& rstd, const Tensor& weight, long M, long N, std::array output_mask) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + auto ret = at::native_layer_norm_backward(vec.at(0), vec.at(1), vec.at(2), vec.at(3), vec.at(4), M, N, output_mask); + return {std::get<0>(ret), std::get<1>(ret), std::get<2>(ret)}; + }; + auto ret = CheckpointTensorImpl::make("native_layer_norm_backward", rt, {grad_out, input, mean, rstd, weight}); + return {ret[0], ret[1], ret[2]}; +} + +bool checkpoint_equal(const Tensor& self, const Tensor& other) { + // there can't possibly be a reason to rematerialize + // a single bool so we'll just compute it now + return at::equal(decheckpoint(self), decheckpoint(other)); +} + Scalar checkpoint__local_scalar_dense(at::Tensor const& a) { return at::_local_scalar_dense(decheckpoint(a)); } diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 3e119e38b0a..89013ef1a94 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -688,6 +688,7 @@ dispatch: CPU: bmm_cpu CUDA: bmm_cuda + Checkpoint: checkpoint_bmm supports_named_tensor: True - func: bmm.out(Tensor self, Tensor mat2, *, Tensor(a!) out) -> Tensor(a!) @@ -1657,11 +1658,13 @@ dispatch: CPU: layer_norm_cpu CUDA: layer_norm_cuda + Checkpoint: checkpoint_layer_norm - func: native_layer_norm_backward(Tensor grad_out, Tensor input, Tensor mean, Tensor rstd, Tensor? weight, int M, int N, bool[3] output_mask) -> (Tensor, Tensor, Tensor) dispatch: CPU: layer_norm_backward_cpu CUDA: layer_norm_backward_cuda + Checkpoint: checkpoint_layer_norm_backward - func: linear(Tensor input, Tensor weight, Tensor? bias=None) -> Tensor python_module: nn @@ -2687,12 +2690,14 @@ CPU: softmax_cpu CUDA: softmax_cuda MkldnnCPU: mkldnn_softmax + Checkpoint: checkpoint__softmax - func: _softmax_backward_data(Tensor grad_output, Tensor output, int dim, Tensor self) -> Tensor use_c10_dispatcher: full dispatch: CPU: softmax_backward_cpu CUDA: softmax_backward_cuda + Checkpoint: checkpoint__softmax_backward_data - func: split.Tensor(Tensor(a) self, int split_size, int dim=0) -> Tensor(a)[] variants: function, method @@ -5387,6 +5392,7 @@ CPU: legacy::cpu::_th_equal CUDA: legacy::cuda::_th_equal QuantizedCPU: quantized_equal + Checkpoint: checkpoint_equal supports_named_tensor: True - func: pow.Tensor_Tensor_out(Tensor self, Tensor exponent, *, Tensor(a!) out) -> Tensor(a!) diff --git a/torch/nn/functional.py b/torch/nn/functional.py index 4777206b90b..13037b8cbfe 100644 --- a/torch/nn/functional.py +++ b/torch/nn/functional.py @@ -3776,9 +3776,10 @@ def multi_head_attention_forward(query, # type: Tensor q = q.contiguous().view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1) if k is not None: - k = k.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1) + # extremely bizarre, but got errors of dimension mistmatch here and below unless I changed view -> reshape + k = k.contiguous().reshape(-1, bsz * num_heads, head_dim).transpose(0, 1) if v is not None: - v = v.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1) + v = v.contiguous().reshape(-1, bsz * num_heads, head_dim).transpose(0, 1) if static_k is not None: assert static_k.size(0) == bsz * num_heads @@ -3825,7 +3826,8 @@ def multi_head_attention_forward(query, # type: Tensor attn_output = torch.bmm(attn_output_weights, v) assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim] - attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) + # also had to change view to reshape + attn_output = attn_output.transpose(0, 1).contiguous().reshape(tgt_len, bsz, embed_dim) attn_output = linear(attn_output, out_proj_weight, out_proj_bias) if need_weights: From 79578bb78ea18b9e785464c3b650774f4b37e8f2 Mon Sep 17 00:00:00 2001 From: "Steven S. Lyubomirsky" Date: Fri, 15 May 2020 19:41:43 -0700 Subject: [PATCH 17/42] Overloads for topk functions --- aten/src/ATen/native/Checkpoint.cpp | 23 ++++++++++++++++++++++ aten/src/ATen/native/native_functions.yaml | 2 ++ 2 files changed, 25 insertions(+) diff --git a/aten/src/ATen/native/Checkpoint.cpp b/aten/src/ATen/native/Checkpoint.cpp index 184dcfcb813..be787fefa4c 100644 --- a/aten/src/ATen/native/Checkpoint.cpp +++ b/aten/src/ATen/native/Checkpoint.cpp @@ -1249,6 +1249,29 @@ checkpoint_layer_norm_backward(const Tensor& grad_out, const Tensor& input, cons return {ret[0], ret[1], ret[2]}; } +std::tuple +checkpoint_topk(const Tensor& self, long k, long dim, bool largest, bool sorted) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + auto ret = at::topk(vec.at(0), k, dim, largest, sorted); + return {std::get<0>(ret), std::get<1>(ret)}; + }; + auto ret = CheckpointTensorImpl::make("topk", rt, {self}); + return {ret[0], ret[1]}; +} + +std::tuple +checkpoint_topk_values(Tensor& values, Tensor& indices, const Tensor& self, long k, long dim, bool largest, bool sorted) { + mutate_function_t mt = + [=](const Tensors& vec) { + Tensor values_ = vec.at(0); + Tensor indices_ = vec.at(1); + at::topk_out(values_, indices_, vec.at(2), k, dim, largest, sorted); + }; + CheckpointTensorImpl::mutate("topk_values", mt, {values, indices, self}, {0, 1}); + return {values, indices}; +} + bool checkpoint_equal(const Tensor& self, const Tensor& other) { // there can't possibly be a reason to rematerialize // a single bool so we'll just compute it now diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 89013ef1a94..90032615686 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -5343,6 +5343,7 @@ dispatch: CPU: topk_out_cpu CUDA: legacy::cuda::_th_topk_out + Checkpoint: checkpoint_topk_values - func: topk(Tensor self, int k, int dim=-1, bool largest=True, bool sorted=True) -> (Tensor values, Tensor indices) variants: method, function @@ -5350,6 +5351,7 @@ CPU: topk CUDA: topk QuantizedCPU: quantized_topk_cpu + Checkpoint: checkpoint_topk - func: all(Tensor self) -> Tensor use_c10_dispatcher: full From e144342eba092b129e575778f7161df14a5ffa18 Mon Sep 17 00:00:00 2001 From: "Steven S. Lyubomirsky" Date: Tue, 19 May 2020 16:06:38 -0700 Subject: [PATCH 18/42] Add overloads for deepspeech --- aten/src/ATen/native/Checkpoint.cpp | 192 +++++++++++++++++++++ aten/src/ATen/native/native_functions.yaml | 21 +++ 2 files changed, 213 insertions(+) diff --git a/aten/src/ATen/native/Checkpoint.cpp b/aten/src/ATen/native/Checkpoint.cpp index be787fefa4c..5fc55eb4c59 100644 --- a/aten/src/ATen/native/Checkpoint.cpp +++ b/aten/src/ATen/native/Checkpoint.cpp @@ -1272,6 +1272,198 @@ checkpoint_topk_values(Tensor& values, Tensor& indices, const Tensor& self, long return {values, indices}; } +Tensor& checkpoint_masked_fill_(Tensor& self, const Tensor& mask, Scalar value) { + mutate_function_t mt = + [=](const Tensors& vec) { + Tensor self_ = vec.at(0); + self_.masked_fill_(vec.at(1), value); + }; + CheckpointTensorImpl::mutate("masked_fill_Scalar", mt, {self, mask}, {0}); + return {self}; +} + +Tensor& checkpoint_masked_fill_(Tensor& self, const Tensor& mask, const Tensor& value) { + mutate_function_t mt = + [=](const Tensors& vec) { + Tensor self_ = vec.at(0); + self_.masked_fill_(vec.at(1), vec.at(2)); + }; + CheckpointTensorImpl::mutate("masked_fill_Tensor", mt, {self, mask, value}, {0}); + return {self}; +} + +Tensor checkpoint_clamp(const Tensor& self, c10::optional min, c10::optional max) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::clamp(vec.at(0), min, max)}; + }; + return CheckpointTensorImpl::make("clamp", rt, {self})[0]; +} + +Tensor& checkpoint_clamp_(Tensor& self, c10::optional min, c10::optional max) { + mutate_function_t mt = + [=](const Tensors& vec) { + Tensor self_ = vec.at(0); + at::clamp_(self_, min, max); + }; + CheckpointTensorImpl::mutate("clamp_", mt, {self}, {0}); + return {self}; +} + +Tensor& checkpoint_clamp_out(Tensor& out, const Tensor& self, c10::optional min, c10::optional max) { + mutate_function_t mt = + [=](const Tensors& vec) { + Tensor out_ = vec.at(0); + at::clamp_out(out_, vec.at(1), min, max); + }; + CheckpointTensorImpl::mutate("clamp_out", mt, {out, self}, {0}); + return {out}; +} + +std::tuple checkpoint_thnn_conv2d_forward_out(Tensor& output, Tensor& finput, Tensor& fgrad_input, const Tensor& self, const Tensor& weight, c10::ArrayRef kernel_size, const Tensor& bias, c10::ArrayRef stride, c10::ArrayRef padding) { + auto kernel_size_ = kernel_size.vec(); + auto stride_ = stride.vec(); + auto padding_ = padding.vec(); + mutate_function_t mt = + [=](const Tensors& vec) { + Tensor output_ = vec.at(0); + Tensor finput_ = vec.at(1); + Tensor fgrad_input_ = vec.at(2); + at::thnn_conv2d_forward_out(output_, finput_, fgrad_input_, vec.at(3), vec.at(4), kernel_size_, vec.at(5), stride_, padding_); + }; + CheckpointTensorImpl::mutate("thnn_conv2d_forward_out", mt, {output, finput, fgrad_input, self, weight, bias}, {0, 1, 2}); + return {output, finput, fgrad_input}; +} + +std::tuple checkpoint_thnn_conv2d_forward(const Tensor& self, const Tensor& weight, c10::ArrayRef kernel_size, const Tensor& bias, c10::ArrayRef stride, c10::ArrayRef padding) { + auto kernel_size_ = kernel_size.vec(); + auto stride_ = stride.vec(); + auto padding_ = padding.vec(); + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + auto ret = at::thnn_conv2d_forward(vec.at(0), vec.at(1), kernel_size_, vec.at(2), stride_, padding_); + return {std::get<0>(ret), std::get<1>(ret), std::get<2>(ret)}; + }; + auto ret = CheckpointTensorImpl::make("thnn_conv2d_forward", rt, {self, weight, bias}); + return {ret[0], ret[1], ret[2]}; +} + +std::tuple checkpoint_thnn_conv2d_backward_out(Tensor& grad_input, Tensor& grad_weight, Tensor& grad_bias, const Tensor& grad_output, const Tensor& self, const Tensor& weight, c10::ArrayRef kernel_size, c10::ArrayRef stride, c10::ArrayRef padding, const Tensor& finput, const Tensor& fgrad_input) { + auto kernel_size_ = kernel_size.vec(); + auto stride_ = stride.vec(); + auto padding_ = padding.vec(); + mutate_function_t mt = + [=](const Tensors& vec) { + Tensor grad_input_ = vec.at(0); + Tensor grad_weight_ = vec.at(1); + Tensor grad_bias_ = vec.at(2); + at::thnn_conv2d_backward_out(grad_input_, grad_weight_, grad_bias_, vec.at(3), vec.at(4), vec.at(5), kernel_size_, stride_, padding_, vec.at(6), vec.at(7)); + }; + CheckpointTensorImpl::mutate("thnn_conv2d_backward_out", mt, {grad_input, grad_weight, grad_bias, grad_output, self, weight, finput, fgrad_input}, {0, 1, 2}); + return {grad_input, grad_weight, grad_bias}; +} + +std::tuple checkpoint_thnn_conv2d_backward(const Tensor& grad_output, const Tensor& self, const Tensor& weight, c10::ArrayRef kernel_size, c10::ArrayRef stride, c10::ArrayRef padding, const Tensor& finput, const Tensor& fgrad_input, std::array output_mask) { + auto kernel_size_ = kernel_size.vec(); + auto stride_ = stride.vec(); + auto padding_ = padding.vec(); + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + auto ret = at::thnn_conv2d_backward(vec.at(0), vec.at(1), vec.at(2), kernel_size_, stride_, padding_, vec.at(3), vec.at(4), output_mask); + return {std::get<0>(ret), std::get<1>(ret), std::get<2>(ret)}; + }; + auto ret = CheckpointTensorImpl::make("thnn_conv2d_backward", rt, {grad_output, self, weight, finput, fgrad_input}); + return {ret[0], ret[1], ret[2]}; +} + +std::tuple checkpoint_native_batch_norm(const Tensor& input, const Tensor& weight, const Tensor& bias, const Tensor& running_mean, const Tensor& running_var, bool training, double momentum, double eps) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + auto ret = at::native_batch_norm(vec.at(0), vec.at(1), vec.at(2), vec.at(3), vec.at(4), training, momentum, eps); + return {std::get<0>(ret), std::get<1>(ret), std::get<2>(ret)}; + }; + auto ret = CheckpointTensorImpl::make("native_batch_norm", rt, {input, weight, bias, running_mean, running_var}); + return {ret[0], ret[1], ret[2]}; +} + +std::tuple checkpoint_native_batch_norm_out(Tensor& out, Tensor& save_mean, Tensor& save_invstd, const Tensor& input, const Tensor& weight, const Tensor& bias, const Tensor& running_mean, const Tensor& running_var, bool training, double momentum, double eps) { + mutate_function_t mt = + [=](const Tensors& vec) { + Tensor out_ = vec.at(0); + Tensor save_mean_ = vec.at(1); + Tensor save_invstd_ = vec.at(2); + at::native_batch_norm_out(out_, save_mean_, save_invstd_, vec.at(3), vec.at(4), vec.at(5), vec.at(6), vec.at(7), training, momentum, eps); + }; + CheckpointTensorImpl::mutate("native_batch_norm_out", mt, {out, save_mean, save_invstd, input, weight, bias, running_mean, running_var}, {0, 1, 2}); + return {out, save_mean, save_invstd}; +} + +std::tuple checkpoint_native_batch_norm_backward(const Tensor& grad_out, const Tensor& input, const Tensor& weight, const Tensor& running_mean, const Tensor& running_var, const Tensor& save_mean, const Tensor& save_invstd, bool train, double eps, std::array output_mask) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + auto ret = at::native_batch_norm_backward(vec.at(0), vec.at(1), vec.at(2), vec.at(3), vec.at(4), vec.at(5), vec.at(6), train, eps, output_mask); + return {std::get<0>(ret), std::get<1>(ret), std::get<2>(ret)}; + }; + auto ret = CheckpointTensorImpl::make("native_batch_norm_backward", rt, {grad_out, input, weight, running_mean, running_var, save_mean, save_invstd}); + return {ret[0], ret[1], ret[2]}; +} + +std::tuple checkpoint__cudnn_ctc_loss(const Tensor& log_probs, const Tensor& targets, ArrayRef input_lengths, ArrayRef target_lengths, long blank, bool deterministic, bool zero_infinity) { + auto input_lengths_ = input_lengths.vec(); + auto target_lengths_ = target_lengths.vec(); + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + auto ret = at::_cudnn_ctc_loss(vec.at(0), vec.at(1), input_lengths_, target_lengths_, blank, deterministic, zero_infinity); + return {std::get<0>(ret), std::get<1>(ret)}; + }; + auto ret = CheckpointTensorImpl::make("_cudnn_ctc_loss", rt, {log_probs, targets}); + return {ret[0], ret[1]}; +} + +std::tuple checkpoint__ctc_loss(const Tensor& log_probs, const Tensor& targets, ArrayRef input_lengths, ArrayRef target_lengths, long blank, bool zero_infinity) { + auto input_lengths_ = input_lengths.vec(); + auto target_lengths_ = target_lengths.vec(); + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + auto ret = at::_ctc_loss(vec.at(0), vec.at(1), input_lengths_, target_lengths_, blank, zero_infinity); + return {std::get<0>(ret), std::get<1>(ret)}; + }; + auto ret = CheckpointTensorImpl::make("_ctc_loss", rt, {log_probs, targets}); + return {ret[0], ret[1]}; +} + +Tensor checkpoint__ctc_loss_backward(const Tensor& grad, const Tensor& log_probs, const Tensor& targets, ArrayRef input_lengths, ArrayRef target_lengths, const Tensor& neg_log_likelihood, const Tensor& log_alpha, long blank, bool zero_infinity) { + auto input_lengths_ = input_lengths.vec(); + auto target_lengths_ = target_lengths.vec(); + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::_ctc_loss_backward(vec.at(0), vec.at(1), vec.at(2), input_lengths_, target_lengths_, vec.at(3), vec.at(4), blank, zero_infinity)}; + }; + return CheckpointTensorImpl::make("_ctc_loss_backward", rt, {grad, log_probs, targets, neg_log_likelihood, log_alpha})[0]; +} + +Tensor& checkpoint_hardtanh_backward_out(Tensor& grad_input, const Tensor& grad_output, const Tensor& self, Scalar min_val, Scalar max_val) { + mutate_function_t mt = + [=](const Tensors& vec) { + Tensor grad_input_ = vec.at(0); + at::hardtanh_backward_out(grad_input_, vec.at(1), vec.at(2), min_val, max_val); + }; + CheckpointTensorImpl::mutate("hardtanh_backward_out", mt, {grad_input, grad_output, self}, {0}); + return {grad_input}; +} + +Tensor checkpoint_hardtanh_backward(const Tensor& grad_output, const Tensor& self, Scalar min_val, Scalar max_val) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::hardtanh_backward(vec.at(0), vec.at(1), min_val, max_val)}; + }; + return CheckpointTensorImpl::make("hardtanh_backward", rt, {grad_output, self})[0]; +} + +bool checkpoint__use_cudnn_ctc_loss(const Tensor& log_probs, const Tensor& targets, ArrayRef input_lengths, ArrayRef target_lengths, long blank) { + return at::_use_cudnn_ctc_loss(decheckpoint(log_probs), decheckpoint(targets), input_lengths, target_lengths, blank); +} + bool checkpoint_equal(const Tensor& self, const Tensor& other) { // there can't possibly be a reason to rematerialize // a single bool so we'll just compute it now diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 90032615686..f0b8f8e56b0 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -155,10 +155,12 @@ - func: _use_cudnn_ctc_loss(Tensor log_probs, Tensor targets, int[] input_lengths, int[] target_lengths, int blank) -> bool dispatch: CUDA: _use_cudnn_ctc_loss + Checkpoint: checkpoint__use_cudnn_ctc_loss - func: _cudnn_ctc_loss(Tensor log_probs, Tensor targets, int[] input_lengths, int[] target_lengths, int blank, bool deterministic, bool zero_infinity) -> (Tensor, Tensor) dispatch: CUDA: _cudnn_ctc_loss + Checkpoint: checkpoint__cudnn_ctc_loss - func: _cudnn_rnn_flatten_weight(Tensor[] weight_arr, int weight_stride0, int input_size, int mode, int hidden_size, int num_layers, bool batch_first, bool bidirectional) -> Tensor dispatch: @@ -744,6 +746,7 @@ CPU: clamp CUDA: clamp QuantizedCPU: quantized_clamp + Checkpoint: checkpoint_clamp - func: clamp_(Tensor(a!) self, Scalar? min=None, Scalar? max=None) -> Tensor(a!) supports_named_tensor: True @@ -751,12 +754,14 @@ dispatch: CPU: _clamp__cpu CUDA: _clamp__cuda + Checkpoint: checkpoint_clamp_ - func: clamp.out(Tensor self, Scalar? min=None, Scalar? max=None, *, Tensor(a!) out) -> Tensor(a!) supports_named_tensor: True dispatch: CPU: _clamp_out_cpu CUDA: _clamp_out_cuda + Checkpoint: checkpoint_clamp_out - func: clamp_max(Tensor self, Scalar max) -> Tensor use_c10_dispatcher: full @@ -1054,11 +1059,13 @@ dispatch: CPU: ctc_loss_cpu CUDA: ctc_loss_gpu + Checkpoint: checkpoint__ctc_loss - func: _ctc_loss_backward(Tensor grad, Tensor log_probs, Tensor targets, int[] input_lengths, int[] target_lengths, Tensor neg_log_likelihood, Tensor log_alpha, int blank, bool zero_infinity=False) -> Tensor dispatch: CPU: ctc_loss_backward_cpu CUDA: ctc_loss_backward_gpu + Checkpoint: checkpoint__ctc_loss_backward - func: det(Tensor self) -> Tensor use_c10_dispatcher: full @@ -2165,10 +2172,12 @@ CPU: batch_norm_cpu CUDA: batch_norm_cuda MkldnnCPU: mkldnn_batch_norm + Checkpoint: checkpoint_native_batch_norm - func: native_batch_norm.out(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps, *, Tensor(a!) out, Tensor(b!) save_mean, Tensor(c!) save_invstd) -> (Tensor(a!), Tensor(b!), Tensor(c!)) dispatch: CUDA: batch_norm_cuda_out + Checkpoint: checkpoint_native_batch_norm_out - func: batch_norm_stats(Tensor input, float eps) -> (Tensor, Tensor) dispatch: @@ -2195,6 +2204,7 @@ dispatch: CPU: batch_norm_backward_cpu CUDA: batch_norm_backward_cuda + Checkpoint: checkpoint_native_batch_norm_backward - func: batch_norm_backward_reduce(Tensor grad_out, Tensor input, Tensor mean, Tensor invstd, Tensor? weight, bool input_g, bool weight_g, bool bias_g) -> (Tensor, Tensor, Tensor, Tensor) dispatch: @@ -4064,6 +4074,7 @@ dispatch: CPU: masked_fill__cpu CUDA: masked_fill__cuda + Checkpoint: checkpoint_masked_fill_ supports_named_tensor: True - func: masked_fill.Scalar(Tensor self, Tensor mask, Scalar value) -> Tensor @@ -4076,6 +4087,7 @@ dispatch: CPU: masked_fill__cpu CUDA: masked_fill__cuda + Checkpoint: checkpoint_masked_fill_ supports_named_tensor: True - func: masked_fill.Tensor(Tensor self, Tensor mask, Tensor value) -> Tensor @@ -5831,10 +5843,15 @@ dispatch: CPU: hardtanh_backward_out CUDA: hardtanh_backward_out + Checkpoint: checkpoint_hardtanh_backward_out - func: hardtanh_backward(Tensor grad_output, Tensor self, Scalar min_val, Scalar max_val) -> Tensor use_c10_dispatcher: full python_module: nn + dispatch: + CPU: hardtanh_backward + CUDA: hardtanh_backward + Checkpoint: checkpoint_hardtanh_backward - func: hardtanh_(Tensor(a!) self, Scalar min_val=-1, Scalar max_val=1) -> Tensor(a!) python_module: nn @@ -6674,24 +6691,28 @@ dispatch: CPU: slow_conv2d_forward_out_cpu CUDA: legacy::cuda::_thnn_conv2d_forward_out + Checkpoint: checkpoint_thnn_conv2d_forward_out - func: thnn_conv2d_forward(Tensor self, Tensor weight, int[2] kernel_size, Tensor? bias, int[2] stride, int[2] padding) -> (Tensor output, Tensor finput, Tensor fgrad_input) python_module: nn dispatch: CPU: slow_conv2d_forward_cpu CUDA: legacy::cuda::_thnn_conv2d_forward + Checkpoint: checkpoint_thnn_conv2d_forward - func: thnn_conv2d_backward.grad_input(Tensor grad_output, Tensor self, Tensor weight, int[2] kernel_size, int[2] stride, int[2] padding, Tensor finput, Tensor fgrad_input, *, Tensor(a!)? grad_input, Tensor(b!)? grad_weight, Tensor(c!)? grad_bias) -> (Tensor(a!), Tensor(b!), Tensor(c!)) python_module: nn dispatch: CPU: slow_conv2d_backward_out_cpu CUDA: legacy::cuda::_thnn_conv2d_backward_out + Checkpoint: checkpoint_thnn_conv2d_backward_out - func: thnn_conv2d_backward.output_mask(Tensor grad_output, Tensor self, Tensor weight, int[2] kernel_size, int[2] stride, int[2] padding, Tensor finput, Tensor fgrad_input, bool[3] output_mask) -> (Tensor grad_input, Tensor grad_weight, Tensor grad_bias) python_module: nn dispatch: CPU: slow_conv2d_backward_cpu CUDA: legacy::cuda::_thnn_conv2d_backward + Checkpoint: checkpoint_thnn_conv2d_backward - func: thnn_conv_depthwise2d.out(Tensor self, Tensor weight, int[2] kernel_size, Tensor? bias=None, int[2] stride=1, int[2] padding=0, int[2] dilation=1, *, Tensor(a!) out) -> Tensor(a!) python_module: nn From 5476fcc836a64ac6bd34d5d515388f3d72b5e188 Mon Sep 17 00:00:00 2001 From: "Steven S. Lyubomirsky" Date: Tue, 19 May 2020 18:18:04 -0700 Subject: [PATCH 19/42] More overloads for ACT --- aten/src/ATen/native/Checkpoint.cpp | 62 ++++++++++++++++++++++ aten/src/ATen/native/native_functions.yaml | 7 +++ 2 files changed, 69 insertions(+) diff --git a/aten/src/ATen/native/Checkpoint.cpp b/aten/src/ATen/native/Checkpoint.cpp index 5fc55eb4c59..6d8f0331601 100644 --- a/aten/src/ATen/native/Checkpoint.cpp +++ b/aten/src/ATen/native/Checkpoint.cpp @@ -1460,6 +1460,68 @@ Tensor checkpoint_hardtanh_backward(const Tensor& grad_output, const Tensor& sel return CheckpointTensorImpl::make("hardtanh_backward", rt, {grad_output, self})[0]; } +Tensor checkpoint_nonzero(const Tensor& self) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::nonzero(vec.at(0))}; + }; + return CheckpointTensorImpl::make("nonzero", rt, {self})[0]; +} + +Tensor& checkpoint_nonzero_out(Tensor& out, const Tensor& self) { + mutate_function_t mt = + [=](const Tensors& vec) { + Tensor out_ = vec.at(0); + at::nonzero_out(out_, vec.at(1)); + }; + CheckpointTensorImpl::mutate("nonzero_out", mt, {out, self}, {0}); + return {out}; +} + +Tensor checkpoint_lt(const Tensor& self, Scalar other) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::lt(vec.at(0), other)}; + }; + return CheckpointTensorImpl::make("lt_Scalar", rt, {self})[0]; +} + +Tensor& checkpoint_lt_out(Tensor& out, const Tensor& self, Scalar other) { + mutate_function_t mt = + [=](const Tensors& vec) { + Tensor out_ = vec.at(0); + at::lt_out(out_, vec.at(1), other); + }; + CheckpointTensorImpl::mutate("lt_Scalar_out", mt, {out, self}, {0}); + return {out}; +} + +Tensor checkpoint_lt(const Tensor& self, const Tensor& other) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::lt(vec.at(0), vec.at(1))}; + }; + return CheckpointTensorImpl::make("lt_Tensor", rt, {self, other})[0]; +} + +Tensor& checkpoint_lt_out(Tensor& out, const Tensor& self, const Tensor& other) { + mutate_function_t mt = + [=](const Tensors& vec) { + Tensor out_ = vec.at(0); + at::lt_out(out_, vec.at(1), vec.at(2)); + }; + CheckpointTensorImpl::mutate("lt_Tensor_out", mt, {out, self, other}, {0}); + return {out}; +} + +Tensor checkpoint_any(const Tensor& self) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::any(vec.at(0))}; + }; + return CheckpointTensorImpl::make("any", rt, {self})[0]; +} + bool checkpoint__use_cudnn_ctc_loss(const Tensor& log_probs, const Tensor& targets, ArrayRef input_lengths, ArrayRef target_lengths, long blank) { return at::_use_cudnn_ctc_loss(decheckpoint(log_probs), decheckpoint(targets), input_lengths, target_lengths, blank); } diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index f0b8f8e56b0..bc33f2a6ba4 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -4794,6 +4794,7 @@ CPU: lt_out CUDA: lt_out QuantizedCPU: lt_out_quantized_cpu + Checkpoint: checkpoint_lt_out - func: lt.Scalar(Tensor self, Scalar other) -> Tensor supports_named_tensor: True @@ -4803,6 +4804,7 @@ CPU: lt CUDA: lt QuantizedCPU: lt_quantized_cpu + Checkpoint: checkpoint_lt - func: lt.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) supports_named_tensor: True @@ -4810,6 +4812,7 @@ CPU: lt_out CUDA: lt_out QuantizedCPU: lt_out_quantized_cpu + Checkpoint: checkpoint_lt_out - func: lt.Tensor(Tensor self, Tensor other) -> Tensor supports_named_tensor: True @@ -4819,6 +4822,7 @@ CPU: lt CUDA: lt QuantizedCPU: lt_quantized_cpu + Checkpoint: checkpoint_lt - func: take.out(Tensor self, Tensor index, *, Tensor(a!) out) -> Tensor(a!) dispatch: @@ -4871,6 +4875,7 @@ dispatch: CPU: legacy::cpu::_th_nonzero_out CUDA: legacy::cuda::_th_nonzero_out + Checkpoint: checkpoint_nonzero_out - func: nonzero(Tensor self) -> Tensor use_c10_dispatcher: full @@ -4878,6 +4883,7 @@ dispatch: CPU: legacy::cpu::_th_nonzero CUDA: legacy::cuda::_th_nonzero + Checkpoint: checkpoint_nonzero - func: nonzero_numpy(Tensor self) -> Tensor[] variants: method, function @@ -5379,6 +5385,7 @@ CUDA: any SparseCPU: any_sparse SparseCUDA: any_sparse + Checkpoint: checkpoint_any - func: renorm.out(Tensor self, Scalar p, int dim, Scalar maxnorm, *, Tensor(a!) out) -> Tensor(a!) dispatch: From 0b770f3f4dd9cb3cbb5dd3b531b52bda40d7c7d3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9B=BE=E9=9B=A8=E9=AD=94=E7=90=86=E6=B2=99?= Date: Fri, 22 May 2020 18:06:56 -0700 Subject: [PATCH 20/42] DTR implementation first pass (#37) save before change save before change add aliaspool design, and add weak/strong pointer discussion add more code rebase add allocator hook save metadata to prepare for eviction save refactor: move log to a seprate file add file raii save save comment save save save save find error, bisecting code save save address review comment address comment address comment fix segfault save save pin save save save save save --- aten/src/ATen/CheckpointTensorImpl.cpp | 364 ++++++++------------- aten/src/ATen/CheckpointTensorImpl.h | 313 ++++++++++++++++-- aten/src/ATen/Logger.h | 192 +++++++++++ aten/src/ATen/native/Activation.cpp | 78 +++++ aten/src/ATen/native/Checkpoint.cpp | 47 ++- aten/src/ATen/native/TensorShape.cpp | 46 --- aten/src/ATen/native/native_functions.yaml | 32 ++ c10/cuda/CUDACachingAllocator.cpp | 22 +- c10/cuda/CUDACachingAllocator.h | 4 + test.py | 6 +- tools/autograd/templates/Functions.cpp | 32 -- 11 files changed, 792 insertions(+), 344 deletions(-) create mode 100644 aten/src/ATen/Logger.h diff --git a/aten/src/ATen/CheckpointTensorImpl.cpp b/aten/src/ATen/CheckpointTensorImpl.cpp index 3e3f61ddcdf..d7f1e1ba9dd 100644 --- a/aten/src/ATen/CheckpointTensorImpl.cpp +++ b/aten/src/ATen/CheckpointTensorImpl.cpp @@ -1,67 +1,58 @@ #include +#include #include -#include <../../../third_party/json/single_include/nlohmann/json.hpp> namespace at { -struct DTRLogger { - std::string time_prefix; - std::ofstream out; - static std::string get_time_prefix() { - std::time_t t = std::time(nullptr); - std::tm* tm = std::localtime(&t); - return - std::to_string(1900+tm->tm_year) + "-" + - std::to_string(1+tm->tm_mon) + "-" + - std::to_string(tm->tm_mday) + "-" + - std::to_string(tm->tm_hour) + "-" + - std::to_string(tm->tm_min) + "-" + - std::to_string(tm->tm_sec); - } - std::string get_filename(const std::string& name) { - return time_prefix + "-" + name + ".log"; - } - DTRLogger() : time_prefix(get_time_prefix()), out(get_filename("default")) { } - void log(const std::string& str) { - out << str << std::endl; +void AliasPool::evict() { + TORCH_CHECK(lock_count == 0); + for (const weak& w : tensors) { + if (auto cell = w.lock()) { + cell->evict(); + } } -}; - -static DTRLogger logger; +} -using json = nlohmann::json; -constexpr bool log_json = true; -const std::string INSTRUCTION = "INSTRUCTION"; -const std::string ANNOTATION = "ANNOTATION"; -const std::string RELEASE = "RELEASE"; -const std::string TIME = "TIME"; -const std::string ARGS = "ARGS"; -const std::string MEMORY = "MEMORY"; -const std::string ALIAS = "ALIAS"; -const std::string NAME = "NAME"; -const std::string CONSTANT = "CONSTANT"; -const std::string CONSTANTS = "CONSTANTS"; +void External::release_resources() { + value->evict(); + value.reset(); +} -void DTRLogConstant(const std::string& name) { - if (log_json) { - json j; - j[INSTRUCTION] = CONSTANT; - j[NAME] = name; - logger.log(j.dump()); - } else { - logger.log(CONSTANT + " " + name); +Tensors stitch(const strongs& input_values, + const std::vector>& constants) { + Tensors input; + size_t i = 0, j = 0; + while (i != input_values.size() || j != constants.size()) { + if (j < constants.size() && std::get<1>(constants[j]) == input.size()) { + Tensor t = std::get<0>(constants[j]); + TORCH_CHECK(!t.key_set().has(DispatchKey::CheckpointTensorId)); + input.push_back(t); + ++j; + } + else { + CHECK(i < input_values.size()); + input.push_back(input_values[i]->get()); + ++i; + } } + return input; } -void DTRLogMemory(const std::string& name, size_t memory) { - if (log_json) { - json j; - j[INSTRUCTION] = MEMORY; - j[NAME] = name; - j[MEMORY] = std::to_string(memory); - logger.log(j.dump()); - } else { - logger.log(name + " " + MEMORY + ": " + std::to_string(memory)); +void Rematerializer::remat() { + // TODO: refactor using RAII for exception safety. + for (const strong& s : input_values) { + ++(s->pool->lock_count); + } + Tensors ts = stitch(input_values, constants); + auto ret = func(ts); + TORCH_CHECK(ret.size() == outputs.size()); + for (size_t i = 0; i < outputs.size(); ++i) { + if (auto output_cell = outputs[i].lock()) { + output_cell->fill(ret[i]); + } + } + for (const strong& s : input_values) { + --(s->pool->lock_count); } } @@ -70,19 +61,25 @@ namespace native { Tensor checkpoint(const Tensor& t) { auto cpti = intrusive_ptr::make(t.detach()); DTRLogConstant(cpti->counter_name()); - DTRLogMemory(cpti->counter_name(), cpti->ref->value->memory()); + DTRLogMemory(cpti->counter_name(), cpti->ref->value->value->memory()); return Tensor(cpti); } Tensor uncheckpoint(const Tensor& t) { auto* cpti = dynamic_cast(t.unsafeGetTensorImpl()); CHECK(cpti != nullptr); - return cpti->ref->value->t; + return cpti->ref->value->value->get(); +} + +void pin(const Tensor& t) { + auto* cpti = dynamic_cast(t.unsafeGetTensorImpl()); + CHECK(cpti != nullptr); + cpti->ref->value->value->pin(); } Tensor decheckpoint(const Tensor& t) { auto* cpti = dynamic_cast(t.unsafeGetTensorImpl()); - return cpti ? cpti->ref->value->t : t; + return cpti ? cpti->ref->value->value->get() : t; } bool is_checkpoint(const Tensor& t) { @@ -91,7 +88,7 @@ bool is_checkpoint(const Tensor& t) { } void new_log(std::string str) { - logger.out = std::ofstream(logger.get_filename(str)); + DTRLogger::logger().out = std::ofstream(DTRLogger::logger().get_filename(str)); } void annotate_log(std::string str) { @@ -99,48 +96,16 @@ void annotate_log(std::string str) { json j; j[INSTRUCTION] = "ANNOTATE"; j[ANNOTATION] = str; - logger.log(j.dump()); - } else { - logger.log("# " + str); - } -} - -} - -void DTRLogAlias(const std::string& name, int index) { - if (log_json) { - json j; - j[INSTRUCTION] = ALIAS; - j[NAME] = name; - j[ALIAS] = std::to_string(index); - logger.log(j.dump()); + DTRLogger::logger().log(j.dump()); } else { - logger.log(name + " " + ALIAS + ": " + std::to_string(index)); + DTRLogger::logger().log("# " + str); } } -void DTRLogCopyFrom(const std::string& to, const std::string& from) { - if (log_json) { - json j; - j[INSTRUCTION] = "COPY_FROM"; - j["DST"] = to; - j["SRC"] = from; - logger.log(j.dump()); - } else { - logger.log(to + " <- " + from); - } +void clear_checkpointpool() { + // not implemented yet. } -void DTRLogCopy(const std::string& new_name, const std::string& old_name) { - if (log_json) { - json j; - j[INSTRUCTION] = "COPY"; - j["DST"] = new_name; - j["SRC"] = old_name; - logger.log(j.dump()); - } else { - logger.log(new_name + " = " + old_name); - } } intrusive_ptr CheckpointTensorImpl::shallow_copy_and_detach(const VariableVersion& version_counter, @@ -160,106 +125,83 @@ void CheckpointTensorImpl::shallow_copy_from(const c10::intrusive_ptr::make(t.detach())); +// return an index for alias. +// we dont care which one because they all lead to the same alias pool. +// return -1 for no alias. +// may god forgive my sin. +int get_alias(const Tensors& ts, const Tensor& t) { + if (t.defined()) { + for (size_t i = 0; i < ts.size(); ++i) { + if (ts[i].defined() && t.is_alias_of(ts[i])) { + return i; + } + } + } + return -1; } +struct MakeRawResult { + std::vector> outputs; + std::vector aliases; + duration_t time; + intrusive_ptr rematerializer; +}; + // remat take a single vector of tensors, // while there are two vector, one storing nonconstants and one storing constants. // the constants are small and they will not be considered for eviction. // however, we have to stitch the two vectors together to pass it in remat. // the size_t in constants decide the location to stitch them in, while input_values fill in the rest. -std::tuple make_raw(const rematerialize_function_t& remat, - const strongs& input_values, - const std::vector>& constants) { - std::vector input; - size_t i = 0, j = 0; - while (i != input_values.size() || j != constants.size()) { - if (j < constants.size() && std::get<1>(constants[j]) == input.size()) { - Tensor t = std::get<0>(constants[j]); - TORCH_CHECK(!t.key_set().has(DispatchKey::CheckpointTensorId)); - input.push_back(t); - ++j; - } - else { - CHECK(i < input_values.size()); - CHECK(!input_values[i]->t.key_set().has(DispatchKey::CheckpointTensorId)); - input.push_back(input_values[i]->t); - ++i; - } - } +MakeRawResult make_raw(const rematerialize_function_t& remat_f, + // We need this to assign alias pool. + // This is ugly as fuck but after refactoring we dont even need stitching anymore. + const Tensors& raw_input, + const strongs& input_values, + const std::vector>& constants) { + Tensors inputs = stitch(input_values, constants); time_t pre = std::chrono::system_clock::now(); - auto output = remat(input); + auto outputs_raw = remat_f(inputs); time_t post = std::chrono::system_clock::now(); - return {output, post - pre}; -} - -std::string from_time(duration_t t) { - return std::to_string(std::chrono::nanoseconds(t).count()); -} - -void DTRLogCall(const std::vector& res, - const std::string& name, - const std::vector& args, - const std::vector& constants, - const std::string& time) { - if (log_json) { - json j; - j[INSTRUCTION] = "CALL"; - j[NAME] = name; - j["RESULT"] = res; - j[ARGS] = args; - j[CONSTANTS] = constants; - j[TIME] = time; - logger.log(j.dump()); - } else { - CHECK(constants.size() == 0); //TODO: implement. - std::string arg = name + "("; - for (const auto& s : args) { - arg += s; - arg += ", "; + std::vector> outputs; + std::vector aliases; + weaks weak_outputs; + auto remat = intrusive_ptr::make(Unsafe(), input_values, constants, remat_f); + for (const Tensor& t : outputs_raw) { + int alias = get_alias(inputs, t); + intrusive_ptr pool; + if (alias == -1) { + pool = intrusive_ptr::make(Unsafe(), true, memory(t)); } - arg += ")"; - std::string log = "("; - for (const auto& s: res) { - log += s; - log += ", "; + else if (auto* cpti = dynamic_cast(raw_input[alias].unsafeGetTensorImpl())) { + pool = cpti->ref->value->value->pool; + } else { // alias to an constant. unevictable. + pool = intrusive_ptr::make(Unsafe(), false, memory(t)); } - log += ") = "; - log += arg; - log += " TIME: "; - log += time; - logger.log(log); + auto e = intrusive_ptr::make(t, pool, remat); + pool->tensors.push_back(weak(e->value)); + outputs.push_back(e); + aliases.push_back(alias); + weak_outputs.push_back(weak(outputs.back()->value)); } + remat->outputs = weak_outputs; + return {outputs, aliases, post - pre, remat}; } -// return an index for alias. -// we dont care which one because they all lead to the same alias pool. -// return -1 for no alias. -// may god forgive my sin. -int get_alias(const Tensors& ts, const Tensor& t) { - if (t.defined()) { - for (size_t i = 0; i < ts.size(); ++i) { - Tensor tsd = ts[i].decheckpoint(); - if (tsd.defined() && t.is_alias_of(tsd)) { - return i; - } - } - } - return -1; +std::string from_time(duration_t t) { + return std::to_string(std::chrono::nanoseconds(t).count()); } Tensors CheckpointTensorImpl::make(const std::string& name, const rematerialize_function_t& remat, - const Tensors& input) { + const Tensors& inputs) { strongs input_values; std::vector> constants; std::vector constant_idx; std::vector args; - for (const Tensor& t: input) { - if (auto* cpt = dynamic_cast(t.unsafeGetTensorImpl())) { - input_values.push_back(cpt->ref->value); - args.push_back(cpt->counter_name()); + for (const Tensor& t: inputs) { + if (auto* cpti = dynamic_cast(t.unsafeGetTensorImpl())) { + input_values.push_back(cpti->ref->value->value); + args.push_back(cpti->counter_name()); } else { size_t idx = input_values.size() + constants.size(); @@ -268,59 +210,24 @@ Tensors CheckpointTensorImpl::make(const std::string& name, } } std::vector res; - auto ret = make_raw(remat, input_values, constants); + auto ret = make_raw(remat, inputs, input_values, constants); Tensors tensors; - for (const Tensor& t: std::get<0>(ret)) { - auto cp = checkpoint_raw(t); + for (const auto& t: ret.outputs) { + auto cp = Tensor(intrusive_ptr::make(t)); tensors.push_back(cp); res.push_back(get_cpti(cp)->counter_name()); } - DTRLogCall(res, name, args, constant_idx, from_time(std::get<1>(ret))); - for (const Tensor& t: tensors) { + DTRLogCall(res, name, args, constant_idx, from_time(ret.time)); + for (size_t i = 0; i < tensors.size(); ++i) { + Tensor t = tensors[i]; auto cpti = get_cpti(t); - DTRLogMemory(cpti->counter_name(), cpti->ref->value->memory()); - DTRLogAlias(cpti->counter_name(), get_alias(input, cpti->ref->value->t)); + DTRLogMemory(cpti->counter_name(), cpti->ref->value->value->memory()); + DTRLogAlias(cpti->counter_name(), ret.aliases[i]); } return tensors; } -void DTRLogMutate(const std::string& name, - const std::vector& args, - const std::vector& constants, - const std::vector& mutate, - const std::string& time) { - if (log_json) { - json j; - j[INSTRUCTION] = "MUTATE"; - j[NAME] = name; - j[ARGS] = args; - j[CONSTANTS] = constants; - j["MUTATE"] = mutate; - j[TIME] = time; - logger.log(j.dump()); - } else { - CHECK(constants.size() == 0); //TODO: implement. - std::string log = name; - log += "("; - for (const auto& s : args) { - log += s; - log += ", "; - } - log += ") "; - log += " MUTATING: "; - log += "("; - for (const size_t i : mutate) { - log += std::to_string(i); - log += ", "; - } - log += ") "; - log += TIME; - log += ": "; - log += time; - logger.log(log); - } -} - +// TODO: check that mutated value does not have alias. void CheckpointTensorImpl::mutate(const std::string& name, const mutate_function_t& mutate, const Tensors& inputs, @@ -338,9 +245,9 @@ void CheckpointTensorImpl::mutate(const std::string& name, std::vector constant_idx; std::vector args; for (const Tensor& t: inputs) { - if (auto* cpt = dynamic_cast(t.unsafeGetTensorImpl())) { - input_values.push_back(cpt->ref->value); - args.push_back(cpt->counter_name()); + if (auto* cpti = dynamic_cast(t.unsafeGetTensorImpl())) { + input_values.push_back(cpti->ref->value->value); + args.push_back(cpti->counter_name()); } else { size_t idx = input_values.size() + constants.size(); @@ -348,28 +255,17 @@ void CheckpointTensorImpl::mutate(const std::string& name, constant_idx.push_back(idx); } } - auto ret = make_raw(remat, input_values, constants); - const auto& modified = std::get<0>(ret); + auto ret = make_raw(remat, inputs, input_values, constants); + const auto& modified = ret.outputs; for (size_t idx: mutate_idx) { - cell_from_tensor(inputs[idx])->value = intrusive_ptr::make(modified[idx]); - } - DTRLogMutate(name, args, constant_idx, mutate_idx, from_time(std::get<1>(ret))); -} - -void DTRLogRelease(const std::string& counter_name) { - if (log_json) { - json j; - j[INSTRUCTION] = RELEASE; - j[NAME] = counter_name; - logger.log(j.dump()); - } else { - logger.log(RELEASE + ": " + counter_name); + cell_from_tensor(inputs[idx])->value = modified[idx]; } + DTRLogMutate(name, args, constant_idx, mutate_idx, from_time(ret.time)); } void CheckpointTensorImpl::release_resources() { DTRLogRelease(counter_name()); - ref.reset(); + ref.reset(); } } diff --git a/aten/src/ATen/CheckpointTensorImpl.h b/aten/src/ATen/CheckpointTensorImpl.h index 5733127646e..6166f4a44e9 100644 --- a/aten/src/ATen/CheckpointTensorImpl.h +++ b/aten/src/ATen/CheckpointTensorImpl.h @@ -21,25 +21,81 @@ #include #include +// System Description: +// Every Tensor is managed by a CheckPointTensor, +// that describe how it is computed, (the function and the inputs) +// And might optionally hold the tensor value. +// The tensor value might be dropped, and when requested later, recomputed and cached again. + +// Corner Cases: +// An CheckPointedTensor might be constant. +// In this case it is unevictable. +// An input might be uncheckpointed. +// In this case it is treated as a small constant and omitted from the system - it will be unevictable. +// An operator might return multiple output. +// In this case the computation info (rematerializer) is shared between all of them, +// And when the function get computed again all value get cached. +// An operator might not return value, but only mutate input value. +// To combat this, we COW the operator, and wrap CheckPopintTensor with a Ref. +// By doing this the inner CheckPointTensor is kept purely functional. +// An operator might try to mutate uncheckpointed tensor. +// We do not support this and will error. +// An operator might create aliases. +// We track alias in AliasPool. +// Each AliasPool hold a set of tensor that is alias to eachother. +// An operator might try to create Alias to an unevictable tensor. +// In such a case the output tensor is unevictable. +// An operator might try to mutate Tensor with Alias. +// We do not support this case an will error if a Tensor has any alive Alias. +// However it could be done without a major redesign of the system - +// Each AliasPool will hold weak pointers to the External Reference. +// When alias mutation occur, +// we make a rematerialize_function that take in the base tensor (other tensor alias from) +// and output all the new value of the aliases, then update the Ref. +// Of course, the cleaner way is to not support this. +// Shame on those who use this feature. + +// Memory Safety: +// The objects here will have lots of backedges. +// In order to collect memory when computation is completed, +// We require that all strong pointer is of the form of value -> input. +// This ensure that everything will be released if there is no external ref whatsoever. + +// Optimization: +// We treat tensor that has no external reference differently - +// They will never be externally used again so we assume their next use time is infinite +// so, if it doesnt has any evicted neighbor it will get evicted immediately. + +// Note: to code fast I do not use RAII and just assume the code will not try to recover from exception. +// It should be easy to fix though. + namespace at { -struct CAFFE2_API CheckpointTensorCell : intrusive_ptr_target { - Tensor t; - explicit CheckpointTensorCell(const Tensor& t) : t(t.detach()) { } - size_t memory() { - return t.defined() ? t.numel() * t.itemsize() : 0; +inline size_t memory(const Tensor& t) { + if (! t.has_storage()) { + return 0; } -}; + auto& storage = t.storage(); + return storage.numel() * storage.itemsize(); +} -struct CAFFE2_API CheckpointTensorImplCell : intrusive_ptr_target { - mutable intrusive_ptr value; - explicit CheckpointTensorImplCell(const intrusive_ptr& value) : value(value) { } - explicit CheckpointTensorImplCell(const Tensor& t) : value(intrusive_ptr::make(t)) { } +template +struct RefCell final : intrusive_ptr_target { + mutable T value; void release_resources() final { - value.reset(); + static_release_resources(value); } + RefCell(const T& t) : value(t) { } }; +template +using Ref = intrusive_ptr>; + +template +void static_release_resources(intrusive_ptr& ptr) { + ptr.reset(); +} + class CheckpointTensorCell; using strong = intrusive_ptr; using strongs = std::vector; @@ -52,6 +108,207 @@ using mutate_function_t = std::function; using time_t = std::chrono::time_point; using duration_t = std::chrono::system_clock::duration; +struct Unsafe { }; + +// Track all Tensor that share the same Storage. +// This is the atomic level of eviction - when evicting, everything here will get evicted. +// When an AliasPool is evicted, the Storage of the underlying tensor must be freed. +// Additionally, the AliasPool contain weak pointer to all children of tensors, +// in order to compute the score of evicting a Storage. +struct AliasPool : intrusive_ptr_target { + weaks tensors; + // get() might hold some raw Tensor, rendering them unevictable. + // it is likely that get() will run out of memory, and when it does so, it will try to evict. + // so, it is crucial that we dont try to evict those tensors - doing so will not evict anything. + // lock_count count how many time a tensor is referenced by get. + size_t lock_count; + bool evictable; + size_t memory; + AliasPool(const Unsafe&, bool evictable, size_t memory) : + lock_count(0), evictable(evictable), memory(memory) { + } + void evict(); + void release_resources() final { + tensors.clear(); + } +}; + +// The rematerializer could be called to reinvoke an operator. +// Tensor point to remat which point to Tensor. +// To build the cycle remat support a default constructor, +// And allow you to fill in the member later. +struct Rematerializer : intrusive_ptr_target { + // I am trying to represent a list of either checkpointedtensor or rawtensor. + // Is stitch the best way to do this? + // Maybe another approach is to use a list of tensor, and do dynamic downcasting? + // WHY DONT WE SIMPLY MAKE ALL CONSTANTS CHECKPOINTED TENSORS AS IS IN THE PREVIOUS VERSION? + // Oh, I remember, we are afraid that small tensors will get banished + // and make the big tensors unevictable. + // It sounds like a shitty reason - we can simply have an unbanishable flag + // as we do not rely on weak pointers anymore. + // And if we choose infinite staleness, then there is no need to deal with them specially - + // because they dont have rematerializer it will never get evicted. + // We should probably refactor and fix this, but it will take some nontrivial effort. + strongs input_values; + std::vector> constants; + weaks outputs; + rematerialize_function_t func; + Rematerializer(const Unsafe&, + const strongs& input_values, + const std::vector>& constants, + const rematerialize_function_t& func) : + input_values(input_values), + constants(constants), + func(func) { + } + void release_resources() final { + input_values.clear(); + constants.clear(); + outputs.clear(); + func = rematerialize_function_t(); + } + void remat(); +}; + +struct CAFFE2_API CheckpointTensorCell : intrusive_ptr_target { + std::unique_ptr t; + bool defined = false; + bool is_undefined_tensor; + DispatchKeySet key_set_; + DispatchKeySet key_set() const { + TORCH_CHECK(defined); + return key_set_; + } + caffe2::TypeMeta dtype_; + caffe2::TypeMeta dtype() const { + TORCH_CHECK(defined); + return dtype_; + } + c10::optional optional_device_; + c10::optional optional_device() const { + TORCH_CHECK(defined); + return optional_device_; + } + int64_t dim_, numel_; + size_t itemsize_; + std::vector sizes_, strides_; + // A Tensor is evictable iff it's AliasPool is evictable. + // A evictable tensor must have Rematerializer. + intrusive_ptr pool; + intrusive_ptr remat; + int64_t dim() const { + TORCH_CHECK(defined && !is_undefined_tensor); + return dim_; + } + int64_t numel() const { + TORCH_CHECK(defined && !is_undefined_tensor); + return numel_; + } + IntArrayRef sizes() const { + TORCH_CHECK(defined && !is_undefined_tensor); + return sizes_; + } + int64_t size(int64_t d) const { + TORCH_CHECK(defined && !is_undefined_tensor); + return sizes_[d]; + } + IntArrayRef strides() const { + TORCH_CHECK(defined && !is_undefined_tensor); + return strides_; + } + int64_t stride(int64_t d) const { + TORCH_CHECK(defined && !is_undefined_tensor); + return strides_[d]; + } + void evict() { + t.reset(); + } + void fill(const Tensor& t) { + if (!(this->t)) { + this->t = std::make_unique(t.detach()); + if (!defined) { + defined = true; + is_undefined_tensor = !t.defined(); + key_set_ = t.key_set(); + dtype_ = t.dtype(); + optional_device_ = t.optional_device(); + if (! is_undefined_tensor) { + dim_ = t.dim(); + numel_ = t.numel(); + itemsize_ = t.itemsize(); + sizes_ = t.sizes().vec(); + strides_ = t.strides().vec(); + } + } + } + } + explicit CheckpointTensorCell(const Tensor& t, const intrusive_ptr& pool) : pool(pool) { + fill(t); + } + explicit CheckpointTensorCell(const Tensor& t, + const intrusive_ptr& pool, + const intrusive_ptr& remat) : + pool(pool), remat(remat) { + fill(t); + } + size_t itemsize() { + return itemsize_; + } + size_t memory() { + TORCH_CHECK(defined); + return pool->memory; + } + Tensor get() { + if (! t) { + TORCH_CHECK(remat); + remat->remat(); + } + TORCH_CHECK(t); + TORCH_CHECK(! t->key_set().has(DispatchKey::CheckpointTensorId)) + return *t; + } + void pin() { + pool->evictable = false; + get(); + remat.reset(); + } + void release_resources() final { + t.reset(); + pool.reset(); + remat.reset(); + } +}; + +// CheckpointPool keep a list of AliasPool, and search over them to choose the best one to evict. +struct CheckpointPool { + static CheckpointPool& singleton() { + static CheckpointPool cpp; + return cpp; + } +}; + +// An external reference. +// Each strong will have at most one external reference. +// By keeping such an invariant, whenever an external reference die, +// We know that the underlying strong is only used internally. +// Thus, when it die we can apply optimization like banishing/infinite staleness. +// We keep this invariant by only allowing CheckpointTensorImpl to make new External, +// When new CheckpointTensorImpl is constructed. +struct External : intrusive_ptr_target { + External(const strong& value) : value(value) { } + External(const Tensor& value) : + value(intrusive_ptr::make(value, + intrusive_ptr::make(Unsafe(), + false, + memory(value)))) { } + External(const Tensor& value, + const intrusive_ptr& pool, + const intrusive_ptr& remat) : + value(intrusive_ptr::make(value, pool, remat)) { } + strong value; + void release_resources() override; +}; + inline DispatchKeySet convert_key_set(const DispatchKeySet& t) { CHECK(!t.has(DispatchKey::CheckpointTensorId)); auto ret = t.add(DispatchKey::CheckpointTensorId); @@ -68,12 +325,17 @@ struct CAFFE2_API CheckpointTensorImpl : TensorImpl { std::string counter_name() const { return std::string("x") + std::to_string(id); } - intrusive_ptr ref; + Ref> ref; void release_resources() final; - explicit CheckpointTensorImpl(const intrusive_ptr& ref) : TensorImpl(convert_key_set(ref->value->t.key_set()), - ref->value->t.dtype(), - ref->value->t.optional_device()), ref(ref) { } - explicit CheckpointTensorImpl(const Tensor& t) : CheckpointTensorImpl(intrusive_ptr::make(t)) { } + explicit CheckpointTensorImpl(const Ref>& ref) : + TensorImpl(convert_key_set(ref->value->value->key_set()), + ref->value->value->dtype(), + ref->value->value->optional_device()), + ref(ref) { } + explicit CheckpointTensorImpl(const intrusive_ptr& e) : + CheckpointTensorImpl(Ref>::make(e)) { } + explicit CheckpointTensorImpl(const Tensor& t) : + CheckpointTensorImpl(intrusive_ptr::make(t)) { } static Tensors make(const std::string& name, const rematerialize_function_t& remat, const Tensors& inputs); @@ -86,19 +348,22 @@ struct CAFFE2_API CheckpointTensorImpl : TensorImpl { bool allow_tensor_metadata_change) const override; void shallow_copy_from(const c10::intrusive_ptr& impl) override; int64_t dim() const override { - return ref->value->t.dim(); + return ref->value->value->dim(); } int64_t numel() const override { - return ref->value->t.numel(); + return ref->value->value->numel(); } IntArrayRef sizes() const override { - return ref->value->t.sizes(); + return ref->value->value->sizes(); } int64_t size(int64_t d) const override { - return ref->value->t.size(d); + return ref->value->value->size(d); } IntArrayRef strides() const override { - return ref->value->t.strides(); + return ref->value->value->strides(); + } + int64_t stride(int64_t d) const override { + return ref->value->value->stride(d); } bool has_storage() const override { return false; @@ -111,11 +376,7 @@ inline CheckpointTensorImpl* get_cpti(const Tensor& t) { return cpti; } -inline Tensor get(const strong& s) { - return s->t; -} - -inline intrusive_ptr cell_from_tensor(const Tensor& t) { +inline Ref> cell_from_tensor(const Tensor& t) { return get_cpti(t)->ref; } diff --git a/aten/src/ATen/Logger.h b/aten/src/ATen/Logger.h new file mode 100644 index 00000000000..0868fb38050 --- /dev/null +++ b/aten/src/ATen/Logger.h @@ -0,0 +1,192 @@ +#pragma once + +#include +#include +#include <../../../third_party/json/single_include/nlohmann/json.hpp> + +namespace at { + +struct DTRLogger { + std::string time_prefix; + std::ofstream out; + static std::string get_time_prefix() { + std::time_t t = std::time(nullptr); + std::tm* tm = std::localtime(&t); + return + std::to_string(1900+tm->tm_year) + "-" + + std::to_string(1+tm->tm_mon) + "-" + + std::to_string(tm->tm_mday) + "-" + + std::to_string(tm->tm_hour) + "-" + + std::to_string(tm->tm_min) + "-" + + std::to_string(tm->tm_sec); + } + std::string get_filename(const std::string& name) { + return time_prefix + "-" + name + ".log"; + } + DTRLogger() : time_prefix(get_time_prefix()), out(get_filename("default")) { } + void log(const std::string& str) { + out << str << std::endl; + } + static DTRLogger& logger() { + static DTRLogger ret; + return ret; + } + +}; + +using json = nlohmann::json; +constexpr bool log_json = true; +const std::string INSTRUCTION = "INSTRUCTION"; +const std::string ANNOTATION = "ANNOTATION"; +const std::string RELEASE = "RELEASE"; +const std::string TIME = "TIME"; +const std::string ARGS = "ARGS"; +const std::string MEMORY = "MEMORY"; +const std::string ALIAS = "ALIAS"; +const std::string NAME = "NAME"; +const std::string CONSTANT = "CONSTANT"; +const std::string CONSTANTS = "CONSTANTS"; + +void DTRLogConstant(const std::string& name) { + if (log_json) { + json j; + j[INSTRUCTION] = CONSTANT; + j[NAME] = name; + DTRLogger::logger().log(j.dump()); + } else { + DTRLogger::logger().log(CONSTANT + " " + name); + } +} + +void DTRLogMemory(const std::string& name, size_t memory) { + if (log_json) { + json j; + j[INSTRUCTION] = MEMORY; + j[NAME] = name; + j[MEMORY] = std::to_string(memory); + DTRLogger::logger().log(j.dump()); + } else { + DTRLogger::logger().log(name + " " + MEMORY + ": " + std::to_string(memory)); + } +} + +void DTRLogAlias(const std::string& name, int index) { + if (log_json) { + json j; + j[INSTRUCTION] = ALIAS; + j[NAME] = name; + j[ALIAS] = std::to_string(index); + DTRLogger::logger().log(j.dump()); + } else { + DTRLogger::logger().log(name + " " + ALIAS + ": " + std::to_string(index)); + } +} + +void DTRLogCopyFrom(const std::string& to, const std::string& from) { + if (log_json) { + json j; + j[INSTRUCTION] = "COPY_FROM"; + j["DST"] = to; + j["SRC"] = from; + DTRLogger::logger().log(j.dump()); + } else { + DTRLogger::logger().log(to + " <- " + from); + } +} + +void DTRLogCopy(const std::string& new_name, const std::string& old_name) { + if (log_json) { + json j; + j[INSTRUCTION] = "COPY"; + j["DST"] = new_name; + j["SRC"] = old_name; + DTRLogger::logger().log(j.dump()); + } else { + DTRLogger::logger().log(new_name + " = " + old_name); + } +} + +void DTRLogMutate(const std::string& name, + const std::vector& args, + const std::vector& constants, + const std::vector& mutate, + const std::string& time) { + if (log_json) { + json j; + j[INSTRUCTION] = "MUTATE"; + j[NAME] = name; + j[ARGS] = args; + j[CONSTANTS] = constants; + j["MUTATE"] = mutate; + j[TIME] = time; + DTRLogger::logger().log(j.dump()); + } else { + CHECK(constants.size() == 0); //TODO: implement. + std::string log = name; + log += "("; + for (const auto& s : args) { + log += s; + log += ", "; + } + log += ") "; + log += " MUTATING: "; + log += "("; + for (const size_t i : mutate) { + log += std::to_string(i); + log += ", "; + } + log += ") "; + log += TIME; + log += ": "; + log += time; + DTRLogger::logger().log(log); + } +} + +void DTRLogRelease(const std::string& counter_name) { + if (log_json) { + json j; + j[INSTRUCTION] = RELEASE; + j[NAME] = counter_name; + DTRLogger::logger().log(j.dump()); + } else { + DTRLogger::logger().log(RELEASE + ": " + counter_name); + } +} + +void DTRLogCall(const std::vector& res, + const std::string& name, + const std::vector& args, + const std::vector& constants, + const std::string& time) { + if (log_json) { + json j; + j[INSTRUCTION] = "CALL"; + j[NAME] = name; + j["RESULT"] = res; + j[ARGS] = args; + j[CONSTANTS] = constants; + j[TIME] = time; + DTRLogger::logger().log(j.dump()); + } else { + CHECK(constants.size() == 0); //TODO: implement. + std::string arg = name + "("; + for (const auto& s : args) { + arg += s; + arg += ", "; + } + arg += ")"; + std::string log = "("; + for (const auto& s: res) { + log += s; + log += ", "; + } + log += ") = "; + log += arg; + log += " TIME: "; + log += time; + DTRLogger::logger().log(log); + } +} + +} diff --git a/aten/src/ATen/native/Activation.cpp b/aten/src/ATen/native/Activation.cpp index a5e64344946..b31c46ec8fd 100644 --- a/aten/src/ATen/native/Activation.cpp +++ b/aten/src/ATen/native/Activation.cpp @@ -724,4 +724,82 @@ Tensor select_backward(const Tensor& grad, IntArrayRef input_sizes, int64_t dim, return grad_input; } +std::vector split(const Tensor& self, int64_t split_size, int64_t dim) { + TORCH_CHECK(self.dim() != 0, "split expects at least a 1-dimensional tensor"); + TORCH_CHECK(split_size >= 0, "split expects split_size be non-negative, but got split_size=", split_size); + int64_t dim_size = self.size(dim); + TORCH_CHECK(split_size > 0 || self.size(dim) == 0, + "split_size can only be 0 if dimension size is 0, " + "but got dimension size of ", dim_size); + // if split_size is 0 and dimension size is 0, there is 1 split. + int64_t num_splits = 1; + if (split_size != 0) { + // ensuring num_splits is at least 1 makes consistent the case where split_size > dim_size + // (returns a single split). We might want to error here, but keep it for BC. + num_splits = std::max((dim_size + split_size - 1) / split_size, 1); + } + std::vector splits(num_splits); + int64_t last_split_size = split_size - (split_size * num_splits - dim_size); + + for (int64_t i = 0; i < num_splits; ++i) { + auto length = i < num_splits - 1 ? split_size : last_split_size; + splits[i] = self.narrow(dim, i * split_size, length); + } + return splits; +} + +std::vector split_with_sizes(const Tensor& self, IntArrayRef split_sizes, int64_t dim) { + TORCH_CHECK(self.dim() != 0, "split expects at least a 1-dimensional tensor"); + int64_t dim_size = self.size(dim); + int64_t num_splits = split_sizes.size(); + std::vector splits(num_splits); + int64_t start_idx = 0; + int64_t i; + + for (i = 0; i < num_splits; ++i) { + auto length = split_sizes[i]; + TORCH_CHECK(length >= 0, + "split_with_sizes expects split_sizes have only non-negative ", + "entries, but got split_sizes=", split_sizes); + splits[i] = self.narrow(dim, start_idx, length); + start_idx += length; + } + TORCH_CHECK(start_idx == dim_size, + "split_with_sizes expects split_sizes to sum exactly to ", dim_size, + " (input tensor's size at dimension ", dim, "), ", "but got split_sizes=", split_sizes); + return splits; +} + +Tensor split_backward(c10::ArrayRef grads, + int64_t split_size, int64_t dim, IntArrayRef sizes, const at::TensorOptions &options) { + dim = at::maybe_wrap_dim(dim, sizes.size()); + int64_t dim_size = sizes[dim]; + int64_t num_splits = grads.size(); + std::vector split_sizes(num_splits, split_size); + split_sizes[num_splits - 1] = split_size - (split_size * num_splits - dim_size); + return at::native::split_with_sizes_backward(grads, split_sizes, dim, sizes, options); +} + +Tensor split_with_sizes_backward(c10::ArrayRef grads, + IntArrayRef split_sizes, int64_t dim, IntArrayRef sizes, const at::TensorOptions &options) { + dim = at::maybe_wrap_dim(dim, sizes.size()); + + // it's possible some of the grads are not defined (represents tensors of all 0s). + // Since at::cat can't handle those, let's define them + std::vector grads_all_defined(grads.size()); + for (size_t j = 0; j < grads.size(); ++j) { + if (grads[j].defined()) { + grads_all_defined[j] = grads[j]; + } else { + auto length = split_sizes[j]; + auto grad_size = sizes.vec(); + grad_size[dim] = length; + grads_all_defined[j] = at::zeros(grad_size, options); + } + } + + auto ret = at::cat(grads_all_defined, dim); + return ret; +} + }} // namespace at::native diff --git a/aten/src/ATen/native/Checkpoint.cpp b/aten/src/ATen/native/Checkpoint.cpp index 6d8f0331601..8615f7f0ccf 100644 --- a/aten/src/ATen/native/Checkpoint.cpp +++ b/aten/src/ATen/native/Checkpoint.cpp @@ -902,7 +902,6 @@ Tensor& checkpoint_sub_(at::Tensor& a, at::Tensor const& b, c10::Scalar c) { return a; } - Tensor checkpoint_repeat(const at::Tensor& a, c10::ArrayRef b) { std::vector b_ = b.vec(); rematerialize_function_t rt = @@ -1536,4 +1535,50 @@ Scalar checkpoint__local_scalar_dense(at::Tensor const& a) { return at::_local_scalar_dense(decheckpoint(a)); } +Tensor checkpoint_split_with_sizes_backward(c10::ArrayRef a, c10::ArrayRef b, long c, c10::ArrayRef d, c10::TensorOptions const& e) { + std::vector a_ = a.vec(); + std::vector d_ = d.vec(); + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::split_with_sizes_backward(vec, b, c, d_, e)}; + }; + return CheckpointTensorImpl::make("split_with_sizes_backward", rt, a_)[0]; +} + +std::vector checkpoint_split_with_sizes(at::Tensor const& a, c10::ArrayRef b, long c) { + std::vector b_ = b.vec(); + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return at::split_with_sizes(vec.at(0), b_, c); + }; + return CheckpointTensorImpl::make("split_with_sizes", rt, {a}); +} + +Tensor checkpoint_split_backward(c10::ArrayRef a, long b, long c, c10::ArrayRef d, const c10::TensorOptions& e) { + std::vector a_ = a.vec(); + std::vector d_ = d.vec(); + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::split_backward(vec, b, c, d_, e)}; + }; + return CheckpointTensorImpl::make("split_backward", rt, a_)[0]; +} + +std::vector checkpoint_split(const at::Tensor& a, long b, long c) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return at::split(vec.at(0), b, c); + }; + return CheckpointTensorImpl::make("split", rt, {a}); +} + +Tensor checkpoint_expand(at::Tensor const& a, c10::ArrayRef b, bool c) { + std::vector b_ = b.vec(); + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {vec.at(0).expand(b_, c)}; + }; + return CheckpointTensorImpl::make("expand", rt, {a})[0]; +} + }} diff --git a/aten/src/ATen/native/TensorShape.cpp b/aten/src/ATen/native/TensorShape.cpp index 14e3a877634..49102d33233 100644 --- a/aten/src/ATen/native/TensorShape.cpp +++ b/aten/src/ATen/native/TensorShape.cpp @@ -888,52 +888,6 @@ Tensor slice(const Tensor& self, int64_t dim, int64_t start, int64_t end, int64_ return result; } -std::vector split(const Tensor& self, int64_t split_size, int64_t dim) { - TORCH_CHECK(self.dim() != 0, "split expects at least a 1-dimensional tensor"); - TORCH_CHECK(split_size >= 0, "split expects split_size be non-negative, but got split_size=", split_size); - int64_t dim_size = self.size(dim); - TORCH_CHECK(split_size > 0 || self.size(dim) == 0, - "split_size can only be 0 if dimension size is 0, " - "but got dimension size of ", dim_size); - // if split_size is 0 and dimension size is 0, there is 1 split. - int64_t num_splits = 1; - if (split_size != 0) { - // ensuring num_splits is at least 1 makes consistent the case where split_size > dim_size - // (returns a single split). We might want to error here, but keep it for BC. - num_splits = std::max((dim_size + split_size - 1) / split_size, 1); - } - std::vector splits(num_splits); - int64_t last_split_size = split_size - (split_size * num_splits - dim_size); - - for (int64_t i = 0; i < num_splits; ++i) { - auto length = i < num_splits - 1 ? split_size : last_split_size; - splits[i] = self.narrow(dim, i * split_size, length); - } - return splits; -} - -std::vector split_with_sizes(const Tensor& self, IntArrayRef split_sizes, int64_t dim) { - TORCH_CHECK(self.dim() != 0, "split expects at least a 1-dimensional tensor"); - int64_t dim_size = self.size(dim); - int64_t num_splits = split_sizes.size(); - std::vector splits(num_splits); - int64_t start_idx = 0; - int64_t i; - - for (i = 0; i < num_splits; ++i) { - auto length = split_sizes[i]; - TORCH_CHECK(length >= 0, - "split_with_sizes expects split_sizes have only non-negative ", - "entries, but got split_sizes=", split_sizes); - splits[i] = self.narrow(dim, start_idx, length); - start_idx += length; - } - TORCH_CHECK(start_idx == dim_size, - "split_with_sizes expects split_sizes to sum exactly to ", dim_size, - " (input tensor's size at dimension ", dim, "), ", "but got split_sizes=", split_sizes); - return splits; -} - // Precondition: tensors is non-empty static inline std::vector get_stack_inputs(TensorList tensors, int64_t dim) { std::vector inputs(tensors.size()); diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index bc33f2a6ba4..667250ce86e 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -15,12 +15,18 @@ - func: decheckpoint(Tensor self) -> Tensor variants: method +- func: pin(Tensor(a!) self) -> () + variants: method + - func: new_log(str logname) -> () variants: function - func: annotate_log(str logname) -> () variants: function +- func: clear_checkpointpool() -> () + variants: function + # Temporary type cast operators. These are needed to trace type-casts now since # Type's are not supported in the IR. Instead, we call down to these # specialized operators for each datatype. @@ -1326,6 +1332,10 @@ variants: method # This is method-only to match the previous tensor API. In the future we could make this a function too. device_guard: False supports_named_tensor: True + dispatch: + CPU: expand + CUDA: expand + Checkpoint: checkpoint_expand - func: expand_as(Tensor self, Tensor other) -> Tensor use_c10_dispatcher: full @@ -2713,11 +2723,33 @@ variants: function, method device_guard: False supports_named_tensor: True + dispatch: + CPU: split + CUDA: split + Checkpoint: checkpoint_split + +- func: split_backward(Tensor[] grads, int split_size, int dim, int[] sizes, ScalarType dtype, Layout layout, Device device, bool pin_memory=False) -> Tensor + variants: function + dispatch: + CPU: split_backward + CUDA: split_backward + Checkpoint: checkpoint_split_backward - func: split_with_sizes(Tensor self, int[] split_sizes, int dim=0) -> Tensor[] variants: function, method device_guard: False supports_named_tensor: True + dispatch: + CPU: split_with_sizes + CUDA: split_with_sizes + Checkpoint: checkpoint_split_with_sizes + +- func: split_with_sizes_backward(Tensor[] grads, int[] split_sizes, int dim, int[] sizes, ScalarType dtype, Layout layout, Device device, bool pin_memory=False) -> Tensor + variants: function + dispatch: + CPU: split_with_sizes_backward + CUDA: split_with_sizes_backward + Checkpoint: checkpoint_split_with_sizes_backward - func: squeeze(Tensor(a) self) -> Tensor(a) supports_named_tensor: True diff --git a/c10/cuda/CUDACachingAllocator.cpp b/c10/cuda/CUDACachingAllocator.cpp index 8c272c2ce5f..f8f7fb421b7 100644 --- a/c10/cuda/CUDACachingAllocator.cpp +++ b/c10/cuda/CUDACachingAllocator.cpp @@ -22,6 +22,15 @@ namespace c10 { C10_DEFINE_REGISTRY(FreeCudaMemoryCallbacksRegistry, FreeMemoryCallback); +evict_func_t evict_func = nullptr; +void set_evict_func(evict_func_t ef) { + evict_func = ef; +} + +evict_func_t get_evict_func() { + return evict_func; +} + namespace cuda { namespace CUDACachingAllocator { @@ -201,10 +210,21 @@ class THCCachingAllocator { // Thus, do not call a public method from another public method. /** allocates a block which is safe to use from the provided stream */ + // Technically speaking, it is still allocating more memory then it should, + // But it doesn't do anything with it until more memory are found, so it is morally ok - no experimental result will be changed. + // TODO: fix it and make it to be always below limit, so ppl can set the limit to be GPU max memory and it will still work. void malloc(void** devPtr, size_t size, cudaStream_t stream) { - std::lock_guard lock(mutex); + malloc_inner(devPtr, size, stream); + auto evict_func = get_evict_func(); + if (evict_func) { + (*evict_func)(); + } + } + void malloc_inner(void** devPtr, size_t size, cudaStream_t stream) + { + std::lock_guard lock(mutex); int device; C10_CUDA_CHECK(cudaGetDevice(&device)); diff --git a/c10/cuda/CUDACachingAllocator.h b/c10/cuda/CUDACachingAllocator.h index fecd12179e2..e119f259b96 100644 --- a/c10/cuda/CUDACachingAllocator.h +++ b/c10/cuda/CUDACachingAllocator.h @@ -27,6 +27,10 @@ C10_DECLARE_REGISTRY(FreeCudaMemoryCallbacksRegistry, FreeMemoryCallback); #define REGISTER_FREE_MEMORY_CALLBACK(name, ...) \ C10_REGISTER_CLASS(FreeCudaMemoryCallbacksRegistry, name, __VA_ARGS__); +typedef void (*evict_func_t)(); +C10_CUDA_API void set_evict_func(evict_func_t); +C10_CUDA_API evict_func_t get_evict_func(); + namespace cuda { // TODO: Turn this into an honest to goodness class. I briefly attempted to do diff --git a/test.py b/test.py index cdea37872be..9a6ee104db0 100644 --- a/test.py +++ b/test.py @@ -1,6 +1,4 @@ import torch -torch.annotate_log("hello") x = torch.Tensor([1]).checkpoint() -y = torch.Tensor([2]).checkpoint() -z = x + y -x.data = z +y = x +z = y diff --git a/tools/autograd/templates/Functions.cpp b/tools/autograd/templates/Functions.cpp index 03b6fc5b4cc..f8a3a8351b3 100644 --- a/tools/autograd/templates/Functions.cpp +++ b/tools/autograd/templates/Functions.cpp @@ -790,38 +790,6 @@ Tensor cholesky_inverse_backward(Tensor grad, Tensor L, bool upper, Tensor inver return grad_L; } -Tensor split_with_sizes_backward(const std::vector &grads, - IntArrayRef split_sizes, int64_t dim, IntArrayRef sizes, const at::TensorOptions &options) { - dim = at::maybe_wrap_dim(dim, sizes.size()); - - // it's possible some of the grads are not defined (represents tensors of all 0s). - // Since at::cat can't handle those, let's define them - std::vector grads_all_defined(grads.size()); - for (size_t j = 0; j < grads.size(); ++j) { - if (grads[j].defined()) { - grads_all_defined[j] = grads[j]; - } else { - auto length = split_sizes[j]; - auto grad_size = sizes.vec(); - grad_size[dim] = length; - grads_all_defined[j] = at::zeros(grad_size, options); - } - } - - auto ret = at::cat(grads_all_defined, dim); - return ret; -} - -Tensor split_backward(const std::vector &grads, - int64_t split_size, int64_t dim, IntArrayRef sizes, const at::TensorOptions &options) { - dim = at::maybe_wrap_dim(dim, sizes.size()); - int64_t dim_size = sizes[dim]; - int64_t num_splits = grads.size(); - std::vector split_sizes(num_splits, split_size); - split_sizes[num_splits - 1] = split_size - (split_size * num_splits - dim_size); - return split_with_sizes_backward(grads, split_sizes, dim, sizes, options); -} - Tensor max_pool_double_backward(const Tensor & grad, const Tensor & indices, int dim) { AT_ASSERT(indices.dim() >= dim); auto size = indices.sizes().slice(0, indices.dim() - dim).vec(); From 73f681c09c9555b9d2367840053545b4e5ddb8f6 Mon Sep 17 00:00:00 2001 From: Marisa Kirisame Date: Sat, 23 May 2020 08:31:33 -0700 Subject: [PATCH 21/42] refactor - remove stitch --- aten/src/ATen/CheckpointTensorImpl.cpp | 172 ++++++++++++------------- aten/src/ATen/CheckpointTensorImpl.h | 30 +---- aten/src/ATen/Logger.h | 7 - 3 files changed, 86 insertions(+), 123 deletions(-) diff --git a/aten/src/ATen/CheckpointTensorImpl.cpp b/aten/src/ATen/CheckpointTensorImpl.cpp index d7f1e1ba9dd..1a593e9e252 100644 --- a/aten/src/ATen/CheckpointTensorImpl.cpp +++ b/aten/src/ATen/CheckpointTensorImpl.cpp @@ -4,58 +4,6 @@ namespace at { -void AliasPool::evict() { - TORCH_CHECK(lock_count == 0); - for (const weak& w : tensors) { - if (auto cell = w.lock()) { - cell->evict(); - } - } -} - -void External::release_resources() { - value->evict(); - value.reset(); -} - -Tensors stitch(const strongs& input_values, - const std::vector>& constants) { - Tensors input; - size_t i = 0, j = 0; - while (i != input_values.size() || j != constants.size()) { - if (j < constants.size() && std::get<1>(constants[j]) == input.size()) { - Tensor t = std::get<0>(constants[j]); - TORCH_CHECK(!t.key_set().has(DispatchKey::CheckpointTensorId)); - input.push_back(t); - ++j; - } - else { - CHECK(i < input_values.size()); - input.push_back(input_values[i]->get()); - ++i; - } - } - return input; -} - -void Rematerializer::remat() { - // TODO: refactor using RAII for exception safety. - for (const strong& s : input_values) { - ++(s->pool->lock_count); - } - Tensors ts = stitch(input_values, constants); - auto ret = func(ts); - TORCH_CHECK(ret.size() == outputs.size()); - for (size_t i = 0; i < outputs.size(); ++i) { - if (auto output_cell = outputs[i].lock()) { - output_cell->fill(ret[i]); - } - } - for (const strong& s : input_values) { - --(s->pool->lock_count); - } -} - namespace native { Tensor checkpoint(const Tensor& t) { @@ -87,6 +35,10 @@ bool is_checkpoint(const Tensor& t) { return cpti != nullptr; } +Tensor try_checkpoint(const Tensor& t) { + return is_checkpoint(t) ? t : checkpoint(t); +} + void new_log(std::string str) { DTRLogger::logger().out = std::ofstream(DTRLogger::logger().get_filename(str)); } @@ -108,6 +60,58 @@ void clear_checkpointpool() { } +Tensor uncheckpoint(const strong& input) { + return input->get(); +} + +Tensors uncheckpoint(const strongs& inputs) { + Tensors ret; + for (const strong& input : inputs) { + ret.push_back(uncheckpoint(input)); + } + return ret; +}; + +Tensors try_checkpoint(const Tensors& inputs) { + Tensors ret; + for (const Tensor& input : inputs) { + ret.push_back(at::native::try_checkpoint(input)); + } + return ret; +} + +void AliasPool::evict() { + TORCH_CHECK(lock_count == 0); + for (const weak& w : tensors) { + if (auto cell = w.lock()) { + cell->evict(); + } + } +} + +void External::release_resources() { + value->evict(); + value.reset(); +} + +void Rematerializer::remat() { + // TODO: refactor using RAII for exception safety. + for (const strong& s : inputs) { + ++(s->pool->lock_count); + } + Tensors ts = uncheckpoint(inputs); + auto ret = func(ts); + TORCH_CHECK(ret.size() == outputs.size()); + for (size_t i = 0; i < outputs.size(); ++i) { + if (auto output_cell = outputs[i].lock()) { + output_cell->fill(ret[i]); + } + } + for (const strong& s : inputs) { + --(s->pool->lock_count); + } +} + intrusive_ptr CheckpointTensorImpl::shallow_copy_and_detach(const VariableVersion& version_counter, bool allow_tensor_metadata_change) const { auto ret = intrusive_ptr::make(ref); @@ -153,29 +157,23 @@ struct MakeRawResult { // however, we have to stitch the two vectors together to pass it in remat. // the size_t in constants decide the location to stitch them in, while input_values fill in the rest. MakeRawResult make_raw(const rematerialize_function_t& remat_f, - // We need this to assign alias pool. - // This is ugly as fuck but after refactoring we dont even need stitching anymore. - const Tensors& raw_input, - const strongs& input_values, - const std::vector>& constants) { - Tensors inputs = stitch(input_values, constants); + const strongs& inputs) { + Tensors raw_inputs = uncheckpoint(inputs); time_t pre = std::chrono::system_clock::now(); - auto outputs_raw = remat_f(inputs); + auto outputs_raw = remat_f(raw_inputs); time_t post = std::chrono::system_clock::now(); std::vector> outputs; std::vector aliases; weaks weak_outputs; - auto remat = intrusive_ptr::make(Unsafe(), input_values, constants, remat_f); + auto remat = intrusive_ptr::make(Unsafe(), remat_f, inputs); for (const Tensor& t : outputs_raw) { - int alias = get_alias(inputs, t); + int alias = get_alias(raw_inputs, t); intrusive_ptr pool; if (alias == -1) { pool = intrusive_ptr::make(Unsafe(), true, memory(t)); } - else if (auto* cpti = dynamic_cast(raw_input[alias].unsafeGetTensorImpl())) { - pool = cpti->ref->value->value->pool; - } else { // alias to an constant. unevictable. - pool = intrusive_ptr::make(Unsafe(), false, memory(t)); + else { + pool = inputs[alias]->pool; } auto e = intrusive_ptr::make(t, pool, remat); pool->tensors.push_back(weak(e->value)); @@ -194,30 +192,24 @@ std::string from_time(duration_t t) { Tensors CheckpointTensorImpl::make(const std::string& name, const rematerialize_function_t& remat, const Tensors& inputs) { + Tensors checkpointed_inputs = try_checkpoint(inputs); strongs input_values; - std::vector> constants; - std::vector constant_idx; std::vector args; - for (const Tensor& t: inputs) { - if (auto* cpti = dynamic_cast(t.unsafeGetTensorImpl())) { - input_values.push_back(cpti->ref->value->value); - args.push_back(cpti->counter_name()); - } - else { - size_t idx = input_values.size() + constants.size(); - constants.push_back({t, idx}); - constant_idx.push_back(idx); - } + for (const Tensor& t: checkpointed_inputs) { + auto* cpti = dynamic_cast(t.unsafeGetTensorImpl()); + TORCH_CHECK(cpti); + input_values.push_back(cpti->ref->value->value); + args.push_back(cpti->counter_name()); } std::vector res; - auto ret = make_raw(remat, inputs, input_values, constants); + auto ret = make_raw(remat, input_values); Tensors tensors; for (const auto& t: ret.outputs) { auto cp = Tensor(intrusive_ptr::make(t)); tensors.push_back(cp); res.push_back(get_cpti(cp)->counter_name()); } - DTRLogCall(res, name, args, constant_idx, from_time(ret.time)); + DTRLogCall(res, name, args, from_time(ret.time)); for (size_t i = 0; i < tensors.size(); ++i) { Tensor t = tensors[i]; auto cpti = get_cpti(t); @@ -240,27 +232,21 @@ void CheckpointTensorImpl::mutate(const std::string& name, mutate(new_input_values); return new_input_values; }; + Tensors checkpointed_inputs = try_checkpoint(inputs); strongs input_values; - std::vector> constants; - std::vector constant_idx; std::vector args; - for (const Tensor& t: inputs) { - if (auto* cpti = dynamic_cast(t.unsafeGetTensorImpl())) { - input_values.push_back(cpti->ref->value->value); - args.push_back(cpti->counter_name()); - } - else { - size_t idx = input_values.size() + constants.size(); - constants.push_back({t, idx}); - constant_idx.push_back(idx); - } + for (const Tensor& t: checkpointed_inputs) { + auto* cpti = dynamic_cast(t.unsafeGetTensorImpl()); + TORCH_CHECK(cpti); + input_values.push_back(cpti->ref->value->value); + args.push_back(cpti->counter_name()); } - auto ret = make_raw(remat, inputs, input_values, constants); + auto ret = make_raw(remat, input_values); const auto& modified = ret.outputs; for (size_t idx: mutate_idx) { cell_from_tensor(inputs[idx])->value = modified[idx]; } - DTRLogMutate(name, args, constant_idx, mutate_idx, from_time(ret.time)); + DTRLogMutate(name, args, mutate_idx, from_time(ret.time)); } void CheckpointTensorImpl::release_resources() { diff --git a/aten/src/ATen/CheckpointTensorImpl.h b/aten/src/ATen/CheckpointTensorImpl.h index 6166f4a44e9..7aa9530f124 100644 --- a/aten/src/ATen/CheckpointTensorImpl.h +++ b/aten/src/ATen/CheckpointTensorImpl.h @@ -138,34 +138,18 @@ struct AliasPool : intrusive_ptr_target { // To build the cycle remat support a default constructor, // And allow you to fill in the member later. struct Rematerializer : intrusive_ptr_target { - // I am trying to represent a list of either checkpointedtensor or rawtensor. - // Is stitch the best way to do this? - // Maybe another approach is to use a list of tensor, and do dynamic downcasting? - // WHY DONT WE SIMPLY MAKE ALL CONSTANTS CHECKPOINTED TENSORS AS IS IN THE PREVIOUS VERSION? - // Oh, I remember, we are afraid that small tensors will get banished - // and make the big tensors unevictable. - // It sounds like a shitty reason - we can simply have an unbanishable flag - // as we do not rely on weak pointers anymore. - // And if we choose infinite staleness, then there is no need to deal with them specially - - // because they dont have rematerializer it will never get evicted. - // We should probably refactor and fix this, but it will take some nontrivial effort. - strongs input_values; - std::vector> constants; - weaks outputs; rematerialize_function_t func; + strongs inputs; + weaks outputs; Rematerializer(const Unsafe&, - const strongs& input_values, - const std::vector>& constants, - const rematerialize_function_t& func) : - input_values(input_values), - constants(constants), - func(func) { + const rematerialize_function_t& func, + const strongs& inputs) : + func(func), inputs(inputs) { } void release_resources() final { - input_values.clear(); - constants.clear(); - outputs.clear(); func = rematerialize_function_t(); + inputs.clear(); + outputs.clear(); } void remat(); }; diff --git a/aten/src/ATen/Logger.h b/aten/src/ATen/Logger.h index 0868fb38050..190c2f28433 100644 --- a/aten/src/ATen/Logger.h +++ b/aten/src/ATen/Logger.h @@ -45,7 +45,6 @@ const std::string MEMORY = "MEMORY"; const std::string ALIAS = "ALIAS"; const std::string NAME = "NAME"; const std::string CONSTANT = "CONSTANT"; -const std::string CONSTANTS = "CONSTANTS"; void DTRLogConstant(const std::string& name) { if (log_json) { @@ -108,7 +107,6 @@ void DTRLogCopy(const std::string& new_name, const std::string& old_name) { void DTRLogMutate(const std::string& name, const std::vector& args, - const std::vector& constants, const std::vector& mutate, const std::string& time) { if (log_json) { @@ -116,12 +114,10 @@ void DTRLogMutate(const std::string& name, j[INSTRUCTION] = "MUTATE"; j[NAME] = name; j[ARGS] = args; - j[CONSTANTS] = constants; j["MUTATE"] = mutate; j[TIME] = time; DTRLogger::logger().log(j.dump()); } else { - CHECK(constants.size() == 0); //TODO: implement. std::string log = name; log += "("; for (const auto& s : args) { @@ -157,7 +153,6 @@ void DTRLogRelease(const std::string& counter_name) { void DTRLogCall(const std::vector& res, const std::string& name, const std::vector& args, - const std::vector& constants, const std::string& time) { if (log_json) { json j; @@ -165,11 +160,9 @@ void DTRLogCall(const std::vector& res, j[NAME] = name; j["RESULT"] = res; j[ARGS] = args; - j[CONSTANTS] = constants; j[TIME] = time; DTRLogger::logger().log(j.dump()); } else { - CHECK(constants.size() == 0); //TODO: implement. std::string arg = name + "("; for (const auto& s : args) { arg += s; From ec0fe999ca39f6279e15ba8492a24798a9195fa0 Mon Sep 17 00:00:00 2001 From: Marisa Kirisame Date: Sat, 23 May 2020 09:54:46 -0700 Subject: [PATCH 22/42] save --- aten/src/ATen/native/native_functions.yaml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 667250ce86e..87cd7cb7165 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -3,6 +3,9 @@ - func: checkpoint(Tensor self) -> Tensor variants: method +- func: try_checkpoint(Tensor self) -> Tensor + variants: method + - func: is_checkpoint(Tensor self) -> bool variants: method From 41f694b79fafaced45b82be151ddb9e6aca9889a Mon Sep 17 00:00:00 2001 From: Marisa Kirisame Date: Sun, 24 May 2020 05:05:54 -0700 Subject: [PATCH 23/42] save --- aten/src/ATen/CheckpointTensorImpl.cpp | 51 ++++++++++++++++------ aten/src/ATen/native/native_functions.yaml | 7 +++ 2 files changed, 44 insertions(+), 14 deletions(-) diff --git a/aten/src/ATen/CheckpointTensorImpl.cpp b/aten/src/ATen/CheckpointTensorImpl.cpp index 1a593e9e252..c21dbb893dd 100644 --- a/aten/src/ATen/CheckpointTensorImpl.cpp +++ b/aten/src/ATen/CheckpointTensorImpl.cpp @@ -4,12 +4,16 @@ namespace at { +bool use_log = true; + namespace native { Tensor checkpoint(const Tensor& t) { auto cpti = intrusive_ptr::make(t.detach()); - DTRLogConstant(cpti->counter_name()); - DTRLogMemory(cpti->counter_name(), cpti->ref->value->value->memory()); + if (use_log) { + DTRLogConstant(cpti->counter_name()); + DTRLogMemory(cpti->counter_name(), cpti->ref->value->value->memory()); + } return Tensor(cpti); } @@ -44,6 +48,7 @@ void new_log(std::string str) { } void annotate_log(std::string str) { + if (!use_log) { return; } if (log_json) { json j; j[INSTRUCTION] = "ANNOTATE"; @@ -54,6 +59,10 @@ void annotate_log(std::string str) { } } +void toggle_log(bool b) { + use_log = b; +} + void clear_checkpointpool() { // not implemented yet. } @@ -115,7 +124,9 @@ void Rematerializer::remat() { intrusive_ptr CheckpointTensorImpl::shallow_copy_and_detach(const VariableVersion& version_counter, bool allow_tensor_metadata_change) const { auto ret = intrusive_ptr::make(ref); - DTRLogCopy(ret->counter_name(), counter_name()); + if (use_log) { + DTRLogCopy(ret->counter_name(), counter_name()); + } return ret; } @@ -124,7 +135,9 @@ void CheckpointTensorImpl::shallow_copy_from(const c10::intrusive_ptr(impl.get()); TORCH_CHECK(cpti != nullptr); ref->value = cpti->ref->value; - DTRLogCopyFrom(counter_name(), cpti->counter_name()); + if (use_log) { + DTRLogCopyFrom(counter_name(), cpti->counter_name()); + } } int CheckpointTensorImpl::counter = 0; @@ -199,7 +212,9 @@ Tensors CheckpointTensorImpl::make(const std::string& name, auto* cpti = dynamic_cast(t.unsafeGetTensorImpl()); TORCH_CHECK(cpti); input_values.push_back(cpti->ref->value->value); - args.push_back(cpti->counter_name()); + if (use_log) { + args.push_back(cpti->counter_name()); + } } std::vector res; auto ret = make_raw(remat, input_values); @@ -209,12 +224,14 @@ Tensors CheckpointTensorImpl::make(const std::string& name, tensors.push_back(cp); res.push_back(get_cpti(cp)->counter_name()); } - DTRLogCall(res, name, args, from_time(ret.time)); - for (size_t i = 0; i < tensors.size(); ++i) { - Tensor t = tensors[i]; - auto cpti = get_cpti(t); - DTRLogMemory(cpti->counter_name(), cpti->ref->value->value->memory()); - DTRLogAlias(cpti->counter_name(), ret.aliases[i]); + if (use_log) { + DTRLogCall(res, name, args, from_time(ret.time)); + for (size_t i = 0; i < tensors.size(); ++i) { + Tensor t = tensors[i]; + auto cpti = get_cpti(t); + DTRLogMemory(cpti->counter_name(), cpti->ref->value->value->memory()); + DTRLogAlias(cpti->counter_name(), ret.aliases[i]); + } } return tensors; } @@ -239,18 +256,24 @@ void CheckpointTensorImpl::mutate(const std::string& name, auto* cpti = dynamic_cast(t.unsafeGetTensorImpl()); TORCH_CHECK(cpti); input_values.push_back(cpti->ref->value->value); - args.push_back(cpti->counter_name()); + if (use_log) { + args.push_back(cpti->counter_name()); + } } auto ret = make_raw(remat, input_values); const auto& modified = ret.outputs; for (size_t idx: mutate_idx) { cell_from_tensor(inputs[idx])->value = modified[idx]; } - DTRLogMutate(name, args, mutate_idx, from_time(ret.time)); + if (use_log) { + DTRLogMutate(name, args, mutate_idx, from_time(ret.time)); + } } void CheckpointTensorImpl::release_resources() { - DTRLogRelease(counter_name()); + if (use_log) { + DTRLogRelease(counter_name()); + } ref.reset(); } diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 87cd7cb7165..8c700198ea9 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -3,6 +3,10 @@ - func: checkpoint(Tensor self) -> Tensor variants: method +# convert a tensor to a checkpoint tensor. +# if the input is already a checkpoint tensor, +# checkpoint will fail while try_checkpoint will +# simply return the input. - func: try_checkpoint(Tensor self) -> Tensor variants: method @@ -27,6 +31,9 @@ - func: annotate_log(str logname) -> () variants: function +- func: toggle_log(bool use_log) -> () + variants: function + - func: clear_checkpointpool() -> () variants: function From f17f292b8537f1b2f7d8acc65229d853bed18fd6 Mon Sep 17 00:00:00 2001 From: AD1024 Date: Wed, 27 May 2020 23:20:23 -0700 Subject: [PATCH 24/42] [ impl ] overload diag & mv --- aten/src/ATen/native/Checkpoint.cpp | 38 +++++++++++++++++++++- aten/src/ATen/native/native_functions.yaml | 4 +++ 2 files changed, 41 insertions(+), 1 deletion(-) diff --git a/aten/src/ATen/native/Checkpoint.cpp b/aten/src/ATen/native/Checkpoint.cpp index 8615f7f0ccf..61a21569534 100644 --- a/aten/src/ATen/native/Checkpoint.cpp +++ b/aten/src/ATen/native/Checkpoint.cpp @@ -1581,4 +1581,40 @@ Tensor checkpoint_expand(at::Tensor const& a, c10::ArrayRef b, bool c) { return CheckpointTensorImpl::make("expand", rt, {a})[0]; } -}} +Tensor checkpoint_diag(at::Tensor const& self, long diagonal) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::diag(vec.at(0), diagonal)}; + }; + return CheckpointTensorImpl::make("diag", rt, {self})[0]; +} + +Tensor& checkpoint_diag_out(at::Tensor& out, const Tensor& self, long diagonal) { + mutate_function_t mt = + [=](const Tensors& vec) { + Tensor out_ = vec.at(0); + at::diag_out(out_, vec.at(1), diagonal); + }; + CheckpointTensorImpl::mutate("diag_out", mt, {out, self}, {0}); + return {out}; +} + +Tensor checkpoint_mv(at::Tensor const& self, at::Tensor const& vec) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::mv(vec.at(0), vec.at(1))}; + }; + return CheckpointTensorImpl::make("mv", rt, {self, vec})[0]; +} + +Tensor& checkpoint_mv_out(at::Tensor& out, const Tensor& self, const Tensor& vec) { + mutate_function_t mt = + [=](const Tensors& vec) { + Tensor out_ = vec.at(0); + at::mv_out(out_, vec.at(1), vec.at(2)); + }; + CheckpointTensorImpl::mutate("mv_out", mt, {out, self, vec}, {0}); + return {out}; +} + +}} \ No newline at end of file diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 8c700198ea9..ff2970a5126 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -2158,12 +2158,14 @@ dispatch: CPU: mv_cpu CUDA: legacy::cuda::_th_mv + Checkpoint: checkpoint_mv supports_named_tensor: True - func: mv.out(Tensor self, Tensor vec, *, Tensor(a!) out) -> Tensor(a!) dispatch: CPU: mv_cpu_out CUDA: legacy::cuda::_th_mv_out + Checkpoint: checkpoint_mv_out supports_named_tensor: True - func: mvlgamma(Tensor self, int p) -> Tensor @@ -4613,6 +4615,7 @@ dispatch: CPU: legacy::cpu::_th_diag_out CUDA: legacy::cuda::_th_diag_out + Checkpoint: checkpoint_diag_out - func: diag(Tensor self, int diagonal=0) -> Tensor use_c10_dispatcher: full @@ -4620,6 +4623,7 @@ dispatch: CPU: legacy::cpu::_th_diag CUDA: legacy::cuda::_th_diag + Checkpoint: checkpoint_diag - func: cross.out(Tensor self, Tensor other, int? dim=None, *, Tensor(a!) out) -> Tensor(a!) From 7c5f36dfad76bc6a7a0cb34f262cae41059f342e Mon Sep 17 00:00:00 2001 From: Marisa Kirisame Date: Sat, 23 May 2020 08:31:33 -0700 Subject: [PATCH 25/42] refactor - remove stitch save save restore equivalentclassnode save save save 50% resnet here we go 50% resnet here we go save save save --- aten/src/ATen/CheckpointTensorImpl.cpp | 263 ++++++++++++++++++++- aten/src/ATen/CheckpointTensorImpl.h | 222 ++++++++++++----- aten/src/ATen/native/native_functions.yaml | 6 + 3 files changed, 420 insertions(+), 71 deletions(-) diff --git a/aten/src/ATen/CheckpointTensorImpl.cpp b/aten/src/ATen/CheckpointTensorImpl.cpp index c21dbb893dd..8d0424e1354 100644 --- a/aten/src/ATen/CheckpointTensorImpl.cpp +++ b/aten/src/ATen/CheckpointTensorImpl.cpp @@ -4,6 +4,71 @@ namespace at { +CheckpointPool pool; + +long current_memory() { + auto device_stat = c10::cuda::CUDACachingAllocator::getDeviceStats(0); + return device_stat.allocated_bytes[0].current; +} + +void checkpoint_auto_evict() { + pool.auto_evict(); +} + +void CheckpointPool::auto_evict() { + if (has_memory_budget) { + while (current_memory() > memory_budget) { + evict(); + } + } +} + +void CheckpointPool::evict() { + TORCH_CHECK(aps.size() > 0); + bool shrinked = false; + int evict_idx = -1; + double evict_score = INFINITY; + time_t current_time = std::chrono::system_clock::now(); + auto remove_from_aps = [&](size_t i) { + aps[i] = aps[aps.size() - 1]; + aps.pop_back(); + }; + for (size_t i = 0; i < aps.size();) { + auto cannot_evict = [&]() { + shrinked = true; + remove_from_aps(i); + }; + auto ap_strong = aps[i].lock(); + if (!ap_strong.defined()) { + cannot_evict(); + } else { + if (ap_strong->evictable()) { + double score = ap_strong->score(current_time); + if (score < evict_score) { + evict_score = score; + evict_idx = i; + } + } + ++i; + } + } + if (evict_idx == -1) { + TORCH_CHECK(shrinked); + } else { + auto evict_from_idx = [&](size_t idx) { + auto ap_strong = aps[idx].lock(); + TORCH_CHECK(ap_strong.defined()); + ap_strong->evict(); + remove_from_aps(evict_idx); + }; + evict_from_idx(evict_idx); + } +} + +CheckpointPool::CheckpointPool() { + c10::set_evict_func(checkpoint_auto_evict); +} + bool use_log = true; namespace native { @@ -67,6 +132,160 @@ void clear_checkpointpool() { // not implemented yet. } +void unset_memory_budget() { + pool.has_memory_budget = false; +} + +void set_memory_budget(long budget) { + pool.memory_budget = budget; + pool.has_memory_budget = true; +} + +} + +Tensor uncheckpoint(const strong& input) { + return input->get(); +} + +Tensors uncheckpoint(const strongs& inputs) { + Tensors ret; + for (const strong& input : inputs) { + ret.push_back(uncheckpoint(input)); + } + return ret; +}; + +Tensors try_checkpoint(const Tensors& inputs) { + Tensors ret; + for (const Tensor& input : inputs) { + ret.push_back(at::native::try_checkpoint(input)); + } + return ret; +} + +CheckpointInfo merge_cpi(CheckpointInfo l, CheckpointInfo r) { + return CheckpointInfo(l.compute_cost + r.compute_cost, + std::max(l.last_used_time, r.last_used_time)); +} + +void AliasPool::evict() { + TORCH_CHECK(!ecn); + ecn = head_remat->get_ecn(last_used_time); + auto ecns = neighbor_ecn(); + for (const auto& necn : ecns) { + merge(merge_cpi, ecn, necn); + } + auto b4 = current_memory(); + TORCH_CHECK(memory > 0); + TORCH_CHECK(lock_count == 0); + TORCH_CHECK(!is_evicted); + is_evicted = true; + for (const weak& w : tensors) { + if (auto cell = w.lock()) { + cell->evict(); + } + } + // TORCH_CHECK(current_memory() < b4); + // somehow it is still evicting unevictable stuff. +} + +double AliasPool::score(time_t current_time) { + auto cpi = head_remat->get_cpi(last_used_time); + auto ecns = neighbor_ecn(); + for (const auto& necn : ecns) { + cpi = merge_cpi(cpi, get_t(necn)); + } + return cpi.score(memory, current_time); +} + +void External::release_resources() { + value->evict(); + value.reset(); +} + +void Rematerializer::remat() { + // TODO: refactor using RAII for exception safety. + for (const strong& s : inputs) { + s->pool->lock(); + } + Tensors ts = uncheckpoint(inputs); + auto ret = func(ts); + TORCH_CHECK(ret.size() == outputs.size()); + for (size_t i = 0; i < outputs.size(); ++i) { + if (auto output_cell = outputs[i].lock()) { + output_cell->fill(ret[i]); + } + } + ecn.reset(); + for (const strong& s : inputs) { + s->pool->unlock(); + } +} + +ecn_ptr Rematerializer::get_ecn(time_t last_used_time) { + if (ecn) { + auto cpi = get_t(ecn); + update_t(ecn, CheckpointInfo(cpi.compute_cost, std::max(last_used_time, cpi.last_used_time))); + } else { + ecn = ecn_ptr::make(CheckpointInfo(compute_cost, last_used_time)); + } + return ecn; +} + +CheckpointInfo Rematerializer::get_cpi(time_t last_used_time) { + return CheckpointInfo(ecn ? duration_t(0) : compute_cost, last_used_time); +} + +std::vector AliasPool::neighbor_ecn() { + std::vector ret; + for (size_t i = 0; i < neighbors.size();) { + if (auto cptc = neighbors[i].lock()) { + if (cptc->pool->ecn) { + ret.push_back(cptc->pool->ecn); + } + ++i; + } else { + neighbors[i] = neighbors[neighbors.size() - 1]; + neighbors.pop_back(); + } + } + std::sort(ret.begin(), ret.end()); + ret.erase(std::unique(ret.begin(), ret.end()), ret.end()); + return ret; +} + +void AliasPool::set_not_evicted(const intrusive_ptr& self) { + if (is_evicted) { + is_evicted = false; + if (ecn) { + TORCH_CHECK(head_remat); + auto cpi = get_t(ecn); + update_t(ecn, CheckpointInfo(cpi.compute_cost - head_remat->compute_cost, cpi.last_used_time)); + ecn.reset(); + } + pool.aps.push_back(weak_intrusive_ptr(self)); + } +} + +void CheckpointTensorCell::fill(const Tensor& t) { + if (!(this->t)) { + this->t = std::make_unique(t.detach()); + pool->set_not_evicted(pool); + if (!defined) { + defined = true; + is_undefined_tensor = !t.defined(); + key_set_ = t.key_set(); + dtype_ = t.dtype(); + optional_device_ = t.optional_device(); + if (! is_undefined_tensor) { + dim_ = t.dim(); + numel_ = t.numel(); + itemsize_ = t.itemsize(); + sizes_ = t.sizes().vec(); + strides_ = t.strides().vec(); + } + } + } } Tensor uncheckpoint(const strong& input) { @@ -142,6 +361,10 @@ void CheckpointTensorImpl::shallow_copy_from(const c10::intrusive_ptr rematerializer; }; +void add_neighbor(const strong& l, const strong& r) { + l->pool->neighbors.push_back(weak(r)); +} + // remat take a single vector of tensors, // while there are two vector, one storing nonconstants and one storing constants. // the constants are small and they will not be considered for eviction. @@ -171,30 +398,50 @@ struct MakeRawResult { // the size_t in constants decide the location to stitch them in, while input_values fill in the rest. MakeRawResult make_raw(const rematerialize_function_t& remat_f, const strongs& inputs) { + for (const strong& s : inputs) { + s->pool->lock(); + } Tensors raw_inputs = uncheckpoint(inputs); time_t pre = std::chrono::system_clock::now(); - auto outputs_raw = remat_f(raw_inputs); + auto raw_outputs = remat_f(raw_inputs); time_t post = std::chrono::system_clock::now(); std::vector> outputs; std::vector aliases; weaks weak_outputs; - auto remat = intrusive_ptr::make(Unsafe(), remat_f, inputs); - for (const Tensor& t : outputs_raw) { + auto remat = intrusive_ptr::make(Unsafe(), remat_f, inputs, post - pre); + for (const Tensor& t : raw_outputs) { + intrusive_ptr alias_pool; int alias = get_alias(raw_inputs, t); - intrusive_ptr pool; if (alias == -1) { - pool = intrusive_ptr::make(Unsafe(), true, memory(t)); + auto m = memory(t); + alias_pool = intrusive_ptr::make(Unsafe(), remat, m); + if (m > 0) { + pool.aps.push_back(weak_intrusive_ptr(alias_pool)); + } } else { - pool = inputs[alias]->pool; + alias_pool = inputs[alias]->pool; + if (alias_pool->head_remat) { + alias_pool->head_remat->compute_cost += (post - pre); + } } - auto e = intrusive_ptr::make(t, pool, remat); - pool->tensors.push_back(weak(e->value)); + auto e = intrusive_ptr::make(t, alias_pool, remat); + alias_pool->tensors.push_back(weak(e->value)); outputs.push_back(e); aliases.push_back(alias); weak_outputs.push_back(weak(outputs.back()->value)); } remat->outputs = weak_outputs; + for (size_t i = 0; i < inputs.size(); ++i) { + for (size_t j = 0; j < outputs.size(); ++j) { + if (is_alias(raw_inputs[i], raw_outputs[j])) { + add_neighbor(inputs[i], outputs[j]->value); + } + } + } + for (const strong& s : inputs) { + s->pool->unlock(); + } return {outputs, aliases, post - pre, remat}; } diff --git a/aten/src/ATen/CheckpointTensorImpl.h b/aten/src/ATen/CheckpointTensorImpl.h index 7aa9530f124..1a2eec70b70 100644 --- a/aten/src/ATen/CheckpointTensorImpl.h +++ b/aten/src/ATen/CheckpointTensorImpl.h @@ -22,22 +22,20 @@ #include // System Description: -// Every Tensor is managed by a CheckPointTensor, +// Every Tensor is managed by a CheckpointTensor, // that describe how it is computed, (the function and the inputs) // And might optionally hold the tensor value. // The tensor value might be dropped, and when requested later, recomputed and cached again. // Corner Cases: -// An CheckPointedTensor might be constant. +// An CheckpointedTensor might be constant. // In this case it is unevictable. -// An input might be uncheckpointed. -// In this case it is treated as a small constant and omitted from the system - it will be unevictable. // An operator might return multiple output. // In this case the computation info (rematerializer) is shared between all of them, // And when the function get computed again all value get cached. // An operator might not return value, but only mutate input value. -// To combat this, we COW the operator, and wrap CheckPopintTensor with a Ref. -// By doing this the inner CheckPointTensor is kept purely functional. +// To combat this, we COW the operator, and wrap CheckpopintTensor with a Ref. +// By doing this the inner CheckpointTensor is kept purely functional. // An operator might try to mutate uncheckpointed tensor. // We do not support this and will error. // An operator might create aliases. @@ -71,6 +69,54 @@ namespace at { +// TODO: using a pool allocator might make more sense - no need to allocate and delete each pointer individually. +template +struct EquivalentClassNode : intrusive_ptr_target { + explicit EquivalentClassNode(const T& t) : t_unsafe(t) { } + mutable intrusive_ptr parent; + bool is_root() { + return !parent; + } + void release_resources() override { + parent.reset(); + } + T t_unsafe; +}; + +template +T& get_t(const intrusive_ptr>& n) { + return find_root(n)->t_unsafe; +} + +template +static void update_t(const intrusive_ptr>& n, const T& t) { + find_root(n)->t_unsafe = t; +} + +template +intrusive_ptr> find_root(const intrusive_ptr>& n) { + if (n->is_root()) { + return n; + } else { + n->parent = find_root(n->parent); + return n->parent; + } +} + +template +intrusive_ptr> merge(const std::function& merge_t, + const intrusive_ptr>& lhs, + const intrusive_ptr>& rhs) { + auto l = find_root(lhs); + auto r = find_root(rhs); + if (l == r) { + return l; + } + l->parent = r; + r->t_unsafe = merge_t(l->t_unsafe, r->t_unsafe); + return r; +} + inline size_t memory(const Tensor& t) { if (! t.has_storage()) { return 0; @@ -107,32 +153,36 @@ using mutate_function_t = std::function; using time_t = std::chrono::time_point; using duration_t = std::chrono::system_clock::duration; - -struct Unsafe { }; - -// Track all Tensor that share the same Storage. -// This is the atomic level of eviction - when evicting, everything here will get evicted. -// When an AliasPool is evicted, the Storage of the underlying tensor must be freed. -// Additionally, the AliasPool contain weak pointer to all children of tensors, -// in order to compute the score of evicting a Storage. -struct AliasPool : intrusive_ptr_target { - weaks tensors; - // get() might hold some raw Tensor, rendering them unevictable. - // it is likely that get() will run out of memory, and when it does so, it will try to evict. - // so, it is crucial that we dont try to evict those tensors - doing so will not evict anything. - // lock_count count how many time a tensor is referenced by get. - size_t lock_count; - bool evictable; - size_t memory; - AliasPool(const Unsafe&, bool evictable, size_t memory) : - lock_count(0), evictable(evictable), memory(memory) { - } - void evict(); - void release_resources() final { - tensors.clear(); +struct CheckpointInfo { + duration_t compute_cost; + time_t last_used_time; + // @ZACH: Floating Point instability? + double score(size_t memory, time_t current_time) const { + TORCH_CHECK(memory > 0); + auto staleness = (current_time - last_used_time).count(); + TORCH_CHECK(staleness > 0); + return compute_cost.count() / static_cast(memory * staleness); + } + CheckpointInfo(duration_t compute_cost, time_t last_used_time) : + compute_cost(compute_cost), + last_used_time(last_used_time) { } }; +// ecn represent a evicted tensor group. +// it is a set of tensor that are evicted, and if two evicted tensor are input -> output to each other, +// they must be in an ecn. +// note: we try to support removal from ecn by subtracting compute_cost and memory. +// this will create suprious connection but that should be fine empircally. +// below is an example of a suprious connection: +// a -> b, a -> c +// a, b, c got evicted so belong to a single ecn. +// a got rematerialized. +// b, c still belong to a single ecn although there is no connection. +using ecn_ptr = intrusive_ptr>; + +struct Unsafe { }; + // The rematerializer could be called to reinvoke an operator. // Tensor point to remat which point to Tensor. // To build the cycle remat support a default constructor, @@ -141,10 +191,19 @@ struct Rematerializer : intrusive_ptr_target { rematerialize_function_t func; strongs inputs; weaks outputs; + duration_t compute_cost; + // when some output in here get evicted, they should belong to this ecn. + // a rematerializer have to track this, + // because when multiple output of a rematerializer get evicted, + // we only want to count the compute cost once. + ecn_ptr ecn; Rematerializer(const Unsafe&, const rematerialize_function_t& func, - const strongs& inputs) : - func(func), inputs(inputs) { + const strongs& inputs, + duration_t compute_cost) : + func(func), + inputs(inputs), + compute_cost(compute_cost) { } void release_resources() final { func = rematerialize_function_t(); @@ -152,9 +211,60 @@ struct Rematerializer : intrusive_ptr_target { outputs.clear(); } void remat(); + ecn_ptr get_ecn(time_t last_used_time); + CheckpointInfo get_cpi(time_t last_used_time); +}; + +// Track all Tensor that share the same Storage. +// This is the atomic level of eviction - when evicting, everything here will get evicted. +// When an AliasPool is evicted, the Storage of the underlying tensor must be freed. +// Additionally, the AliasPool contain weak pointer to all children of tensors, +// in order to compute the score of evicting a Storage. +struct AliasPool : intrusive_ptr_target { + weaks tensors; + weaks neighbors; + std::vector neighbor_ecn(); + // get() might hold some raw Tensor, rendering them unevictable. + // it is likely that get() will run out of memory, and when it does so, it will try to evict. + // so, it is crucial that we dont try to evict those tensors - doing so will not evict anything. + // lock_count count how many time a tensor is referenced by get. + size_t lock_count = 0; + void lock() { + ++lock_count; + } + void unlock() { + --lock_count; + } + intrusive_ptr head_remat; + bool evictable() const { + return lock_count == 0 && head_remat; + } + // if it is not evictable it must not be evicted. + bool is_evicted = false; + size_t memory; + time_t last_used_time; + // An aliaspool cant register itself to the checkpointpool - you have to do it yourself. + AliasPool(const Unsafe&, intrusive_ptr head_remat, size_t memory) : + head_remat(head_remat), + memory(memory), + last_used_time(std::chrono::system_clock::now()) { + } + // if it is evicted, then hold the evicted tensor group. + ecn_ptr ecn; + double score(time_t current_time); + void evict(); + // if it was evicted, refresh it. otherwise do nothing. + // have to check so, because when we rematerialize a single tensor in an aliaspool, + // we will set it to non-evicted, and when we rematerialize it's tensor they will also reset this. + void set_not_evicted(const intrusive_ptr& self); + void release_resources() final { + tensors.clear(); + neighbors.clear(); + head_remat.reset(); + } }; -struct CAFFE2_API CheckpointTensorCell : intrusive_ptr_target { +struct CheckpointTensorCell : intrusive_ptr_target { std::unique_ptr t; bool defined = false; bool is_undefined_tensor; @@ -207,25 +317,7 @@ struct CAFFE2_API CheckpointTensorCell : intrusive_ptr_target { void evict() { t.reset(); } - void fill(const Tensor& t) { - if (!(this->t)) { - this->t = std::make_unique(t.detach()); - if (!defined) { - defined = true; - is_undefined_tensor = !t.defined(); - key_set_ = t.key_set(); - dtype_ = t.dtype(); - optional_device_ = t.optional_device(); - if (! is_undefined_tensor) { - dim_ = t.dim(); - numel_ = t.numel(); - itemsize_ = t.itemsize(); - sizes_ = t.sizes().vec(); - strides_ = t.strides().vec(); - } - } - } - } + void fill(const Tensor& t); explicit CheckpointTensorCell(const Tensor& t, const intrusive_ptr& pool) : pool(pool) { fill(t); } @@ -248,12 +340,13 @@ struct CAFFE2_API CheckpointTensorCell : intrusive_ptr_target { remat->remat(); } TORCH_CHECK(t); - TORCH_CHECK(! t->key_set().has(DispatchKey::CheckpointTensorId)) + TORCH_CHECK(! t->key_set().has(DispatchKey::CheckpointTensorId)); + pool->last_used_time = std::chrono::system_clock::now(); return *t; } void pin() { - pool->evictable = false; get(); + pool->head_remat.reset(); remat.reset(); } void release_resources() final { @@ -263,14 +356,6 @@ struct CAFFE2_API CheckpointTensorCell : intrusive_ptr_target { } }; -// CheckpointPool keep a list of AliasPool, and search over them to choose the best one to evict. -struct CheckpointPool { - static CheckpointPool& singleton() { - static CheckpointPool cpp; - return cpp; - } -}; - // An external reference. // Each strong will have at most one external reference. // By keeping such an invariant, whenever an external reference die, @@ -283,7 +368,7 @@ struct External : intrusive_ptr_target { External(const Tensor& value) : value(intrusive_ptr::make(value, intrusive_ptr::make(Unsafe(), - false, + intrusive_ptr(), memory(value)))) { } External(const Tensor& value, const intrusive_ptr& pool, @@ -300,7 +385,7 @@ inline DispatchKeySet convert_key_set(const DispatchKeySet& t) { return ret; } -struct CAFFE2_API CheckpointTensorImpl : TensorImpl { +struct CheckpointTensorImpl : TensorImpl { int id = gen_counter(); static int counter; static int gen_counter() { @@ -354,6 +439,17 @@ struct CAFFE2_API CheckpointTensorImpl : TensorImpl { } }; +// CheckpointPool keep a list of AliasPool, and search over them to choose the best one to evict. +struct CheckpointPool { + std::vector> aps; + bool has_memory_budget = false; + long memory_budget; + void evict(); + void auto_evict(); + void clear_checkpointpool(); + CheckpointPool(); +}; + inline CheckpointTensorImpl* get_cpti(const Tensor& t) { auto* cpti = dynamic_cast(t.unsafeGetTensorImpl()); TORCH_CHECK(cpti != nullptr); diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index ff2970a5126..d9eb137705d 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -37,6 +37,12 @@ - func: clear_checkpointpool() -> () variants: function +- func: set_memory_budget(int budget) -> () + variants: function + +- func: unset_memory_budget() -> () + variants: function + # Temporary type cast operators. These are needed to trace type-casts now since # Type's are not supported in the IR. Instead, we call down to these # specialized operators for each datatype. From c9ef0976bbca7df7954c64a57f92b48550222c4a Mon Sep 17 00:00:00 2001 From: Marisa Kirisame Date: Sun, 31 May 2020 10:08:52 -0700 Subject: [PATCH 26/42] save --- aten/src/ATen/CheckpointTensorImpl.cpp | 52 -------------------------- 1 file changed, 52 deletions(-) diff --git a/aten/src/ATen/CheckpointTensorImpl.cpp b/aten/src/ATen/CheckpointTensorImpl.cpp index 8d0424e1354..9193185ae11 100644 --- a/aten/src/ATen/CheckpointTensorImpl.cpp +++ b/aten/src/ATen/CheckpointTensorImpl.cpp @@ -288,58 +288,6 @@ void CheckpointTensorCell::fill(const Tensor& t) { } } -Tensor uncheckpoint(const strong& input) { - return input->get(); -} - -Tensors uncheckpoint(const strongs& inputs) { - Tensors ret; - for (const strong& input : inputs) { - ret.push_back(uncheckpoint(input)); - } - return ret; -}; - -Tensors try_checkpoint(const Tensors& inputs) { - Tensors ret; - for (const Tensor& input : inputs) { - ret.push_back(at::native::try_checkpoint(input)); - } - return ret; -} - -void AliasPool::evict() { - TORCH_CHECK(lock_count == 0); - for (const weak& w : tensors) { - if (auto cell = w.lock()) { - cell->evict(); - } - } -} - -void External::release_resources() { - value->evict(); - value.reset(); -} - -void Rematerializer::remat() { - // TODO: refactor using RAII for exception safety. - for (const strong& s : inputs) { - ++(s->pool->lock_count); - } - Tensors ts = uncheckpoint(inputs); - auto ret = func(ts); - TORCH_CHECK(ret.size() == outputs.size()); - for (size_t i = 0; i < outputs.size(); ++i) { - if (auto output_cell = outputs[i].lock()) { - output_cell->fill(ret[i]); - } - } - for (const strong& s : inputs) { - --(s->pool->lock_count); - } -} - intrusive_ptr CheckpointTensorImpl::shallow_copy_and_detach(const VariableVersion& version_counter, bool allow_tensor_metadata_change) const { auto ret = intrusive_ptr::make(ref); From 0a08ff79a3de1d97b144ba68f7e1d5240aca85a0 Mon Sep 17 00:00:00 2001 From: Marisa Kirisame Date: Sun, 31 May 2020 23:06:18 -0700 Subject: [PATCH 27/42] save --- aten/src/ATen/CheckpointTensorImpl.cpp | 24 ++++++++++-------------- aten/src/ATen/CheckpointTensorImpl.h | 13 +++++-------- 2 files changed, 15 insertions(+), 22 deletions(-) diff --git a/aten/src/ATen/CheckpointTensorImpl.cpp b/aten/src/ATen/CheckpointTensorImpl.cpp index 9193185ae11..7b7669845c6 100644 --- a/aten/src/ATen/CheckpointTensorImpl.cpp +++ b/aten/src/ATen/CheckpointTensorImpl.cpp @@ -164,13 +164,12 @@ Tensors try_checkpoint(const Tensors& inputs) { } CheckpointInfo merge_cpi(CheckpointInfo l, CheckpointInfo r) { - return CheckpointInfo(l.compute_cost + r.compute_cost, - std::max(l.last_used_time, r.last_used_time)); + return CheckpointInfo(l.compute_cost + r.compute_cost); } void AliasPool::evict() { TORCH_CHECK(!ecn); - ecn = head_remat->get_ecn(last_used_time); + ecn = head_remat->get_ecn(); auto ecns = neighbor_ecn(); for (const auto& necn : ecns) { merge(merge_cpi, ecn, necn); @@ -190,12 +189,12 @@ void AliasPool::evict() { } double AliasPool::score(time_t current_time) { - auto cpi = head_remat->get_cpi(last_used_time); + auto cpi = head_remat->get_cpi(); auto ecns = neighbor_ecn(); for (const auto& necn : ecns) { cpi = merge_cpi(cpi, get_t(necn)); } - return cpi.score(memory, current_time); + return cpi.score(memory, (current_time - last_used_time).count()); } void External::release_resources() { @@ -222,18 +221,15 @@ void Rematerializer::remat() { } } -ecn_ptr Rematerializer::get_ecn(time_t last_used_time) { - if (ecn) { - auto cpi = get_t(ecn); - update_t(ecn, CheckpointInfo(cpi.compute_cost, std::max(last_used_time, cpi.last_used_time))); - } else { - ecn = ecn_ptr::make(CheckpointInfo(compute_cost, last_used_time)); +ecn_ptr Rematerializer::get_ecn() { + if (!ecn) { + ecn = ecn_ptr::make(CheckpointInfo(compute_cost)); } return ecn; } -CheckpointInfo Rematerializer::get_cpi(time_t last_used_time) { - return CheckpointInfo(ecn ? duration_t(0) : compute_cost, last_used_time); +CheckpointInfo Rematerializer::get_cpi() { + return CheckpointInfo(ecn ? duration_t(0) : compute_cost); } std::vector AliasPool::neighbor_ecn() { @@ -260,7 +256,7 @@ void AliasPool::set_not_evicted(const intrusive_ptr& self) { if (ecn) { TORCH_CHECK(head_remat); auto cpi = get_t(ecn); - update_t(ecn, CheckpointInfo(cpi.compute_cost - head_remat->compute_cost, cpi.last_used_time)); + update_t(ecn, CheckpointInfo(cpi.compute_cost - head_remat->compute_cost)); ecn.reset(); } pool.aps.push_back(weak_intrusive_ptr(self)); diff --git a/aten/src/ATen/CheckpointTensorImpl.h b/aten/src/ATen/CheckpointTensorImpl.h index 1a2eec70b70..2a0a70e6ed2 100644 --- a/aten/src/ATen/CheckpointTensorImpl.h +++ b/aten/src/ATen/CheckpointTensorImpl.h @@ -155,17 +155,14 @@ using time_t = std::chrono::time_point; using duration_t = std::chrono::system_clock::duration; struct CheckpointInfo { duration_t compute_cost; - time_t last_used_time; // @ZACH: Floating Point instability? - double score(size_t memory, time_t current_time) const { + double score(size_t memory, size_t staleness) const { TORCH_CHECK(memory > 0); - auto staleness = (current_time - last_used_time).count(); TORCH_CHECK(staleness > 0); return compute_cost.count() / static_cast(memory * staleness); } - CheckpointInfo(duration_t compute_cost, time_t last_used_time) : - compute_cost(compute_cost), - last_used_time(last_used_time) { + CheckpointInfo(duration_t compute_cost) : + compute_cost(compute_cost) { } }; @@ -211,8 +208,8 @@ struct Rematerializer : intrusive_ptr_target { outputs.clear(); } void remat(); - ecn_ptr get_ecn(time_t last_used_time); - CheckpointInfo get_cpi(time_t last_used_time); + ecn_ptr get_ecn(); + CheckpointInfo get_cpi(); }; // Track all Tensor that share the same Storage. From fe395f41e17cc6bbcf677364e0068b82d3faca88 Mon Sep 17 00:00:00 2001 From: Marisa Kirisame Date: Mon, 1 Jun 2020 13:20:30 -0700 Subject: [PATCH 28/42] save --- aten/src/ATen/CheckpointTensorImpl.cpp | 4 +++- aten/src/ATen/CheckpointTensorImpl.h | 1 + 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/aten/src/ATen/CheckpointTensorImpl.cpp b/aten/src/ATen/CheckpointTensorImpl.cpp index 7b7669845c6..2f6135b8ede 100644 --- a/aten/src/ATen/CheckpointTensorImpl.cpp +++ b/aten/src/ATen/CheckpointTensorImpl.cpp @@ -198,7 +198,9 @@ double AliasPool::score(time_t current_time) { } void External::release_resources() { - value->evict(); + if (value->remat) { + value->evict(); + } value.reset(); } diff --git a/aten/src/ATen/CheckpointTensorImpl.h b/aten/src/ATen/CheckpointTensorImpl.h index 2a0a70e6ed2..aebe9ebeca7 100644 --- a/aten/src/ATen/CheckpointTensorImpl.h +++ b/aten/src/ATen/CheckpointTensorImpl.h @@ -312,6 +312,7 @@ struct CheckpointTensorCell : intrusive_ptr_target { return strides_[d]; } void evict() { + TORCH_CHECK(remat); t.reset(); } void fill(const Tensor& t); From 7efe2029123ba7bf53cf8e377d3c7fcd5dc491b2 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Mon, 1 Jun 2020 20:59:23 -0700 Subject: [PATCH 29/42] More opts --- aten/src/ATen/CheckpointTensorImpl.cpp | 238 +++++++++++++++++++++++-- aten/src/ATen/CheckpointTensorImpl.h | 9 +- 2 files changed, 231 insertions(+), 16 deletions(-) diff --git a/aten/src/ATen/CheckpointTensorImpl.cpp b/aten/src/ATen/CheckpointTensorImpl.cpp index 2f6135b8ede..13f60b654ad 100644 --- a/aten/src/ATen/CheckpointTensorImpl.cpp +++ b/aten/src/ATen/CheckpointTensorImpl.cpp @@ -2,11 +2,108 @@ #include #include +#include +#include + namespace at { +using Clock = std::chrono::high_resolution_clock; +using Time = Clock::time_point; +using Duration = Clock::duration; +using FinalTime = std::chrono::nanoseconds; + +struct PerfStats; + +struct Timer { + std::string name; + Time start; + PerfStats& stats; + Timer(std::string name, Time start, PerfStats& stats) : name(name), start(start), stats(stats) {} + ~Timer(); +}; + +bool stats = true; + +struct PerfStats { + using TimerStats = std::tuple; + Time start; + std::unordered_map calls; + std::vector timers; + + PerfStats() : start(Clock::now()), calls(0), timers() {} + + Timer track(std::string name) { + if (stats) { + auto it = this->calls.find(name); + if (it != this->calls.end()) { + it->second += 1; + } else { + this->calls.insert({name, 0}); + } + + return Timer(name, Clock::now(), *this); + } + } + + ~PerfStats() { + if (!stats) { return; } + auto start = std::get<1>(this->timers[0]); + auto now = Clock::now(); + std::cout << "All done. Here are some perf stats fresh off the preses." << std::endl; + std::unordered_map durations; + + Duration total = now - this->start; + + // For now simple strategy, count up all the time taken + // by each "tagged call site". + for (auto timer : timers) { + auto name = std::get<0>(timer); + Duration duration = std::get<3>(timer); + auto it = durations.find(name); + if (it != durations.end()) { + it->second += duration; + } else { + durations.insert({name, duration}); + } + } + + std::vector> data; + + // Convert the durations + for (auto d : durations) { + // auto duration = std::chrono::duration_cast(d.second); + data.push_back(d); + } + + std::sort(data.begin(), data.end(), + [](const std::pair & a, const std::pair & b) -> bool { + return a.second > b.second; + }); + + for (auto d : data) { + auto duration = std::chrono::duration_cast(d.second); + auto total_duration = std::chrono::duration_cast(total); + double percentage = ((double)duration.count())/((double)total_duration.count()) * 100; + auto call_count = this->calls.find(d.first); + TORCH_CHECK(call_count != this->calls.end()); + std::cout << "CallSite: " << d.first << " CallCount: " << call_count->second << " Cost: " << duration.count() << "ns" << " (%" << percentage << ")" << std::endl; + } + } +}; + +Timer::~Timer() { + Time now = Clock::now(); + Duration elapsed = now - start; + PerfStats::TimerStats stats = { name , start, now, elapsed }; + this->stats.timers.push_back(stats); +} + +static PerfStats STATS = PerfStats(); + CheckpointPool pool; long current_memory() { + STATS.track("current_memory"); auto device_stat = c10::cuda::CUDACachingAllocator::getDeviceStats(0); return device_stat.allocated_bytes[0].current; } @@ -16,6 +113,7 @@ void checkpoint_auto_evict() { } void CheckpointPool::auto_evict() { + STATS.track("CheckpointPool::auto_evict"); if (has_memory_budget) { while (current_memory() > memory_budget) { evict(); @@ -24,6 +122,7 @@ void CheckpointPool::auto_evict() { } void CheckpointPool::evict() { + STATS.track("CheckpointPool::evict"); TORCH_CHECK(aps.size() > 0); bool shrinked = false; int evict_idx = -1; @@ -66,14 +165,16 @@ void CheckpointPool::evict() { } CheckpointPool::CheckpointPool() { + STATS.track("CheckpointPool::CheckpointPool"); c10::set_evict_func(checkpoint_auto_evict); } -bool use_log = true; +bool use_log = false; namespace native { Tensor checkpoint(const Tensor& t) { + STATS.track("checkpoint"); auto cpti = intrusive_ptr::make(t.detach()); if (use_log) { DTRLogConstant(cpti->counter_name()); @@ -83,28 +184,33 @@ Tensor checkpoint(const Tensor& t) { } Tensor uncheckpoint(const Tensor& t) { + STATS.track("uncheckpoint"); auto* cpti = dynamic_cast(t.unsafeGetTensorImpl()); - CHECK(cpti != nullptr); + // CHECK(cpti != nullptr); return cpti->ref->value->value->get(); } void pin(const Tensor& t) { + STATS.track("pin"); auto* cpti = dynamic_cast(t.unsafeGetTensorImpl()); - CHECK(cpti != nullptr); + // CHECK(cpti != nullptr); cpti->ref->value->value->pin(); } Tensor decheckpoint(const Tensor& t) { + STATS.track("decheckpoint"); auto* cpti = dynamic_cast(t.unsafeGetTensorImpl()); return cpti ? cpti->ref->value->value->get() : t; } bool is_checkpoint(const Tensor& t) { + STATS.track("is_checkpoint"); auto* cpti = dynamic_cast(t.unsafeGetTensorImpl()); return cpti != nullptr; } Tensor try_checkpoint(const Tensor& t) { + STATS.track("try_checkpiont"); return is_checkpoint(t) ? t : checkpoint(t); } @@ -143,20 +249,26 @@ void set_memory_budget(long budget) { } +[[inline]] Tensor uncheckpoint(const strong& input) { return input->get(); } Tensors uncheckpoint(const strongs& inputs) { + STATS.track("uncheckpoint"); Tensors ret; + ret.reserve(inputs.size()); for (const strong& input : inputs) { - ret.push_back(uncheckpoint(input)); + // Jared: inlined manually + ret.push_back(input->get()); } return ret; }; Tensors try_checkpoint(const Tensors& inputs) { + STATS.track("try_checkpoint"); Tensors ret; + ret.reserve(inputs.size()); for (const Tensor& input : inputs) { ret.push_back(at::native::try_checkpoint(input)); } @@ -164,10 +276,12 @@ Tensors try_checkpoint(const Tensors& inputs) { } CheckpointInfo merge_cpi(CheckpointInfo l, CheckpointInfo r) { + STATS.track("merge_cpi"); return CheckpointInfo(l.compute_cost + r.compute_cost); } void AliasPool::evict() { + STATS.track("AliasPool::evict"); TORCH_CHECK(!ecn); ecn = head_remat->get_ecn(); auto ecns = neighbor_ecn(); @@ -175,9 +289,9 @@ void AliasPool::evict() { merge(merge_cpi, ecn, necn); } auto b4 = current_memory(); - TORCH_CHECK(memory > 0); - TORCH_CHECK(lock_count == 0); - TORCH_CHECK(!is_evicted); + // cTORCH_CHECK(memory > 0); + // TORCH_CHECK(lock_count == 0); + // TORCH_CHECK(!is_evicted); is_evicted = true; for (const weak& w : tensors) { if (auto cell = w.lock()) { @@ -234,25 +348,96 @@ CheckpointInfo Rematerializer::get_cpi() { return CheckpointInfo(ecn ? duration_t(0) : compute_cost); } -std::vector AliasPool::neighbor_ecn() { - std::vector ret; +// A utility function to swap two elements +void swap(int* a, int* b) +{ + int t = *a; + *a = *b; + *b = t; +} + +// /* This function is same in both iterative and recursive*/ +// template +// int partition(std::vector& arr, int l, int h) +// { +// int x = arr[h]; +// int i = (l - 1); + +// for (int j = l; j <= h - 1; j++) { +// if (arr[j] <= x) { +// i++; +// swap(&arr[i], &arr[j]); +// } +// } +// swap(&arr[i + 1], &arr[h]); +// return (i + 1); +// } + +// template +// void qsort_dedup(std::vector& arr, int l, int h) +// { +// // Create an auxiliary stack +// std::vector stack; +// stack.reserve(h - l + 1); + +// // initialize top of stack +// int top = -1; + +// // push initial values of l and h to stack +// stack[++top] = l; +// stack[++top] = h; + +// // Keep popping from stack while is not empty +// while (top >= 0) { +// // Pop h and l +// h = stack[top--]; +// l = stack[top--]; + +// // Set pivot element at its correct position +// // in sorted array +// int p = partition(arr, l, h); + +// // If there are elements on left side of pivot, +// // then push left side to stack +// if (p - 1 > l) { +// stack[++top] = l; +// stack[++top] = p - 1; +// } + +// // If there are elements on right side of pivot, +// // then push right side to stack +// if (p + 1 < h) { +// stack[++top] = p + 1; +// stack[++top] = h; +// } +// } +// } + +std::set AliasPool::neighbor_ecn() { + STATS.track("AliasPool::neighbor_ecn"); + std::set ptr_set; + + STATS.track("AliasPool::neighbor_ecn(process nodes)"); for (size_t i = 0; i < neighbors.size();) { if (auto cptc = neighbors[i].lock()) { + STATS.track("AliasPool::neighbor_ecn(true-branch)"); if (cptc->pool->ecn) { - ret.push_back(cptc->pool->ecn); + ptr_set.insert(cptc->pool->ecn); } ++i; } else { + STATS.track("AliasPool::neighbor_ecn(false-branch)"); neighbors[i] = neighbors[neighbors.size() - 1]; neighbors.pop_back(); } } - std::sort(ret.begin(), ret.end()); - ret.erase(std::unique(ret.begin(), ret.end()), ret.end()); - return ret; + + return ptr_set; } void AliasPool::set_not_evicted(const intrusive_ptr& self) { + STATS.track("AliasPool::set_not_evicted"); + std::vector ret; if (is_evicted) { is_evicted = false; if (ecn) { @@ -266,8 +451,10 @@ void AliasPool::set_not_evicted(const intrusive_ptr& self) { } void CheckpointTensorCell::fill(const Tensor& t) { + STATS.track("CheckpointTensorCell::fill"); if (!(this->t)) { this->t = std::make_unique(t.detach()); + STATS.track("CheckpointTensorCell::fill(after_alloc)"); pool->set_not_evicted(pool); if (!defined) { defined = true; @@ -296,6 +483,7 @@ intrusive_ptr CheckpointTensorImpl::shallow_copy_and_detach(const Va } void CheckpointTensorImpl::shallow_copy_from(const c10::intrusive_ptr& impl) { + STATS.track("CheckpointTensorCell::shallow_copy_from"); TORCH_CHECK(impl->key_set().has(DispatchKey::CheckpointTensorId)); auto* cpti = dynamic_cast(impl.get()); TORCH_CHECK(cpti != nullptr); @@ -344,6 +532,7 @@ void add_neighbor(const strong& l, const strong& r) { // the size_t in constants decide the location to stitch them in, while input_values fill in the rest. MakeRawResult make_raw(const rematerialize_function_t& remat_f, const strongs& inputs) { + STATS.track("make_raw"); for (const strong& s : inputs) { s->pool->lock(); } @@ -355,6 +544,7 @@ MakeRawResult make_raw(const rematerialize_function_t& remat_f, std::vector aliases; weaks weak_outputs; auto remat = intrusive_ptr::make(Unsafe(), remat_f, inputs, post - pre); + for (const Tensor& t : raw_outputs) { intrusive_ptr alias_pool; int alias = get_alias(raw_inputs, t); @@ -398,9 +588,16 @@ std::string from_time(duration_t t) { Tensors CheckpointTensorImpl::make(const std::string& name, const rematerialize_function_t& remat, const Tensors& inputs) { + STATS.track("CheckPointTensorImpl::make"); Tensors checkpointed_inputs = try_checkpoint(inputs); + auto input_size = checkpointed_inputs.size(); + strongs input_values; + input_values.reserve(input_size); + std::vector args; + args.reserve(input_size); + for (const Tensor& t: checkpointed_inputs) { auto* cpti = dynamic_cast(t.unsafeGetTensorImpl()); TORCH_CHECK(cpti); @@ -409,15 +606,25 @@ Tensors CheckpointTensorImpl::make(const std::string& name, args.push_back(cpti->counter_name()); } } - std::vector res; + auto ret = make_raw(remat, input_values); + Tensors tensors; + tensors.reserve(ret.outputs.size()); + for (const auto& t: ret.outputs) { auto cp = Tensor(intrusive_ptr::make(t)); tensors.push_back(cp); - res.push_back(get_cpti(cp)->counter_name()); } + if (use_log) { + std::vector res; + res.reserve(ret.outputs.size()); + + for (const auto& tensor : tensors) { + res.push_back(get_cpti(tensor)->counter_name()); + } + DTRLogCall(res, name, args, from_time(ret.time)); for (size_t i = 0; i < tensors.size(); ++i) { Tensor t = tensors[i]; @@ -426,6 +633,7 @@ Tensors CheckpointTensorImpl::make(const std::string& name, DTRLogAlias(cpti->counter_name(), ret.aliases[i]); } } + return tensors; } diff --git a/aten/src/ATen/CheckpointTensorImpl.h b/aten/src/ATen/CheckpointTensorImpl.h index aebe9ebeca7..6a712c048b4 100644 --- a/aten/src/ATen/CheckpointTensorImpl.h +++ b/aten/src/ATen/CheckpointTensorImpl.h @@ -220,7 +220,7 @@ struct Rematerializer : intrusive_ptr_target { struct AliasPool : intrusive_ptr_target { weaks tensors; weaks neighbors; - std::vector neighbor_ecn(); + std::set neighbor_ecn(); // get() might hold some raw Tensor, rendering them unevictable. // it is likely that get() will run out of memory, and when it does so, it will try to evict. // so, it is crucial that we dont try to evict those tensors - doing so will not evict anything. @@ -392,20 +392,27 @@ struct CheckpointTensorImpl : TensorImpl { std::string counter_name() const { return std::string("x") + std::to_string(id); } + Ref> ref; + void release_resources() final; + explicit CheckpointTensorImpl(const Ref>& ref) : TensorImpl(convert_key_set(ref->value->value->key_set()), ref->value->value->dtype(), ref->value->value->optional_device()), ref(ref) { } + explicit CheckpointTensorImpl(const intrusive_ptr& e) : CheckpointTensorImpl(Ref>::make(e)) { } + explicit CheckpointTensorImpl(const Tensor& t) : CheckpointTensorImpl(intrusive_ptr::make(t)) { } + static Tensors make(const std::string& name, const rematerialize_function_t& remat, const Tensors& inputs); + // mutate_idx indicate which of the inputs will get mutated. static void mutate(const std::string& name, const mutate_function_t& mutate, From ed31496250eab8d00fa8d01dc9568c9973cef857 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Mon, 1 Jun 2020 21:20:43 -0700 Subject: [PATCH 30/42] Try removing pop-ing from loop --- aten/src/ATen/CheckpointTensorImpl.cpp | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/aten/src/ATen/CheckpointTensorImpl.cpp b/aten/src/ATen/CheckpointTensorImpl.cpp index 13f60b654ad..9f52de5f549 100644 --- a/aten/src/ATen/CheckpointTensorImpl.cpp +++ b/aten/src/ATen/CheckpointTensorImpl.cpp @@ -418,7 +418,8 @@ std::set AliasPool::neighbor_ecn() { std::set ptr_set; STATS.track("AliasPool::neighbor_ecn(process nodes)"); - for (size_t i = 0; i < neighbors.size();) { + int back = neighbors.size() - 1; + for (size_t i = 0; i < back;) { if (auto cptc = neighbors[i].lock()) { STATS.track("AliasPool::neighbor_ecn(true-branch)"); if (cptc->pool->ecn) { @@ -427,11 +428,14 @@ std::set AliasPool::neighbor_ecn() { ++i; } else { STATS.track("AliasPool::neighbor_ecn(false-branch)"); - neighbors[i] = neighbors[neighbors.size() - 1]; - neighbors.pop_back(); + neighbors[i] = neighbors[back]; + back = back - 1; } } + CheckpointTensorCell* ptr = nullptr; + neighbors.resize(back, weak_intrusive_ptr(ptr)); + return ptr_set; } From a9878a5ebe5648cdb1ec7687dbd4360dd5e3b80f Mon Sep 17 00:00:00 2001 From: Marisa Kirisame Date: Mon, 1 Jun 2020 21:57:45 -0700 Subject: [PATCH 31/42] save --- aten/src/ATen/CheckpointTensorImpl.cpp | 38 ++++++++++---------- aten/src/ATen/CheckpointTensorImpl.h | 49 ++++++-------------------- 2 files changed, 31 insertions(+), 56 deletions(-) diff --git a/aten/src/ATen/CheckpointTensorImpl.cpp b/aten/src/ATen/CheckpointTensorImpl.cpp index 13f60b654ad..d3df0840273 100644 --- a/aten/src/ATen/CheckpointTensorImpl.cpp +++ b/aten/src/ATen/CheckpointTensorImpl.cpp @@ -17,8 +17,8 @@ struct PerfStats; struct Timer { std::string name; Time start; - PerfStats& stats; - Timer(std::string name, Time start, PerfStats& stats) : name(name), start(start), stats(stats) {} + Timer(std::string name, Time start) : name(name), start(start) {} + Timer() {} ~Timer(); }; @@ -41,8 +41,9 @@ struct PerfStats { this->calls.insert({name, 0}); } - return Timer(name, Clock::now(), *this); + return Timer(name, Clock::now()); } + return Timer(); } ~PerfStats() { @@ -91,15 +92,15 @@ struct PerfStats { } }; +static PerfStats STATS = PerfStats(); + Timer::~Timer() { Time now = Clock::now(); Duration elapsed = now - start; PerfStats::TimerStats stats = { name , start, now, elapsed }; - this->stats.timers.push_back(stats); + STATS.timers.push_back(stats); } -static PerfStats STATS = PerfStats(); - CheckpointPool pool; long current_memory() { @@ -235,7 +236,12 @@ void toggle_log(bool b) { } void clear_checkpointpool() { - // not implemented yet. + while (likely(!pool.exts.empty())) { + if (auto e = pool.exts.back().lock()) { + e->value->pin(); + } + pool.exts.pop_back(); + } } void unset_memory_budget() { @@ -436,9 +442,8 @@ std::set AliasPool::neighbor_ecn() { } void AliasPool::set_not_evicted(const intrusive_ptr& self) { - STATS.track("AliasPool::set_not_evicted"); - std::vector ret; - if (is_evicted) { + if (unlikely(is_evicted)) { + STATS.track("AliasPool::set_not_evicted(inside)"); is_evicted = false; if (ecn) { TORCH_CHECK(head_remat); @@ -454,7 +459,6 @@ void CheckpointTensorCell::fill(const Tensor& t) { STATS.track("CheckpointTensorCell::fill"); if (!(this->t)) { this->t = std::make_unique(t.detach()); - STATS.track("CheckpointTensorCell::fill(after_alloc)"); pool->set_not_evicted(pool); if (!defined) { defined = true; @@ -462,13 +466,6 @@ void CheckpointTensorCell::fill(const Tensor& t) { key_set_ = t.key_set(); dtype_ = t.dtype(); optional_device_ = t.optional_device(); - if (! is_undefined_tensor) { - dim_ = t.dim(); - numel_ = t.numel(); - itemsize_ = t.itemsize(); - sizes_ = t.sizes().vec(); - strides_ = t.strides().vec(); - } } } } @@ -562,6 +559,7 @@ MakeRawResult make_raw(const rematerialize_function_t& remat_f, } } auto e = intrusive_ptr::make(t, alias_pool, remat); + pool.exts.push_back(weak_intrusive_ptr(e)); alias_pool->tensors.push_back(weak(e->value)); outputs.push_back(e); aliases.push_back(alias); @@ -678,4 +676,8 @@ void CheckpointTensorImpl::release_resources() { ref.reset(); } +CheckpointTensorImpl::CheckpointTensorImpl(const Tensor& t) : CheckpointTensorImpl(intrusive_ptr::make(t)) { + pool.exts.push_back(weak_intrusive_ptr(ref->value)); +} + } diff --git a/aten/src/ATen/CheckpointTensorImpl.h b/aten/src/ATen/CheckpointTensorImpl.h index 6a712c048b4..742c4efeecd 100644 --- a/aten/src/ATen/CheckpointTensorImpl.h +++ b/aten/src/ATen/CheckpointTensorImpl.h @@ -21,6 +21,9 @@ #include #include +#define likely(x) __builtin_expect(!!(x), 1) +#define unlikely(x) __builtin_expect(!!(x), 0) + // System Description: // Every Tensor is managed by a CheckpointTensor, // that describe how it is computed, (the function and the inputs) @@ -280,37 +283,10 @@ struct CheckpointTensorCell : intrusive_ptr_target { TORCH_CHECK(defined); return optional_device_; } - int64_t dim_, numel_; - size_t itemsize_; - std::vector sizes_, strides_; // A Tensor is evictable iff it's AliasPool is evictable. // A evictable tensor must have Rematerializer. intrusive_ptr pool; intrusive_ptr remat; - int64_t dim() const { - TORCH_CHECK(defined && !is_undefined_tensor); - return dim_; - } - int64_t numel() const { - TORCH_CHECK(defined && !is_undefined_tensor); - return numel_; - } - IntArrayRef sizes() const { - TORCH_CHECK(defined && !is_undefined_tensor); - return sizes_; - } - int64_t size(int64_t d) const { - TORCH_CHECK(defined && !is_undefined_tensor); - return sizes_[d]; - } - IntArrayRef strides() const { - TORCH_CHECK(defined && !is_undefined_tensor); - return strides_; - } - int64_t stride(int64_t d) const { - TORCH_CHECK(defined && !is_undefined_tensor); - return strides_[d]; - } void evict() { TORCH_CHECK(remat); t.reset(); @@ -325,9 +301,6 @@ struct CheckpointTensorCell : intrusive_ptr_target { pool(pool), remat(remat) { fill(t); } - size_t itemsize() { - return itemsize_; - } size_t memory() { TORCH_CHECK(defined); return pool->memory; @@ -406,8 +379,7 @@ struct CheckpointTensorImpl : TensorImpl { explicit CheckpointTensorImpl(const intrusive_ptr& e) : CheckpointTensorImpl(Ref>::make(e)) { } - explicit CheckpointTensorImpl(const Tensor& t) : - CheckpointTensorImpl(intrusive_ptr::make(t)) { } + explicit CheckpointTensorImpl(const Tensor& t); static Tensors make(const std::string& name, const rematerialize_function_t& remat, @@ -422,22 +394,22 @@ struct CheckpointTensorImpl : TensorImpl { bool allow_tensor_metadata_change) const override; void shallow_copy_from(const c10::intrusive_ptr& impl) override; int64_t dim() const override { - return ref->value->value->dim(); + return ref->value->value->get().dim(); } int64_t numel() const override { - return ref->value->value->numel(); + return ref->value->value->get().numel(); } IntArrayRef sizes() const override { - return ref->value->value->sizes(); + return ref->value->value->get().sizes(); } int64_t size(int64_t d) const override { - return ref->value->value->size(d); + return ref->value->value->get().size(d); } IntArrayRef strides() const override { - return ref->value->value->strides(); + return ref->value->value->get().strides(); } int64_t stride(int64_t d) const override { - return ref->value->value->stride(d); + return ref->value->value->get().stride(d); } bool has_storage() const override { return false; @@ -447,6 +419,7 @@ struct CheckpointTensorImpl : TensorImpl { // CheckpointPool keep a list of AliasPool, and search over them to choose the best one to evict. struct CheckpointPool { std::vector> aps; + std::vector> exts; bool has_memory_budget = false; long memory_budget; void evict(); From 984ec1025d1aaf6fbe97e77009c86c4a2de34028 Mon Sep 17 00:00:00 2001 From: Marisa Kirisame Date: Mon, 1 Jun 2020 22:05:57 -0700 Subject: [PATCH 32/42] save --- aten/src/ATen/CheckpointTensorImpl.cpp | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/aten/src/ATen/CheckpointTensorImpl.cpp b/aten/src/ATen/CheckpointTensorImpl.cpp index 7935f794488..099c8bd5ad8 100644 --- a/aten/src/ATen/CheckpointTensorImpl.cpp +++ b/aten/src/ATen/CheckpointTensorImpl.cpp @@ -439,9 +439,7 @@ std::set AliasPool::neighbor_ecn() { } } - CheckpointTensorCell* ptr = nullptr; - neighbors.resize(back, weak_intrusive_ptr(ptr)); - + neighbors.erase(neighbors.begin() + back); return ptr_set; } From 4607be498695b372fafdb02ac4e58188fa9c08bd Mon Sep 17 00:00:00 2001 From: Marisa Kirisame Date: Mon, 1 Jun 2020 23:08:29 -0700 Subject: [PATCH 33/42] save --- aten/src/ATen/CheckpointTensorImpl.cpp | 29 ++++++++++---------------- aten/src/ATen/Logger.h | 18 +++++++++++++--- c10/cuda/CUDACachingAllocator.cpp | 21 ------------------- c10/cuda/CUDACachingAllocator.h | 4 ---- 4 files changed, 26 insertions(+), 46 deletions(-) diff --git a/aten/src/ATen/CheckpointTensorImpl.cpp b/aten/src/ATen/CheckpointTensorImpl.cpp index 099c8bd5ad8..bd1c9166913 100644 --- a/aten/src/ATen/CheckpointTensorImpl.cpp +++ b/aten/src/ATen/CheckpointTensorImpl.cpp @@ -22,7 +22,7 @@ struct Timer { ~Timer(); }; -bool stats = true; +constexpr bool stats = true; struct PerfStats { using TimerStats = std::tuple; @@ -109,10 +109,6 @@ long current_memory() { return device_stat.allocated_bytes[0].current; } -void checkpoint_auto_evict() { - pool.auto_evict(); -} - void CheckpointPool::auto_evict() { STATS.track("CheckpointPool::auto_evict"); if (has_memory_budget) { @@ -165,10 +161,7 @@ void CheckpointPool::evict() { } } -CheckpointPool::CheckpointPool() { - STATS.track("CheckpointPool::CheckpointPool"); - c10::set_evict_func(checkpoint_auto_evict); -} +CheckpointPool::CheckpointPool() { } bool use_log = false; @@ -331,6 +324,7 @@ void Rematerializer::remat() { } Tensors ts = uncheckpoint(inputs); auto ret = func(ts); + pool.auto_evict(); TORCH_CHECK(ret.size() == outputs.size()); for (size_t i = 0; i < outputs.size(); ++i) { if (auto output_cell = outputs[i].lock()) { @@ -422,24 +416,22 @@ void swap(int* a, int* b) std::set AliasPool::neighbor_ecn() { STATS.track("AliasPool::neighbor_ecn"); std::set ptr_set; - STATS.track("AliasPool::neighbor_ecn(process nodes)"); - int back = neighbors.size() - 1; - for (size_t i = 0; i < back;) { + int size = neighbors.size(); + for (size_t i = 0; i < size;) { if (auto cptc = neighbors[i].lock()) { - STATS.track("AliasPool::neighbor_ecn(true-branch)"); if (cptc->pool->ecn) { ptr_set.insert(cptc->pool->ecn); } ++i; } else { - STATS.track("AliasPool::neighbor_ecn(false-branch)"); - neighbors[i] = neighbors[back]; - back = back - 1; + neighbors[i] = neighbors[size - 1]; + size = size - 1; } } - - neighbors.erase(neighbors.begin() + back); + if (size < neighbors.size()) { + neighbors.erase(neighbors.begin() + size); + } return ptr_set; } @@ -539,6 +531,7 @@ MakeRawResult make_raw(const rematerialize_function_t& remat_f, time_t pre = std::chrono::system_clock::now(); auto raw_outputs = remat_f(raw_inputs); time_t post = std::chrono::system_clock::now(); + pool.auto_evict(); std::vector> outputs; std::vector aliases; weaks weak_outputs; diff --git a/aten/src/ATen/Logger.h b/aten/src/ATen/Logger.h index 190c2f28433..7e3c01a9af9 100644 --- a/aten/src/ATen/Logger.h +++ b/aten/src/ATen/Logger.h @@ -39,6 +39,7 @@ constexpr bool log_json = true; const std::string INSTRUCTION = "INSTRUCTION"; const std::string ANNOTATION = "ANNOTATION"; const std::string RELEASE = "RELEASE"; +const std::string PIN = "PIN"; const std::string TIME = "TIME"; const std::string ARGS = "ARGS"; const std::string MEMORY = "MEMORY"; @@ -139,14 +140,25 @@ void DTRLogMutate(const std::string& name, } } -void DTRLogRelease(const std::string& counter_name) { +void DTRLogRelease(const std::string& name) { if (log_json) { json j; j[INSTRUCTION] = RELEASE; - j[NAME] = counter_name; + j[NAME] = name; + DTRLogger::logger().log(j.dump()); + } else { + DTRLogger::logger().log(RELEASE + ": " + name); + } +} + +void DTRLogPin(const std::string& name) { + if (log_json) { + json j; + j[INSTRUCTION] = PIN; + j[NAME] = name; DTRLogger::logger().log(j.dump()); } else { - DTRLogger::logger().log(RELEASE + ": " + counter_name); + DTRLogger::logger().log(RELEASE + ": " + name); } } diff --git a/c10/cuda/CUDACachingAllocator.cpp b/c10/cuda/CUDACachingAllocator.cpp index f8f7fb421b7..d309288291a 100644 --- a/c10/cuda/CUDACachingAllocator.cpp +++ b/c10/cuda/CUDACachingAllocator.cpp @@ -22,15 +22,6 @@ namespace c10 { C10_DEFINE_REGISTRY(FreeCudaMemoryCallbacksRegistry, FreeMemoryCallback); -evict_func_t evict_func = nullptr; -void set_evict_func(evict_func_t ef) { - evict_func = ef; -} - -evict_func_t get_evict_func() { - return evict_func; -} - namespace cuda { namespace CUDACachingAllocator { @@ -210,19 +201,7 @@ class THCCachingAllocator { // Thus, do not call a public method from another public method. /** allocates a block which is safe to use from the provided stream */ - // Technically speaking, it is still allocating more memory then it should, - // But it doesn't do anything with it until more memory are found, so it is morally ok - no experimental result will be changed. - // TODO: fix it and make it to be always below limit, so ppl can set the limit to be GPU max memory and it will still work. void malloc(void** devPtr, size_t size, cudaStream_t stream) - { - malloc_inner(devPtr, size, stream); - auto evict_func = get_evict_func(); - if (evict_func) { - (*evict_func)(); - } - } - - void malloc_inner(void** devPtr, size_t size, cudaStream_t stream) { std::lock_guard lock(mutex); int device; diff --git a/c10/cuda/CUDACachingAllocator.h b/c10/cuda/CUDACachingAllocator.h index e119f259b96..fecd12179e2 100644 --- a/c10/cuda/CUDACachingAllocator.h +++ b/c10/cuda/CUDACachingAllocator.h @@ -27,10 +27,6 @@ C10_DECLARE_REGISTRY(FreeCudaMemoryCallbacksRegistry, FreeMemoryCallback); #define REGISTER_FREE_MEMORY_CALLBACK(name, ...) \ C10_REGISTER_CLASS(FreeCudaMemoryCallbacksRegistry, name, __VA_ARGS__); -typedef void (*evict_func_t)(); -C10_CUDA_API void set_evict_func(evict_func_t); -C10_CUDA_API evict_func_t get_evict_func(); - namespace cuda { // TODO: Turn this into an honest to goodness class. I briefly attempted to do From 8ad4f93d66146e2b7e7ff830d1a69b3eec563b1c Mon Sep 17 00:00:00 2001 From: Marisa Kirisame Date: Tue, 2 Jun 2020 01:04:28 -0700 Subject: [PATCH 34/42] save --- aten/src/ATen/CheckpointTensorImpl.cpp | 13 +++++++++++++ aten/src/ATen/native/native_functions.yaml | 5 +++++ 2 files changed, 18 insertions(+) diff --git a/aten/src/ATen/CheckpointTensorImpl.cpp b/aten/src/ATen/CheckpointTensorImpl.cpp index bd1c9166913..ca720dc801f 100644 --- a/aten/src/ATen/CheckpointTensorImpl.cpp +++ b/aten/src/ATen/CheckpointTensorImpl.cpp @@ -164,6 +164,7 @@ void CheckpointPool::evict() { CheckpointPool::CheckpointPool() { } bool use_log = false; +long compute_time_ = 0; namespace native { @@ -246,6 +247,14 @@ void set_memory_budget(long budget) { pool.has_memory_budget = true; } +void reset_compute_time() { + compute_time_ = 0; +} + +long compute_time() { + return compute_time_; +} + } [[inline]] @@ -323,8 +332,11 @@ void Rematerializer::remat() { s->pool->lock(); } Tensors ts = uncheckpoint(inputs); + time_t pre = std::chrono::system_clock::now(); auto ret = func(ts); + time_t post = std::chrono::system_clock::now(); pool.auto_evict(); + compute_time_ += (post - pre).count(); TORCH_CHECK(ret.size() == outputs.size()); for (size_t i = 0; i < outputs.size(); ++i) { if (auto output_cell = outputs[i].lock()) { @@ -532,6 +544,7 @@ MakeRawResult make_raw(const rematerialize_function_t& remat_f, auto raw_outputs = remat_f(raw_inputs); time_t post = std::chrono::system_clock::now(); pool.auto_evict(); + compute_time_ += (post - pre).count(); std::vector> outputs; std::vector aliases; weaks weak_outputs; diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index d9eb137705d..59459cafef5 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -43,6 +43,11 @@ - func: unset_memory_budget() -> () variants: function +- func: reset_compute_time() -> () + variants: function + +- func: compute_time() -> int + variants: function # Temporary type cast operators. These are needed to trace type-casts now since # Type's are not supported in the IR. Instead, we call down to these # specialized operators for each datatype. From 74dd61cfa4b90bdbdfd0219c6f5d74ca46a635f2 Mon Sep 17 00:00:00 2001 From: Marisa Kirisame Date: Wed, 3 Jun 2020 10:07:58 -0700 Subject: [PATCH 35/42] save --- aten/src/ATen/CheckpointTensorImpl.cpp | 50 +++++++++++++++++++------- aten/src/ATen/CheckpointTensorImpl.h | 41 +++++++++++++-------- 2 files changed, 64 insertions(+), 27 deletions(-) diff --git a/aten/src/ATen/CheckpointTensorImpl.cpp b/aten/src/ATen/CheckpointTensorImpl.cpp index ca720dc801f..48be99fe850 100644 --- a/aten/src/ATen/CheckpointTensorImpl.cpp +++ b/aten/src/ATen/CheckpointTensorImpl.cpp @@ -94,6 +94,28 @@ struct PerfStats { static PerfStats STATS = PerfStats(); +size_t memory_sum = 0; +size_t memory_max = 0; +size_t memory_count = 0; + +void reset_memory_stat() { + memory_sum = 0; + memory_max = 0; + memory_count = 0; +} + +inline size_t memory(const Tensor& t) { + if (! t.has_storage()) { + return 0; + } + auto& storage = t.storage(); + size_t res = storage.numel() * storage.itemsize(); + memory_sum += res; + memory_max = std::max(memory_max, res); + memory_count += 1; + return res; +} + Timer::~Timer() { Time now = Clock::now(); Duration elapsed = now - start; @@ -102,6 +124,11 @@ Timer::~Timer() { } CheckpointPool pool; +void CheckpointPool::add(const intrusive_ptr& p) { + if (p->memory > 0 && (memory_count == 0 || p->memory >= 0.01 * double(memory_sum/memory_count))) { + aps.push_back(weak_intrusive_ptr(p)); + } +} long current_memory() { STATS.track("current_memory"); @@ -137,7 +164,11 @@ void CheckpointPool::evict() { auto ap_strong = aps[i].lock(); if (!ap_strong.defined()) { cannot_evict(); - } else { + } + else if (ap_strong->ecn) { + cannot_evict(); + } + else { if (ap_strong->evictable()) { double score = ap_strong->score(current_time); if (score < evict_score) { @@ -297,9 +328,9 @@ void AliasPool::evict() { merge(merge_cpi, ecn, necn); } auto b4 = current_memory(); - // cTORCH_CHECK(memory > 0); - // TORCH_CHECK(lock_count == 0); - // TORCH_CHECK(!is_evicted); + TORCH_CHECK(memory > 0); + TORCH_CHECK(lock_count == 0); + TORCH_CHECK(!is_evicted); is_evicted = true; for (const weak& w : tensors) { if (auto cell = w.lock()) { @@ -320,9 +351,7 @@ double AliasPool::score(time_t current_time) { } void External::release_resources() { - if (value->remat) { - value->evict(); - } + value->pool->release_external(); value.reset(); } @@ -428,7 +457,6 @@ void swap(int* a, int* b) std::set AliasPool::neighbor_ecn() { STATS.track("AliasPool::neighbor_ecn"); std::set ptr_set; - STATS.track("AliasPool::neighbor_ecn(process nodes)"); int size = neighbors.size(); for (size_t i = 0; i < size;) { if (auto cptc = neighbors[i].lock()) { @@ -457,7 +485,7 @@ void AliasPool::set_not_evicted(const intrusive_ptr& self) { update_t(ecn, CheckpointInfo(cpi.compute_cost - head_remat->compute_cost)); ecn.reset(); } - pool.aps.push_back(weak_intrusive_ptr(self)); + pool.add(self); } } @@ -556,9 +584,7 @@ MakeRawResult make_raw(const rematerialize_function_t& remat_f, if (alias == -1) { auto m = memory(t); alias_pool = intrusive_ptr::make(Unsafe(), remat, m); - if (m > 0) { - pool.aps.push_back(weak_intrusive_ptr(alias_pool)); - } + pool.add(alias_pool); } else { alias_pool = inputs[alias]->pool; diff --git a/aten/src/ATen/CheckpointTensorImpl.h b/aten/src/ATen/CheckpointTensorImpl.h index 742c4efeecd..ab8ee368480 100644 --- a/aten/src/ATen/CheckpointTensorImpl.h +++ b/aten/src/ATen/CheckpointTensorImpl.h @@ -21,8 +21,8 @@ #include #include -#define likely(x) __builtin_expect(!!(x), 1) -#define unlikely(x) __builtin_expect(!!(x), 0) +#define likely(x) __builtin_expect(!!(x), 1) +#define unlikely(x) __builtin_expect(!!(x), 0) // System Description: // Every Tensor is managed by a CheckpointTensor, @@ -120,13 +120,7 @@ intrusive_ptr> merge(const std::function struct RefCell final : intrusive_ptr_target { @@ -229,6 +223,7 @@ struct AliasPool : intrusive_ptr_target { // so, it is crucial that we dont try to evict those tensors - doing so will not evict anything. // lock_count count how many time a tensor is referenced by get. size_t lock_count = 0; + size_t external_count = 0; void lock() { ++lock_count; } @@ -253,6 +248,19 @@ struct AliasPool : intrusive_ptr_target { ecn_ptr ecn; double score(time_t current_time); void evict(); + void register_external() { + ++external_count; + } + void release_external() { + --external_count; + if (external_count == 0) { + if (lock_count > 0) {return;} + TORCH_CHECK(lock_count == 0); + if (memory > 0 && (!ecn) && head_remat) { + evict(); + } + } + } // if it was evicted, refresh it. otherwise do nothing. // have to check so, because when we rematerialize a single tensor in an aliaspool, // we will set it to non-evicted, and when we rematerialize it's tensor they will also reset this. @@ -335,16 +343,18 @@ struct CheckpointTensorCell : intrusive_ptr_target { // We keep this invariant by only allowing CheckpointTensorImpl to make new External, // When new CheckpointTensorImpl is constructed. struct External : intrusive_ptr_target { - External(const strong& value) : value(value) { } + External(const strong& value) : value(value) { + value->pool->register_external(); + } External(const Tensor& value) : - value(intrusive_ptr::make(value, - intrusive_ptr::make(Unsafe(), - intrusive_ptr(), - memory(value)))) { } + External(strong::make(value, + intrusive_ptr::make(Unsafe(), + intrusive_ptr(), + memory(value)))) { } External(const Tensor& value, const intrusive_ptr& pool, const intrusive_ptr& remat) : - value(intrusive_ptr::make(value, pool, remat)) { } + External(strong::make(value, pool, remat)) { } strong value; void release_resources() override; }; @@ -425,6 +435,7 @@ struct CheckpointPool { void evict(); void auto_evict(); void clear_checkpointpool(); + void add(const intrusive_ptr&); CheckpointPool(); }; From f8c8f325352243d094a2c4ad4b0b96ba462cec87 Mon Sep 17 00:00:00 2001 From: Marisa Kirisame Date: Wed, 3 Jun 2020 14:25:51 -0700 Subject: [PATCH 36/42] mike try this --- aten/src/ATen/CheckpointTensorImpl.cpp | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/aten/src/ATen/CheckpointTensorImpl.cpp b/aten/src/ATen/CheckpointTensorImpl.cpp index 48be99fe850..0fa29a39e49 100644 --- a/aten/src/ATen/CheckpointTensorImpl.cpp +++ b/aten/src/ATen/CheckpointTensorImpl.cpp @@ -4,6 +4,8 @@ #include #include +#include +#include namespace at { @@ -156,6 +158,9 @@ void CheckpointPool::evict() { aps[i] = aps[aps.size() - 1]; aps.pop_back(); }; + std::random_device rd; + std::mt19937 gen(rd()); + std::uniform_int_distribution<> distrib(1, 10 * std::max(1, static_cast(std::sqrt(aps.size())))); for (size_t i = 0; i < aps.size();) { auto cannot_evict = [&]() { shrinked = true; @@ -176,7 +181,7 @@ void CheckpointPool::evict() { evict_idx = i; } } - ++i; + i += distrib(gen); } } if (evict_idx == -1) { From f2f095c2b255b41fd3f3af598f37a17c28f172db Mon Sep 17 00:00:00 2001 From: Marisa Kirisame Date: Wed, 3 Jun 2020 14:43:29 -0700 Subject: [PATCH 37/42] mike try this --- aten/src/ATen/CheckpointTensorImpl.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/aten/src/ATen/CheckpointTensorImpl.cpp b/aten/src/ATen/CheckpointTensorImpl.cpp index 0fa29a39e49..bd2e415966c 100644 --- a/aten/src/ATen/CheckpointTensorImpl.cpp +++ b/aten/src/ATen/CheckpointTensorImpl.cpp @@ -160,7 +160,7 @@ void CheckpointPool::evict() { }; std::random_device rd; std::mt19937 gen(rd()); - std::uniform_int_distribution<> distrib(1, 10 * std::max(1, static_cast(std::sqrt(aps.size())))); + std::uniform_int_distribution<> distrib(1, 1 * std::max(1, static_cast(std::sqrt(aps.size())))); for (size_t i = 0; i < aps.size();) { auto cannot_evict = [&]() { shrinked = true; From 7c040705b8acb0cd1cc45ae484752d88ef2927db Mon Sep 17 00:00:00 2001 From: Marisa Kirisame Date: Wed, 3 Jun 2020 15:17:21 -0700 Subject: [PATCH 38/42] save --- aten/src/ATen/CheckpointTensorImpl.cpp | 5 ++--- aten/src/ATen/CheckpointTensorImpl.h | 3 +++ 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/aten/src/ATen/CheckpointTensorImpl.cpp b/aten/src/ATen/CheckpointTensorImpl.cpp index bd2e415966c..a5209a094ba 100644 --- a/aten/src/ATen/CheckpointTensorImpl.cpp +++ b/aten/src/ATen/CheckpointTensorImpl.cpp @@ -158,8 +158,6 @@ void CheckpointPool::evict() { aps[i] = aps[aps.size() - 1]; aps.pop_back(); }; - std::random_device rd; - std::mt19937 gen(rd()); std::uniform_int_distribution<> distrib(1, 1 * std::max(1, static_cast(std::sqrt(aps.size())))); for (size_t i = 0; i < aps.size();) { auto cannot_evict = [&]() { @@ -559,6 +557,7 @@ struct MakeRawResult { void add_neighbor(const strong& l, const strong& r) { l->pool->neighbors.push_back(weak(r)); + r->pool->neighbors.push_back(weak(l)); } // remat take a single vector of tensors, @@ -607,7 +606,7 @@ MakeRawResult make_raw(const rematerialize_function_t& remat_f, remat->outputs = weak_outputs; for (size_t i = 0; i < inputs.size(); ++i) { for (size_t j = 0; j < outputs.size(); ++j) { - if (is_alias(raw_inputs[i], raw_outputs[j])) { + if (!is_alias(raw_inputs[i], raw_outputs[j])) { add_neighbor(inputs[i], outputs[j]->value); } } diff --git a/aten/src/ATen/CheckpointTensorImpl.h b/aten/src/ATen/CheckpointTensorImpl.h index ab8ee368480..9c7a3ef457d 100644 --- a/aten/src/ATen/CheckpointTensorImpl.h +++ b/aten/src/ATen/CheckpointTensorImpl.h @@ -3,6 +3,7 @@ #include #include #include +#include #include #include @@ -430,6 +431,8 @@ struct CheckpointTensorImpl : TensorImpl { struct CheckpointPool { std::vector> aps; std::vector> exts; + std::random_device rd; + std::mt19937 gen = std::mt19937(rd()); bool has_memory_budget = false; long memory_budget; void evict(); From 3cb823bb04d034001cbc07bf21ae25170254fb88 Mon Sep 17 00:00:00 2001 From: Marisa Kirisame Date: Wed, 3 Jun 2020 15:27:13 -0700 Subject: [PATCH 39/42] undo system optimization for now --- aten/src/ATen/CheckpointTensorImpl.cpp | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/aten/src/ATen/CheckpointTensorImpl.cpp b/aten/src/ATen/CheckpointTensorImpl.cpp index a5209a094ba..4a0b0981a3c 100644 --- a/aten/src/ATen/CheckpointTensorImpl.cpp +++ b/aten/src/ATen/CheckpointTensorImpl.cpp @@ -127,7 +127,7 @@ Timer::~Timer() { CheckpointPool pool; void CheckpointPool::add(const intrusive_ptr& p) { - if (p->memory > 0 && (memory_count == 0 || p->memory >= 0.01 * double(memory_sum/memory_count))) { + if (p->memory > 0 && (true || memory_count == 0 || p->memory >= 0.01 * double(memory_sum/memory_count))) { aps.push_back(weak_intrusive_ptr(p)); } } @@ -179,7 +179,8 @@ void CheckpointPool::evict() { evict_idx = i; } } - i += distrib(gen); + i += 1; + //i += distrib(gen); } } if (evict_idx == -1) { From f6312c93d4ed66af4c2b4408e1d8752c8e0772a4 Mon Sep 17 00:00:00 2001 From: Marisa Kirisame Date: Wed, 3 Jun 2020 17:20:40 -0700 Subject: [PATCH 40/42] save --- aten/src/ATen/CheckpointTensorImpl.cpp | 10 +++++----- aten/src/ATen/CheckpointTensorImpl.h | 1 + 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/aten/src/ATen/CheckpointTensorImpl.cpp b/aten/src/ATen/CheckpointTensorImpl.cpp index 4a0b0981a3c..a96a97f1047 100644 --- a/aten/src/ATen/CheckpointTensorImpl.cpp +++ b/aten/src/ATen/CheckpointTensorImpl.cpp @@ -34,7 +34,7 @@ struct PerfStats { PerfStats() : start(Clock::now()), calls(0), timers() {} - Timer track(std::string name) { + /*Timer track(std::string name) { if (stats) { auto it = this->calls.find(name); if (it != this->calls.end()) { @@ -46,7 +46,8 @@ struct PerfStats { return Timer(name, Clock::now()); } return Timer(); - } + }*/ + void track(const char*) { } ~PerfStats() { if (!stats) { return; } @@ -127,7 +128,7 @@ Timer::~Timer() { CheckpointPool pool; void CheckpointPool::add(const intrusive_ptr& p) { - if (p->memory > 0 && (true || memory_count == 0 || p->memory >= 0.01 * double(memory_sum/memory_count))) { + if (p->memory > 0 && (memory_count == 0 || p->memory >= 0.01 * double(memory_sum/memory_count))) { aps.push_back(weak_intrusive_ptr(p)); } } @@ -179,8 +180,7 @@ void CheckpointPool::evict() { evict_idx = i; } } - i += 1; - //i += distrib(gen); + i += distrib(gen); } } if (evict_idx == -1) { diff --git a/aten/src/ATen/CheckpointTensorImpl.h b/aten/src/ATen/CheckpointTensorImpl.h index 9c7a3ef457d..fc34b22038b 100644 --- a/aten/src/ATen/CheckpointTensorImpl.h +++ b/aten/src/ATen/CheckpointTensorImpl.h @@ -24,6 +24,7 @@ #define likely(x) __builtin_expect(!!(x), 1) #define unlikely(x) __builtin_expect(!!(x), 0) +#define TORCH_CHECK(x) // profile mode // System Description: // Every Tensor is managed by a CheckpointTensor, From ee8f2a963c711da2ec7df7411b14ba01ef418a8c Mon Sep 17 00:00:00 2001 From: Marisa Kirisame Date: Wed, 3 Jun 2020 17:29:29 -0700 Subject: [PATCH 41/42] save --- aten/src/ATen/CheckpointTensorImpl.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/aten/src/ATen/CheckpointTensorImpl.h b/aten/src/ATen/CheckpointTensorImpl.h index fc34b22038b..1d328574951 100644 --- a/aten/src/ATen/CheckpointTensorImpl.h +++ b/aten/src/ATen/CheckpointTensorImpl.h @@ -24,7 +24,7 @@ #define likely(x) __builtin_expect(!!(x), 1) #define unlikely(x) __builtin_expect(!!(x), 0) -#define TORCH_CHECK(x) // profile mode +#define TORCH_CHECK(a, ...) // profile mode // System Description: // Every Tensor is managed by a CheckpointTensor, From 454027703740fc06092e1c16cb2a363fd0695caf Mon Sep 17 00:00:00 2001 From: Marisa Kirisame Date: Tue, 23 Jun 2020 14:48:14 -0700 Subject: [PATCH 42/42] save --- aten/src/ATen/CheckpointTensorImpl.cpp | 86 ++++---------------------- aten/src/ATen/CheckpointTensorImpl.h | 4 +- 2 files changed, 13 insertions(+), 77 deletions(-) diff --git a/aten/src/ATen/CheckpointTensorImpl.cpp b/aten/src/ATen/CheckpointTensorImpl.cpp index a96a97f1047..72ad9e63fda 100644 --- a/aten/src/ATen/CheckpointTensorImpl.cpp +++ b/aten/src/ATen/CheckpointTensorImpl.cpp @@ -153,13 +153,14 @@ void CheckpointPool::evict() { TORCH_CHECK(aps.size() > 0); bool shrinked = false; int evict_idx = -1; - double evict_score = INFINITY; + double evict_cost = INFINITY; time_t current_time = std::chrono::system_clock::now(); auto remove_from_aps = [&](size_t i) { aps[i] = aps[aps.size() - 1]; aps.pop_back(); }; std::uniform_int_distribution<> distrib(1, 1 * std::max(1, static_cast(std::sqrt(aps.size())))); + // sampling a random independent subset of all evictable tensors to find the cheapest tensor to evict. for (size_t i = 0; i < aps.size();) { auto cannot_evict = [&]() { shrinked = true; @@ -174,9 +175,9 @@ void CheckpointPool::evict() { } else { if (ap_strong->evictable()) { - double score = ap_strong->score(current_time); - if (score < evict_score) { - evict_score = score; + double cost = ap_strong->cost(current_time); + if (cost < evict_cost) { + evict_cost = cost; evict_idx = i; } } @@ -216,14 +217,14 @@ Tensor checkpoint(const Tensor& t) { Tensor uncheckpoint(const Tensor& t) { STATS.track("uncheckpoint"); auto* cpti = dynamic_cast(t.unsafeGetTensorImpl()); - // CHECK(cpti != nullptr); + TORCH_CHECK(cpti != nullptr); return cpti->ref->value->value->get(); } void pin(const Tensor& t) { STATS.track("pin"); auto* cpti = dynamic_cast(t.unsafeGetTensorImpl()); - // CHECK(cpti != nullptr); + TORCH_CHECK(cpti != nullptr); cpti->ref->value->value->pin(); } @@ -302,7 +303,7 @@ Tensors uncheckpoint(const strongs& inputs) { Tensors ret; ret.reserve(inputs.size()); for (const strong& input : inputs) { - // Jared: inlined manually + // inlined manually ret.push_back(input->get()); } return ret; @@ -341,17 +342,17 @@ void AliasPool::evict() { cell->evict(); } } - // TORCH_CHECK(current_memory() < b4); + TORCH_CHECK(current_memory() < b4); // somehow it is still evicting unevictable stuff. } -double AliasPool::score(time_t current_time) { +double AliasPool::cost(time_t current_time) { auto cpi = head_remat->get_cpi(); auto ecns = neighbor_ecn(); for (const auto& necn : ecns) { cpi = merge_cpi(cpi, get_t(necn)); } - return cpi.score(memory, (current_time - last_used_time).count()); + return cpi.cost(memory, (current_time - last_used_time).count()); } void External::release_resources() { @@ -393,71 +394,6 @@ CheckpointInfo Rematerializer::get_cpi() { return CheckpointInfo(ecn ? duration_t(0) : compute_cost); } -// A utility function to swap two elements -void swap(int* a, int* b) -{ - int t = *a; - *a = *b; - *b = t; -} - -// /* This function is same in both iterative and recursive*/ -// template -// int partition(std::vector& arr, int l, int h) -// { -// int x = arr[h]; -// int i = (l - 1); - -// for (int j = l; j <= h - 1; j++) { -// if (arr[j] <= x) { -// i++; -// swap(&arr[i], &arr[j]); -// } -// } -// swap(&arr[i + 1], &arr[h]); -// return (i + 1); -// } - -// template -// void qsort_dedup(std::vector& arr, int l, int h) -// { -// // Create an auxiliary stack -// std::vector stack; -// stack.reserve(h - l + 1); - -// // initialize top of stack -// int top = -1; - -// // push initial values of l and h to stack -// stack[++top] = l; -// stack[++top] = h; - -// // Keep popping from stack while is not empty -// while (top >= 0) { -// // Pop h and l -// h = stack[top--]; -// l = stack[top--]; - -// // Set pivot element at its correct position -// // in sorted array -// int p = partition(arr, l, h); - -// // If there are elements on left side of pivot, -// // then push left side to stack -// if (p - 1 > l) { -// stack[++top] = l; -// stack[++top] = p - 1; -// } - -// // If there are elements on right side of pivot, -// // then push right side to stack -// if (p + 1 < h) { -// stack[++top] = p + 1; -// stack[++top] = h; -// } -// } -// } - std::set AliasPool::neighbor_ecn() { STATS.track("AliasPool::neighbor_ecn"); std::set ptr_set; diff --git a/aten/src/ATen/CheckpointTensorImpl.h b/aten/src/ATen/CheckpointTensorImpl.h index 1d328574951..3e0f3e4d203 100644 --- a/aten/src/ATen/CheckpointTensorImpl.h +++ b/aten/src/ATen/CheckpointTensorImpl.h @@ -155,7 +155,7 @@ using duration_t = std::chrono::system_clock::duration; struct CheckpointInfo { duration_t compute_cost; // @ZACH: Floating Point instability? - double score(size_t memory, size_t staleness) const { + double cost(size_t memory, size_t staleness) const { TORCH_CHECK(memory > 0); TORCH_CHECK(staleness > 0); return compute_cost.count() / static_cast(memory * staleness); @@ -248,7 +248,7 @@ struct AliasPool : intrusive_ptr_target { } // if it is evicted, then hold the evicted tensor group. ecn_ptr ecn; - double score(time_t current_time); + double cost(time_t current_time); void evict(); void register_external() { ++external_count;