diff --git a/aten/src/ATen/CheckpointTensorImpl.cpp b/aten/src/ATen/CheckpointTensorImpl.cpp index d7f1e1ba9dd..1a593e9e252 100644 --- a/aten/src/ATen/CheckpointTensorImpl.cpp +++ b/aten/src/ATen/CheckpointTensorImpl.cpp @@ -4,58 +4,6 @@ namespace at { -void AliasPool::evict() { - TORCH_CHECK(lock_count == 0); - for (const weak& w : tensors) { - if (auto cell = w.lock()) { - cell->evict(); - } - } -} - -void External::release_resources() { - value->evict(); - value.reset(); -} - -Tensors stitch(const strongs& input_values, - const std::vector>& constants) { - Tensors input; - size_t i = 0, j = 0; - while (i != input_values.size() || j != constants.size()) { - if (j < constants.size() && std::get<1>(constants[j]) == input.size()) { - Tensor t = std::get<0>(constants[j]); - TORCH_CHECK(!t.key_set().has(DispatchKey::CheckpointTensorId)); - input.push_back(t); - ++j; - } - else { - CHECK(i < input_values.size()); - input.push_back(input_values[i]->get()); - ++i; - } - } - return input; -} - -void Rematerializer::remat() { - // TODO: refactor using RAII for exception safety. - for (const strong& s : input_values) { - ++(s->pool->lock_count); - } - Tensors ts = stitch(input_values, constants); - auto ret = func(ts); - TORCH_CHECK(ret.size() == outputs.size()); - for (size_t i = 0; i < outputs.size(); ++i) { - if (auto output_cell = outputs[i].lock()) { - output_cell->fill(ret[i]); - } - } - for (const strong& s : input_values) { - --(s->pool->lock_count); - } -} - namespace native { Tensor checkpoint(const Tensor& t) { @@ -87,6 +35,10 @@ bool is_checkpoint(const Tensor& t) { return cpti != nullptr; } +Tensor try_checkpoint(const Tensor& t) { + return is_checkpoint(t) ? t : checkpoint(t); +} + void new_log(std::string str) { DTRLogger::logger().out = std::ofstream(DTRLogger::logger().get_filename(str)); } @@ -108,6 +60,58 @@ void clear_checkpointpool() { } +Tensor uncheckpoint(const strong& input) { + return input->get(); +} + +Tensors uncheckpoint(const strongs& inputs) { + Tensors ret; + for (const strong& input : inputs) { + ret.push_back(uncheckpoint(input)); + } + return ret; +}; + +Tensors try_checkpoint(const Tensors& inputs) { + Tensors ret; + for (const Tensor& input : inputs) { + ret.push_back(at::native::try_checkpoint(input)); + } + return ret; +} + +void AliasPool::evict() { + TORCH_CHECK(lock_count == 0); + for (const weak& w : tensors) { + if (auto cell = w.lock()) { + cell->evict(); + } + } +} + +void External::release_resources() { + value->evict(); + value.reset(); +} + +void Rematerializer::remat() { + // TODO: refactor using RAII for exception safety. + for (const strong& s : inputs) { + ++(s->pool->lock_count); + } + Tensors ts = uncheckpoint(inputs); + auto ret = func(ts); + TORCH_CHECK(ret.size() == outputs.size()); + for (size_t i = 0; i < outputs.size(); ++i) { + if (auto output_cell = outputs[i].lock()) { + output_cell->fill(ret[i]); + } + } + for (const strong& s : inputs) { + --(s->pool->lock_count); + } +} + intrusive_ptr CheckpointTensorImpl::shallow_copy_and_detach(const VariableVersion& version_counter, bool allow_tensor_metadata_change) const { auto ret = intrusive_ptr::make(ref); @@ -153,29 +157,23 @@ struct MakeRawResult { // however, we have to stitch the two vectors together to pass it in remat. // the size_t in constants decide the location to stitch them in, while input_values fill in the rest. MakeRawResult make_raw(const rematerialize_function_t& remat_f, - // We need this to assign alias pool. - // This is ugly as fuck but after refactoring we dont even need stitching anymore. - const Tensors& raw_input, - const strongs& input_values, - const std::vector>& constants) { - Tensors inputs = stitch(input_values, constants); + const strongs& inputs) { + Tensors raw_inputs = uncheckpoint(inputs); time_t pre = std::chrono::system_clock::now(); - auto outputs_raw = remat_f(inputs); + auto outputs_raw = remat_f(raw_inputs); time_t post = std::chrono::system_clock::now(); std::vector> outputs; std::vector aliases; weaks weak_outputs; - auto remat = intrusive_ptr::make(Unsafe(), input_values, constants, remat_f); + auto remat = intrusive_ptr::make(Unsafe(), remat_f, inputs); for (const Tensor& t : outputs_raw) { - int alias = get_alias(inputs, t); + int alias = get_alias(raw_inputs, t); intrusive_ptr pool; if (alias == -1) { pool = intrusive_ptr::make(Unsafe(), true, memory(t)); } - else if (auto* cpti = dynamic_cast(raw_input[alias].unsafeGetTensorImpl())) { - pool = cpti->ref->value->value->pool; - } else { // alias to an constant. unevictable. - pool = intrusive_ptr::make(Unsafe(), false, memory(t)); + else { + pool = inputs[alias]->pool; } auto e = intrusive_ptr::make(t, pool, remat); pool->tensors.push_back(weak(e->value)); @@ -194,30 +192,24 @@ std::string from_time(duration_t t) { Tensors CheckpointTensorImpl::make(const std::string& name, const rematerialize_function_t& remat, const Tensors& inputs) { + Tensors checkpointed_inputs = try_checkpoint(inputs); strongs input_values; - std::vector> constants; - std::vector constant_idx; std::vector args; - for (const Tensor& t: inputs) { - if (auto* cpti = dynamic_cast(t.unsafeGetTensorImpl())) { - input_values.push_back(cpti->ref->value->value); - args.push_back(cpti->counter_name()); - } - else { - size_t idx = input_values.size() + constants.size(); - constants.push_back({t, idx}); - constant_idx.push_back(idx); - } + for (const Tensor& t: checkpointed_inputs) { + auto* cpti = dynamic_cast(t.unsafeGetTensorImpl()); + TORCH_CHECK(cpti); + input_values.push_back(cpti->ref->value->value); + args.push_back(cpti->counter_name()); } std::vector res; - auto ret = make_raw(remat, inputs, input_values, constants); + auto ret = make_raw(remat, input_values); Tensors tensors; for (const auto& t: ret.outputs) { auto cp = Tensor(intrusive_ptr::make(t)); tensors.push_back(cp); res.push_back(get_cpti(cp)->counter_name()); } - DTRLogCall(res, name, args, constant_idx, from_time(ret.time)); + DTRLogCall(res, name, args, from_time(ret.time)); for (size_t i = 0; i < tensors.size(); ++i) { Tensor t = tensors[i]; auto cpti = get_cpti(t); @@ -240,27 +232,21 @@ void CheckpointTensorImpl::mutate(const std::string& name, mutate(new_input_values); return new_input_values; }; + Tensors checkpointed_inputs = try_checkpoint(inputs); strongs input_values; - std::vector> constants; - std::vector constant_idx; std::vector args; - for (const Tensor& t: inputs) { - if (auto* cpti = dynamic_cast(t.unsafeGetTensorImpl())) { - input_values.push_back(cpti->ref->value->value); - args.push_back(cpti->counter_name()); - } - else { - size_t idx = input_values.size() + constants.size(); - constants.push_back({t, idx}); - constant_idx.push_back(idx); - } + for (const Tensor& t: checkpointed_inputs) { + auto* cpti = dynamic_cast(t.unsafeGetTensorImpl()); + TORCH_CHECK(cpti); + input_values.push_back(cpti->ref->value->value); + args.push_back(cpti->counter_name()); } - auto ret = make_raw(remat, inputs, input_values, constants); + auto ret = make_raw(remat, input_values); const auto& modified = ret.outputs; for (size_t idx: mutate_idx) { cell_from_tensor(inputs[idx])->value = modified[idx]; } - DTRLogMutate(name, args, constant_idx, mutate_idx, from_time(ret.time)); + DTRLogMutate(name, args, mutate_idx, from_time(ret.time)); } void CheckpointTensorImpl::release_resources() { diff --git a/aten/src/ATen/CheckpointTensorImpl.h b/aten/src/ATen/CheckpointTensorImpl.h index 6166f4a44e9..7aa9530f124 100644 --- a/aten/src/ATen/CheckpointTensorImpl.h +++ b/aten/src/ATen/CheckpointTensorImpl.h @@ -138,34 +138,18 @@ struct AliasPool : intrusive_ptr_target { // To build the cycle remat support a default constructor, // And allow you to fill in the member later. struct Rematerializer : intrusive_ptr_target { - // I am trying to represent a list of either checkpointedtensor or rawtensor. - // Is stitch the best way to do this? - // Maybe another approach is to use a list of tensor, and do dynamic downcasting? - // WHY DONT WE SIMPLY MAKE ALL CONSTANTS CHECKPOINTED TENSORS AS IS IN THE PREVIOUS VERSION? - // Oh, I remember, we are afraid that small tensors will get banished - // and make the big tensors unevictable. - // It sounds like a shitty reason - we can simply have an unbanishable flag - // as we do not rely on weak pointers anymore. - // And if we choose infinite staleness, then there is no need to deal with them specially - - // because they dont have rematerializer it will never get evicted. - // We should probably refactor and fix this, but it will take some nontrivial effort. - strongs input_values; - std::vector> constants; - weaks outputs; rematerialize_function_t func; + strongs inputs; + weaks outputs; Rematerializer(const Unsafe&, - const strongs& input_values, - const std::vector>& constants, - const rematerialize_function_t& func) : - input_values(input_values), - constants(constants), - func(func) { + const rematerialize_function_t& func, + const strongs& inputs) : + func(func), inputs(inputs) { } void release_resources() final { - input_values.clear(); - constants.clear(); - outputs.clear(); func = rematerialize_function_t(); + inputs.clear(); + outputs.clear(); } void remat(); }; diff --git a/aten/src/ATen/Logger.h b/aten/src/ATen/Logger.h index 0868fb38050..190c2f28433 100644 --- a/aten/src/ATen/Logger.h +++ b/aten/src/ATen/Logger.h @@ -45,7 +45,6 @@ const std::string MEMORY = "MEMORY"; const std::string ALIAS = "ALIAS"; const std::string NAME = "NAME"; const std::string CONSTANT = "CONSTANT"; -const std::string CONSTANTS = "CONSTANTS"; void DTRLogConstant(const std::string& name) { if (log_json) { @@ -108,7 +107,6 @@ void DTRLogCopy(const std::string& new_name, const std::string& old_name) { void DTRLogMutate(const std::string& name, const std::vector& args, - const std::vector& constants, const std::vector& mutate, const std::string& time) { if (log_json) { @@ -116,12 +114,10 @@ void DTRLogMutate(const std::string& name, j[INSTRUCTION] = "MUTATE"; j[NAME] = name; j[ARGS] = args; - j[CONSTANTS] = constants; j["MUTATE"] = mutate; j[TIME] = time; DTRLogger::logger().log(j.dump()); } else { - CHECK(constants.size() == 0); //TODO: implement. std::string log = name; log += "("; for (const auto& s : args) { @@ -157,7 +153,6 @@ void DTRLogRelease(const std::string& counter_name) { void DTRLogCall(const std::vector& res, const std::string& name, const std::vector& args, - const std::vector& constants, const std::string& time) { if (log_json) { json j; @@ -165,11 +160,9 @@ void DTRLogCall(const std::vector& res, j[NAME] = name; j["RESULT"] = res; j[ARGS] = args; - j[CONSTANTS] = constants; j[TIME] = time; DTRLogger::logger().log(j.dump()); } else { - CHECK(constants.size() == 0); //TODO: implement. std::string arg = name + "("; for (const auto& s : args) { arg += s; diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 667250ce86e..87cd7cb7165 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -3,6 +3,9 @@ - func: checkpoint(Tensor self) -> Tensor variants: method +- func: try_checkpoint(Tensor self) -> Tensor + variants: method + - func: is_checkpoint(Tensor self) -> bool variants: method