From 30639521225c232382cc57ffd52d408fd3077e55 Mon Sep 17 00:00:00 2001 From: Marisa Kirisame Date: Thu, 9 Apr 2020 14:55:52 -0700 Subject: [PATCH 1/2] save --- aten/src/ATen/CheckpointTensorImpl.cpp | 73 ++++++++++++++++++++------ aten/src/ATen/CheckpointTensorImpl.h | 9 ---- 2 files changed, 56 insertions(+), 26 deletions(-) diff --git a/aten/src/ATen/CheckpointTensorImpl.cpp b/aten/src/ATen/CheckpointTensorImpl.cpp index 67e5748b559..619e50c2b81 100644 --- a/aten/src/ATen/CheckpointTensorImpl.cpp +++ b/aten/src/ATen/CheckpointTensorImpl.cpp @@ -35,6 +35,7 @@ const std::string ARGS = "ARGS"; const std::string MEMORY = "MEMORY"; const std::string NAME = "NAME"; const std::string CONSTANT = "CONSTANT"; +const std::string CONSTANTS = "CONSTANTS"; void DTRLogConstant(const std::string& name) { if (log_json) { @@ -107,11 +108,21 @@ Tensor checkpoint_raw(const Tensor& t) { } std::tuple make_raw(const rematerialize_function_t& remat, - const strongs& input_values) { + const strongs& input_values, + const std::vector>& constants) { std::vector input; - for (const strong& s: input_values) { - CHECK(!s->t.key_set().has(DispatchKey::CheckpointTensorId)); - input.push_back(s->t); + size_t i = 0, j = 0; + while (i != input_values.size() || j != constants.size()) { + if (j < constants.size() && std::get<1>(constants[j]) == input.size()) { + input.push_back(std::get<0>(constants[j])); + ++j; + } + else { + CHECK(i < input_values.size()); + CHECK(!input_values[i]->t.key_set().has(DispatchKey::CheckpointTensorId)); + input.push_back(input_values[i]->t); + ++i; + } } time_t pre = std::chrono::system_clock::now(); auto output = remat(input); @@ -123,16 +134,22 @@ std::string from_time(duration_t t) { return std::to_string(std::chrono::nanoseconds(t).count()); } -void DTRLogCall(const std::vector& res, const std::string& name, const std::vector& args, const std::string& time) { +void DTRLogCall(const std::vector& res, + const std::string& name, + const std::vector& args, + const std::vector& constants, + const std::string& time) { if (log_json) { json j; j[INSTRUCTION] = "CALL"; j[NAME] = name; j["RESULT"] = res; j[ARGS] = args; + j[CONSTANTS] = constants; j[TIME] = time; DTRLog(j.dump()); } else { + CHECK(constants.size() == 0); //TODO: implement. std::string arg = name + "("; for (const auto& s : arg) { arg += s; @@ -156,21 +173,29 @@ Tensors CheckpointTensorImpl::make(const std::string& name, const rematerialize_function_t& remat, const Tensors& input) { strongs input_values; + std::vector> constants; + std::vector constant_idx; std::vector args; for (const Tensor& t: input) { - auto ft = from_tensor(t); - input_values.push_back(std::get<0>(ft)); - args.push_back(std::get<1>(ft)); + if (auto* cpt = dynamic_cast(t.unsafeGetTensorImpl())) { + input_values.push_back(cpt->ref->value); + args.push_back(cpt->counter_name()); + } + else { + size_t idx = input_values.size() + constants.size(); + constants.push_back({t, idx}); + constant_idx.push_back(idx); + } } std::vector res; - auto ret = make_raw(remat, input_values); + auto ret = make_raw(remat, input_values, constants); Tensors tensors; for (const Tensor& t: std::get<0>(ret)) { auto cp = checkpoint_raw(t); tensors.push_back(cp); res.push_back(get_cpti(cp)->counter_name()); } - DTRLogCall(res, name, args, from_time(std::get<1>(ret))); + DTRLogCall(res, name, args, constant_idx, from_time(std::get<1>(ret))); for (const Tensor& t: tensors) { auto cpti = get_cpti(t); DTRLogMemory(cpti->counter_name(), cpti->ref->value->memory()); @@ -178,16 +203,22 @@ Tensors CheckpointTensorImpl::make(const std::string& name, return tensors; } -void DTRLogMutate(const std::string& name, const std::vector& args, const std::vector& mutate, const std::string& time) { +void DTRLogMutate(const std::string& name, + const std::vector& args, + const std::vector& constants, + const std::vector& mutate, + const std::string& time) { if (log_json) { json j; j[INSTRUCTION] = "MUTATE"; j[NAME] = name; j[ARGS] = args; + j[CONSTANTS] = constants; j["MUTATE"] = mutate; j[TIME] = time; DTRLog(j.dump()); } else { + CHECK(constants.size() == 0); //TODO: implement. std::string log = name; log += "("; for (const auto& s : args) { @@ -222,18 +253,26 @@ void CheckpointTensorImpl::mutate(const std::string& name, return new_input_values; }; strongs input_values; + std::vector> constants; + std::vector constant_idx; std::vector args; - for (const Tensor& t : inputs) { - auto ft = from_tensor(t); - args.push_back(std::get<1>(ft)); - input_values.push_back(std::get<0>(ft)); + for (const Tensor& t: inputs) { + if (auto* cpt = dynamic_cast(t.unsafeGetTensorImpl())) { + input_values.push_back(cpt->ref->value); + args.push_back(cpt->counter_name()); + } + else { + size_t idx = input_values.size() + constants.size(); + constants.push_back({t, idx}); + constant_idx.push_back(idx); + } } - auto ret = make_raw(remat, input_values); + auto ret = make_raw(remat, input_values, constants); const auto& modified = std::get<0>(ret); for (size_t idx: mutate_idx) { cell_from_tensor(inputs[idx])->value = intrusive_ptr::make(modified[idx]); } - DTRLogMutate(name, args, mutate_idx, from_time(std::get<1>(ret))); + DTRLogMutate(name, args, constant_idx, mutate_idx, from_time(std::get<1>(ret))); } void DTRLogRelease(const std::string& counter_name) { diff --git a/aten/src/ATen/CheckpointTensorImpl.h b/aten/src/ATen/CheckpointTensorImpl.h index e4c32a973df..259bcfe2f23 100644 --- a/aten/src/ATen/CheckpointTensorImpl.h +++ b/aten/src/ATen/CheckpointTensorImpl.h @@ -112,15 +112,6 @@ inline CheckpointTensorImpl* get_cpti(const Tensor& t) { return cpti; } -inline std::tuple from_tensor(const Tensor& t) { - auto* cpt = dynamic_cast(t.unsafeGetTensorImpl()); - if(cpt != nullptr) { - return {cpt->ref->value, cpt->counter_name()}; - } else { - return from_tensor(native::checkpoint(t)); - } -} - inline Tensor get(const strong& s) { return s->t; } From cf5ccd5c4c6471d49266b004e33a6d4ce706d500 Mon Sep 17 00:00:00 2001 From: Marisa Kirisame Date: Tue, 14 Apr 2020 22:44:33 -0700 Subject: [PATCH 2/2] fix comment --- aten/src/ATen/CheckpointTensorImpl.cpp | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/aten/src/ATen/CheckpointTensorImpl.cpp b/aten/src/ATen/CheckpointTensorImpl.cpp index 619e50c2b81..03891985460 100644 --- a/aten/src/ATen/CheckpointTensorImpl.cpp +++ b/aten/src/ATen/CheckpointTensorImpl.cpp @@ -107,6 +107,11 @@ Tensor checkpoint_raw(const Tensor& t) { return Tensor(intrusive_ptr::make(t.detach())); } +// remat take a single vector of tensors, +// while there are two vector, one storing nonconstants and one storing constants. +// the constants are small and they will not be considered for eviction. +// however, we have to stitch the two vectors together to pass it in remat. +// the size_t in constants decide the location to stitch them in, while input_values fill in the rest. std::tuple make_raw(const rematerialize_function_t& remat, const strongs& input_values, const std::vector>& constants) {