Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

refactor - remove stitch #38

Merged
merged 2 commits into from
May 24, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
172 changes: 79 additions & 93 deletions aten/src/ATen/CheckpointTensorImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,58 +4,6 @@

namespace at {

void AliasPool::evict() {
TORCH_CHECK(lock_count == 0);
for (const weak& w : tensors) {
if (auto cell = w.lock()) {
cell->evict();
}
}
}

void External::release_resources() {
value->evict();
value.reset();
}

Tensors stitch(const strongs& input_values,
const std::vector<std::tuple<Tensor, size_t>>& constants) {
Tensors input;
size_t i = 0, j = 0;
while (i != input_values.size() || j != constants.size()) {
if (j < constants.size() && std::get<1>(constants[j]) == input.size()) {
Tensor t = std::get<0>(constants[j]);
TORCH_CHECK(!t.key_set().has(DispatchKey::CheckpointTensorId));
input.push_back(t);
++j;
}
else {
CHECK(i < input_values.size());
input.push_back(input_values[i]->get());
++i;
}
}
return input;
}

void Rematerializer::remat() {
// TODO: refactor using RAII for exception safety.
for (const strong& s : input_values) {
++(s->pool->lock_count);
}
Tensors ts = stitch(input_values, constants);
auto ret = func(ts);
TORCH_CHECK(ret.size() == outputs.size());
for (size_t i = 0; i < outputs.size(); ++i) {
if (auto output_cell = outputs[i].lock()) {
output_cell->fill(ret[i]);
}
}
for (const strong& s : input_values) {
--(s->pool->lock_count);
}
}

namespace native {

Tensor checkpoint(const Tensor& t) {
Expand Down Expand Up @@ -87,6 +35,10 @@ bool is_checkpoint(const Tensor& t) {
return cpti != nullptr;
}

Tensor try_checkpoint(const Tensor& t) {
return is_checkpoint(t) ? t : checkpoint(t);
}

void new_log(std::string str) {
DTRLogger::logger().out = std::ofstream(DTRLogger::logger().get_filename(str));
}
Expand All @@ -108,6 +60,58 @@ void clear_checkpointpool() {

}

Tensor uncheckpoint(const strong& input) {
return input->get();
}

Tensors uncheckpoint(const strongs& inputs) {
Tensors ret;
for (const strong& input : inputs) {
ret.push_back(uncheckpoint(input));
}
return ret;
};

Tensors try_checkpoint(const Tensors& inputs) {
Tensors ret;
for (const Tensor& input : inputs) {
ret.push_back(at::native::try_checkpoint(input));
}
return ret;
}

void AliasPool::evict() {
TORCH_CHECK(lock_count == 0);
for (const weak& w : tensors) {
if (auto cell = w.lock()) {
cell->evict();
}
}
}

void External::release_resources() {
value->evict();
value.reset();
}

void Rematerializer::remat() {
// TODO: refactor using RAII for exception safety.
for (const strong& s : inputs) {
++(s->pool->lock_count);
}
Tensors ts = uncheckpoint(inputs);
auto ret = func(ts);
TORCH_CHECK(ret.size() == outputs.size());
for (size_t i = 0; i < outputs.size(); ++i) {
if (auto output_cell = outputs[i].lock()) {
output_cell->fill(ret[i]);
}
}
for (const strong& s : inputs) {
--(s->pool->lock_count);
}
}

intrusive_ptr<TensorImpl> CheckpointTensorImpl::shallow_copy_and_detach(const VariableVersion& version_counter,
bool allow_tensor_metadata_change) const {
auto ret = intrusive_ptr<CheckpointTensorImpl>::make(ref);
Expand Down Expand Up @@ -153,29 +157,23 @@ struct MakeRawResult {
// however, we have to stitch the two vectors together to pass it in remat.
// the size_t in constants decide the location to stitch them in, while input_values fill in the rest.
MakeRawResult make_raw(const rematerialize_function_t& remat_f,
// We need this to assign alias pool.
// This is ugly as fuck but after refactoring we dont even need stitching anymore.
const Tensors& raw_input,
const strongs& input_values,
const std::vector<std::tuple<Tensor, size_t>>& constants) {
Tensors inputs = stitch(input_values, constants);
const strongs& inputs) {
Tensors raw_inputs = uncheckpoint(inputs);
time_t pre = std::chrono::system_clock::now();
auto outputs_raw = remat_f(inputs);
auto outputs_raw = remat_f(raw_inputs);
time_t post = std::chrono::system_clock::now();
std::vector<intrusive_ptr<External>> outputs;
std::vector<int> aliases;
weaks weak_outputs;
auto remat = intrusive_ptr<Rematerializer>::make(Unsafe(), input_values, constants, remat_f);
auto remat = intrusive_ptr<Rematerializer>::make(Unsafe(), remat_f, inputs);
for (const Tensor& t : outputs_raw) {
int alias = get_alias(inputs, t);
int alias = get_alias(raw_inputs, t);
intrusive_ptr<AliasPool> pool;
if (alias == -1) {
pool = intrusive_ptr<AliasPool>::make(Unsafe(), true, memory(t));
}
else if (auto* cpti = dynamic_cast<CheckpointTensorImpl*>(raw_input[alias].unsafeGetTensorImpl())) {
pool = cpti->ref->value->value->pool;
} else { // alias to an constant. unevictable.
pool = intrusive_ptr<AliasPool>::make(Unsafe(), false, memory(t));
else {
pool = inputs[alias]->pool;
}
auto e = intrusive_ptr<External>::make(t, pool, remat);
pool->tensors.push_back(weak(e->value));
Expand All @@ -194,30 +192,24 @@ std::string from_time(duration_t t) {
Tensors CheckpointTensorImpl::make(const std::string& name,
const rematerialize_function_t& remat,
const Tensors& inputs) {
Tensors checkpointed_inputs = try_checkpoint(inputs);
strongs input_values;
std::vector<std::tuple<Tensor, size_t>> constants;
std::vector<size_t> constant_idx;
std::vector<std::string> args;
for (const Tensor& t: inputs) {
if (auto* cpti = dynamic_cast<CheckpointTensorImpl*>(t.unsafeGetTensorImpl())) {
input_values.push_back(cpti->ref->value->value);
args.push_back(cpti->counter_name());
}
else {
size_t idx = input_values.size() + constants.size();
constants.push_back({t, idx});
constant_idx.push_back(idx);
}
for (const Tensor& t: checkpointed_inputs) {
auto* cpti = dynamic_cast<CheckpointTensorImpl*>(t.unsafeGetTensorImpl());
TORCH_CHECK(cpti);
input_values.push_back(cpti->ref->value->value);
args.push_back(cpti->counter_name());
}
std::vector<std::string> res;
auto ret = make_raw(remat, inputs, input_values, constants);
auto ret = make_raw(remat, input_values);
Tensors tensors;
for (const auto& t: ret.outputs) {
auto cp = Tensor(intrusive_ptr<CheckpointTensorImpl>::make(t));
tensors.push_back(cp);
res.push_back(get_cpti(cp)->counter_name());
}
DTRLogCall(res, name, args, constant_idx, from_time(ret.time));
DTRLogCall(res, name, args, from_time(ret.time));
for (size_t i = 0; i < tensors.size(); ++i) {
Tensor t = tensors[i];
auto cpti = get_cpti(t);
Expand All @@ -240,27 +232,21 @@ void CheckpointTensorImpl::mutate(const std::string& name,
mutate(new_input_values);
return new_input_values;
};
Tensors checkpointed_inputs = try_checkpoint(inputs);
strongs input_values;
std::vector<std::tuple<Tensor, size_t>> constants;
std::vector<size_t> constant_idx;
std::vector<std::string> args;
for (const Tensor& t: inputs) {
if (auto* cpti = dynamic_cast<CheckpointTensorImpl*>(t.unsafeGetTensorImpl())) {
input_values.push_back(cpti->ref->value->value);
args.push_back(cpti->counter_name());
}
else {
size_t idx = input_values.size() + constants.size();
constants.push_back({t, idx});
constant_idx.push_back(idx);
}
for (const Tensor& t: checkpointed_inputs) {
auto* cpti = dynamic_cast<CheckpointTensorImpl*>(t.unsafeGetTensorImpl());
TORCH_CHECK(cpti);
input_values.push_back(cpti->ref->value->value);
args.push_back(cpti->counter_name());
}
auto ret = make_raw(remat, inputs, input_values, constants);
auto ret = make_raw(remat, input_values);
const auto& modified = ret.outputs;
for (size_t idx: mutate_idx) {
cell_from_tensor(inputs[idx])->value = modified[idx];
}
DTRLogMutate(name, args, constant_idx, mutate_idx, from_time(ret.time));
DTRLogMutate(name, args, mutate_idx, from_time(ret.time));
}

void CheckpointTensorImpl::release_resources() {
Expand Down
30 changes: 7 additions & 23 deletions aten/src/ATen/CheckpointTensorImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -138,34 +138,18 @@ struct AliasPool : intrusive_ptr_target {
// To build the cycle remat support a default constructor,
// And allow you to fill in the member later.
struct Rematerializer : intrusive_ptr_target {
// I am trying to represent a list of either checkpointedtensor or rawtensor.
// Is stitch the best way to do this?
// Maybe another approach is to use a list of tensor, and do dynamic downcasting?
// WHY DONT WE SIMPLY MAKE ALL CONSTANTS CHECKPOINTED TENSORS AS IS IN THE PREVIOUS VERSION?
// Oh, I remember, we are afraid that small tensors will get banished
// and make the big tensors unevictable.
// It sounds like a shitty reason - we can simply have an unbanishable flag
// as we do not rely on weak pointers anymore.
// And if we choose infinite staleness, then there is no need to deal with them specially -
// because they dont have rematerializer it will never get evicted.
// We should probably refactor and fix this, but it will take some nontrivial effort.
strongs input_values;
std::vector<std::tuple<Tensor, size_t>> constants;
weaks outputs;
rematerialize_function_t func;
strongs inputs;
weaks outputs;
Rematerializer(const Unsafe&,
const strongs& input_values,
const std::vector<std::tuple<Tensor, size_t>>& constants,
const rematerialize_function_t& func) :
input_values(input_values),
constants(constants),
func(func) {
const rematerialize_function_t& func,
const strongs& inputs) :
func(func), inputs(inputs) {
}
void release_resources() final {
input_values.clear();
constants.clear();
outputs.clear();
func = rematerialize_function_t();
inputs.clear();
outputs.clear();
}
void remat();
};
Expand Down
7 changes: 0 additions & 7 deletions aten/src/ATen/Logger.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@ const std::string MEMORY = "MEMORY";
const std::string ALIAS = "ALIAS";
const std::string NAME = "NAME";
const std::string CONSTANT = "CONSTANT";
const std::string CONSTANTS = "CONSTANTS";

void DTRLogConstant(const std::string& name) {
if (log_json) {
Expand Down Expand Up @@ -108,20 +107,17 @@ void DTRLogCopy(const std::string& new_name, const std::string& old_name) {

void DTRLogMutate(const std::string& name,
const std::vector<std::string>& args,
const std::vector<size_t>& constants,
const std::vector<size_t>& mutate,
const std::string& time) {
if (log_json) {
json j;
j[INSTRUCTION] = "MUTATE";
j[NAME] = name;
j[ARGS] = args;
j[CONSTANTS] = constants;
j["MUTATE"] = mutate;
j[TIME] = time;
DTRLogger::logger().log(j.dump());
} else {
CHECK(constants.size() == 0); //TODO: implement.
std::string log = name;
log += "(";
for (const auto& s : args) {
Expand Down Expand Up @@ -157,19 +153,16 @@ void DTRLogRelease(const std::string& counter_name) {
void DTRLogCall(const std::vector<std::string>& res,
const std::string& name,
const std::vector<std::string>& args,
const std::vector<size_t>& constants,
const std::string& time) {
if (log_json) {
json j;
j[INSTRUCTION] = "CALL";
j[NAME] = name;
j["RESULT"] = res;
j[ARGS] = args;
j[CONSTANTS] = constants;
j[TIME] = time;
DTRLogger::logger().log(j.dump());
} else {
CHECK(constants.size() == 0); //TODO: implement.
std::string arg = name + "(";
for (const auto& s : args) {
arg += s;
Expand Down
3 changes: 3 additions & 0 deletions aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
- func: checkpoint(Tensor self) -> Tensor
variants: method

- func: try_checkpoint(Tensor self) -> Tensor
variants: method

- func: is_checkpoint(Tensor self) -> bool
variants: method

Expand Down