Skip to content

Commit

Permalink
Merge pull request #38 from uwsampl/refactor
Browse files Browse the repository at this point in the history
refactor - remove stitch
  • Loading branch information
MarisaKirisame committed May 24, 2020
2 parents 0b770f3 + ec0fe99 commit 0118183
Show file tree
Hide file tree
Showing 4 changed files with 89 additions and 123 deletions.
172 changes: 79 additions & 93 deletions aten/src/ATen/CheckpointTensorImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,58 +4,6 @@

namespace at {

void AliasPool::evict() {
TORCH_CHECK(lock_count == 0);
for (const weak& w : tensors) {
if (auto cell = w.lock()) {
cell->evict();
}
}
}

void External::release_resources() {
value->evict();
value.reset();
}

Tensors stitch(const strongs& input_values,
const std::vector<std::tuple<Tensor, size_t>>& constants) {
Tensors input;
size_t i = 0, j = 0;
while (i != input_values.size() || j != constants.size()) {
if (j < constants.size() && std::get<1>(constants[j]) == input.size()) {
Tensor t = std::get<0>(constants[j]);
TORCH_CHECK(!t.key_set().has(DispatchKey::CheckpointTensorId));
input.push_back(t);
++j;
}
else {
CHECK(i < input_values.size());
input.push_back(input_values[i]->get());
++i;
}
}
return input;
}

void Rematerializer::remat() {
// TODO: refactor using RAII for exception safety.
for (const strong& s : input_values) {
++(s->pool->lock_count);
}
Tensors ts = stitch(input_values, constants);
auto ret = func(ts);
TORCH_CHECK(ret.size() == outputs.size());
for (size_t i = 0; i < outputs.size(); ++i) {
if (auto output_cell = outputs[i].lock()) {
output_cell->fill(ret[i]);
}
}
for (const strong& s : input_values) {
--(s->pool->lock_count);
}
}

namespace native {

Tensor checkpoint(const Tensor& t) {
Expand Down Expand Up @@ -87,6 +35,10 @@ bool is_checkpoint(const Tensor& t) {
return cpti != nullptr;
}

Tensor try_checkpoint(const Tensor& t) {
return is_checkpoint(t) ? t : checkpoint(t);
}

void new_log(std::string str) {
DTRLogger::logger().out = std::ofstream(DTRLogger::logger().get_filename(str));
}
Expand All @@ -108,6 +60,58 @@ void clear_checkpointpool() {

}

Tensor uncheckpoint(const strong& input) {
return input->get();
}

Tensors uncheckpoint(const strongs& inputs) {
Tensors ret;
for (const strong& input : inputs) {
ret.push_back(uncheckpoint(input));
}
return ret;
};

Tensors try_checkpoint(const Tensors& inputs) {
Tensors ret;
for (const Tensor& input : inputs) {
ret.push_back(at::native::try_checkpoint(input));
}
return ret;
}

void AliasPool::evict() {
TORCH_CHECK(lock_count == 0);
for (const weak& w : tensors) {
if (auto cell = w.lock()) {
cell->evict();
}
}
}

void External::release_resources() {
value->evict();
value.reset();
}

void Rematerializer::remat() {
// TODO: refactor using RAII for exception safety.
for (const strong& s : inputs) {
++(s->pool->lock_count);
}
Tensors ts = uncheckpoint(inputs);
auto ret = func(ts);
TORCH_CHECK(ret.size() == outputs.size());
for (size_t i = 0; i < outputs.size(); ++i) {
if (auto output_cell = outputs[i].lock()) {
output_cell->fill(ret[i]);
}
}
for (const strong& s : inputs) {
--(s->pool->lock_count);
}
}

intrusive_ptr<TensorImpl> CheckpointTensorImpl::shallow_copy_and_detach(const VariableVersion& version_counter,
bool allow_tensor_metadata_change) const {
auto ret = intrusive_ptr<CheckpointTensorImpl>::make(ref);
Expand Down Expand Up @@ -153,29 +157,23 @@ struct MakeRawResult {
// however, we have to stitch the two vectors together to pass it in remat.
// the size_t in constants decide the location to stitch them in, while input_values fill in the rest.
MakeRawResult make_raw(const rematerialize_function_t& remat_f,
// We need this to assign alias pool.
// This is ugly as fuck but after refactoring we dont even need stitching anymore.
const Tensors& raw_input,
const strongs& input_values,
const std::vector<std::tuple<Tensor, size_t>>& constants) {
Tensors inputs = stitch(input_values, constants);
const strongs& inputs) {
Tensors raw_inputs = uncheckpoint(inputs);
time_t pre = std::chrono::system_clock::now();
auto outputs_raw = remat_f(inputs);
auto outputs_raw = remat_f(raw_inputs);
time_t post = std::chrono::system_clock::now();
std::vector<intrusive_ptr<External>> outputs;
std::vector<int> aliases;
weaks weak_outputs;
auto remat = intrusive_ptr<Rematerializer>::make(Unsafe(), input_values, constants, remat_f);
auto remat = intrusive_ptr<Rematerializer>::make(Unsafe(), remat_f, inputs);
for (const Tensor& t : outputs_raw) {
int alias = get_alias(inputs, t);
int alias = get_alias(raw_inputs, t);
intrusive_ptr<AliasPool> pool;
if (alias == -1) {
pool = intrusive_ptr<AliasPool>::make(Unsafe(), true, memory(t));
}
else if (auto* cpti = dynamic_cast<CheckpointTensorImpl*>(raw_input[alias].unsafeGetTensorImpl())) {
pool = cpti->ref->value->value->pool;
} else { // alias to an constant. unevictable.
pool = intrusive_ptr<AliasPool>::make(Unsafe(), false, memory(t));
else {
pool = inputs[alias]->pool;
}
auto e = intrusive_ptr<External>::make(t, pool, remat);
pool->tensors.push_back(weak(e->value));
Expand All @@ -194,30 +192,24 @@ std::string from_time(duration_t t) {
Tensors CheckpointTensorImpl::make(const std::string& name,
const rematerialize_function_t& remat,
const Tensors& inputs) {
Tensors checkpointed_inputs = try_checkpoint(inputs);
strongs input_values;
std::vector<std::tuple<Tensor, size_t>> constants;
std::vector<size_t> constant_idx;
std::vector<std::string> args;
for (const Tensor& t: inputs) {
if (auto* cpti = dynamic_cast<CheckpointTensorImpl*>(t.unsafeGetTensorImpl())) {
input_values.push_back(cpti->ref->value->value);
args.push_back(cpti->counter_name());
}
else {
size_t idx = input_values.size() + constants.size();
constants.push_back({t, idx});
constant_idx.push_back(idx);
}
for (const Tensor& t: checkpointed_inputs) {
auto* cpti = dynamic_cast<CheckpointTensorImpl*>(t.unsafeGetTensorImpl());
TORCH_CHECK(cpti);
input_values.push_back(cpti->ref->value->value);
args.push_back(cpti->counter_name());
}
std::vector<std::string> res;
auto ret = make_raw(remat, inputs, input_values, constants);
auto ret = make_raw(remat, input_values);
Tensors tensors;
for (const auto& t: ret.outputs) {
auto cp = Tensor(intrusive_ptr<CheckpointTensorImpl>::make(t));
tensors.push_back(cp);
res.push_back(get_cpti(cp)->counter_name());
}
DTRLogCall(res, name, args, constant_idx, from_time(ret.time));
DTRLogCall(res, name, args, from_time(ret.time));
for (size_t i = 0; i < tensors.size(); ++i) {
Tensor t = tensors[i];
auto cpti = get_cpti(t);
Expand All @@ -240,27 +232,21 @@ void CheckpointTensorImpl::mutate(const std::string& name,
mutate(new_input_values);
return new_input_values;
};
Tensors checkpointed_inputs = try_checkpoint(inputs);
strongs input_values;
std::vector<std::tuple<Tensor, size_t>> constants;
std::vector<size_t> constant_idx;
std::vector<std::string> args;
for (const Tensor& t: inputs) {
if (auto* cpti = dynamic_cast<CheckpointTensorImpl*>(t.unsafeGetTensorImpl())) {
input_values.push_back(cpti->ref->value->value);
args.push_back(cpti->counter_name());
}
else {
size_t idx = input_values.size() + constants.size();
constants.push_back({t, idx});
constant_idx.push_back(idx);
}
for (const Tensor& t: checkpointed_inputs) {
auto* cpti = dynamic_cast<CheckpointTensorImpl*>(t.unsafeGetTensorImpl());
TORCH_CHECK(cpti);
input_values.push_back(cpti->ref->value->value);
args.push_back(cpti->counter_name());
}
auto ret = make_raw(remat, inputs, input_values, constants);
auto ret = make_raw(remat, input_values);
const auto& modified = ret.outputs;
for (size_t idx: mutate_idx) {
cell_from_tensor(inputs[idx])->value = modified[idx];
}
DTRLogMutate(name, args, constant_idx, mutate_idx, from_time(ret.time));
DTRLogMutate(name, args, mutate_idx, from_time(ret.time));
}

void CheckpointTensorImpl::release_resources() {
Expand Down
30 changes: 7 additions & 23 deletions aten/src/ATen/CheckpointTensorImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -138,34 +138,18 @@ struct AliasPool : intrusive_ptr_target {
// To build the cycle remat support a default constructor,
// And allow you to fill in the member later.
struct Rematerializer : intrusive_ptr_target {
// I am trying to represent a list of either checkpointedtensor or rawtensor.
// Is stitch the best way to do this?
// Maybe another approach is to use a list of tensor, and do dynamic downcasting?
// WHY DONT WE SIMPLY MAKE ALL CONSTANTS CHECKPOINTED TENSORS AS IS IN THE PREVIOUS VERSION?
// Oh, I remember, we are afraid that small tensors will get banished
// and make the big tensors unevictable.
// It sounds like a shitty reason - we can simply have an unbanishable flag
// as we do not rely on weak pointers anymore.
// And if we choose infinite staleness, then there is no need to deal with them specially -
// because they dont have rematerializer it will never get evicted.
// We should probably refactor and fix this, but it will take some nontrivial effort.
strongs input_values;
std::vector<std::tuple<Tensor, size_t>> constants;
weaks outputs;
rematerialize_function_t func;
strongs inputs;
weaks outputs;
Rematerializer(const Unsafe&,
const strongs& input_values,
const std::vector<std::tuple<Tensor, size_t>>& constants,
const rematerialize_function_t& func) :
input_values(input_values),
constants(constants),
func(func) {
const rematerialize_function_t& func,
const strongs& inputs) :
func(func), inputs(inputs) {
}
void release_resources() final {
input_values.clear();
constants.clear();
outputs.clear();
func = rematerialize_function_t();
inputs.clear();
outputs.clear();
}
void remat();
};
Expand Down
7 changes: 0 additions & 7 deletions aten/src/ATen/Logger.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@ const std::string MEMORY = "MEMORY";
const std::string ALIAS = "ALIAS";
const std::string NAME = "NAME";
const std::string CONSTANT = "CONSTANT";
const std::string CONSTANTS = "CONSTANTS";

void DTRLogConstant(const std::string& name) {
if (log_json) {
Expand Down Expand Up @@ -108,20 +107,17 @@ void DTRLogCopy(const std::string& new_name, const std::string& old_name) {

void DTRLogMutate(const std::string& name,
const std::vector<std::string>& args,
const std::vector<size_t>& constants,
const std::vector<size_t>& mutate,
const std::string& time) {
if (log_json) {
json j;
j[INSTRUCTION] = "MUTATE";
j[NAME] = name;
j[ARGS] = args;
j[CONSTANTS] = constants;
j["MUTATE"] = mutate;
j[TIME] = time;
DTRLogger::logger().log(j.dump());
} else {
CHECK(constants.size() == 0); //TODO: implement.
std::string log = name;
log += "(";
for (const auto& s : args) {
Expand Down Expand Up @@ -157,19 +153,16 @@ void DTRLogRelease(const std::string& counter_name) {
void DTRLogCall(const std::vector<std::string>& res,
const std::string& name,
const std::vector<std::string>& args,
const std::vector<size_t>& constants,
const std::string& time) {
if (log_json) {
json j;
j[INSTRUCTION] = "CALL";
j[NAME] = name;
j["RESULT"] = res;
j[ARGS] = args;
j[CONSTANTS] = constants;
j[TIME] = time;
DTRLogger::logger().log(j.dump());
} else {
CHECK(constants.size() == 0); //TODO: implement.
std::string arg = name + "(";
for (const auto& s : args) {
arg += s;
Expand Down
3 changes: 3 additions & 0 deletions aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
- func: checkpoint(Tensor self) -> Tensor
variants: method

- func: try_checkpoint(Tensor self) -> Tensor
variants: method

- func: is_checkpoint(Tensor self) -> bool
variants: method

Expand Down

0 comments on commit 0118183

Please # to comment.