Skip to content

Commit

Permalink
Fix logger (special handling for small constants) (#19)
Browse files Browse the repository at this point in the history
* save

* fix comment
  • Loading branch information
MarisaKirisame committed Apr 15, 2020
1 parent 6e6c4e1 commit dab0aa8
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 26 deletions.
78 changes: 61 additions & 17 deletions aten/src/ATen/CheckpointTensorImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ const std::string ARGS = "ARGS";
const std::string MEMORY = "MEMORY";
const std::string NAME = "NAME";
const std::string CONSTANT = "CONSTANT";
const std::string CONSTANTS = "CONSTANTS";

void DTRLogConstant(const std::string& name) {
if (log_json) {
Expand Down Expand Up @@ -106,12 +107,27 @@ Tensor checkpoint_raw(const Tensor& t) {
return Tensor(intrusive_ptr<CheckpointTensorImpl>::make(t.detach()));
}

// remat take a single vector of tensors,
// while there are two vector, one storing nonconstants and one storing constants.
// the constants are small and they will not be considered for eviction.
// however, we have to stitch the two vectors together to pass it in remat.
// the size_t in constants decide the location to stitch them in, while input_values fill in the rest.
std::tuple<Tensors, duration_t> make_raw(const rematerialize_function_t& remat,
const strongs& input_values) {
const strongs& input_values,
const std::vector<std::tuple<Tensor, size_t>>& constants) {
std::vector<Tensor> input;
for (const strong& s: input_values) {
CHECK(!s->t.key_set().has(DispatchKey::CheckpointTensorId));
input.push_back(s->t);
size_t i = 0, j = 0;
while (i != input_values.size() || j != constants.size()) {
if (j < constants.size() && std::get<1>(constants[j]) == input.size()) {
input.push_back(std::get<0>(constants[j]));
++j;
}
else {
CHECK(i < input_values.size());
CHECK(!input_values[i]->t.key_set().has(DispatchKey::CheckpointTensorId));
input.push_back(input_values[i]->t);
++i;
}
}
time_t pre = std::chrono::system_clock::now();
auto output = remat(input);
Expand All @@ -123,16 +139,22 @@ std::string from_time(duration_t t) {
return std::to_string(std::chrono::nanoseconds(t).count());
}

void DTRLogCall(const std::vector<std::string>& res, const std::string& name, const std::vector<std::string>& args, const std::string& time) {
void DTRLogCall(const std::vector<std::string>& res,
const std::string& name,
const std::vector<std::string>& args,
const std::vector<size_t>& constants,
const std::string& time) {
if (log_json) {
json j;
j[INSTRUCTION] = "CALL";
j[NAME] = name;
j["RESULT"] = res;
j[ARGS] = args;
j[CONSTANTS] = constants;
j[TIME] = time;
DTRLog(j.dump());
} else {
CHECK(constants.size() == 0); //TODO: implement.
std::string arg = name + "(";
for (const auto& s : arg) {
arg += s;
Expand All @@ -156,38 +178,52 @@ Tensors CheckpointTensorImpl::make(const std::string& name,
const rematerialize_function_t& remat,
const Tensors& input) {
strongs input_values;
std::vector<std::tuple<Tensor, size_t>> constants;
std::vector<size_t> constant_idx;
std::vector<std::string> args;
for (const Tensor& t: input) {
auto ft = from_tensor(t);
input_values.push_back(std::get<0>(ft));
args.push_back(std::get<1>(ft));
if (auto* cpt = dynamic_cast<CheckpointTensorImpl*>(t.unsafeGetTensorImpl())) {
input_values.push_back(cpt->ref->value);
args.push_back(cpt->counter_name());
}
else {
size_t idx = input_values.size() + constants.size();
constants.push_back({t, idx});
constant_idx.push_back(idx);
}
}
std::vector<std::string> res;
auto ret = make_raw(remat, input_values);
auto ret = make_raw(remat, input_values, constants);
Tensors tensors;
for (const Tensor& t: std::get<0>(ret)) {
auto cp = checkpoint_raw(t);
tensors.push_back(cp);
res.push_back(get_cpti(cp)->counter_name());
}
DTRLogCall(res, name, args, from_time(std::get<1>(ret)));
DTRLogCall(res, name, args, constant_idx, from_time(std::get<1>(ret)));
for (const Tensor& t: tensors) {
auto cpti = get_cpti(t);
DTRLogMemory(cpti->counter_name(), cpti->ref->value->memory());
}
return tensors;
}

void DTRLogMutate(const std::string& name, const std::vector<std::string>& args, const std::vector<size_t>& mutate, const std::string& time) {
void DTRLogMutate(const std::string& name,
const std::vector<std::string>& args,
const std::vector<size_t>& constants,
const std::vector<size_t>& mutate,
const std::string& time) {
if (log_json) {
json j;
j[INSTRUCTION] = "MUTATE";
j[NAME] = name;
j[ARGS] = args;
j[CONSTANTS] = constants;
j["MUTATE"] = mutate;
j[TIME] = time;
DTRLog(j.dump());
} else {
CHECK(constants.size() == 0); //TODO: implement.
std::string log = name;
log += "(";
for (const auto& s : args) {
Expand Down Expand Up @@ -222,18 +258,26 @@ void CheckpointTensorImpl::mutate(const std::string& name,
return new_input_values;
};
strongs input_values;
std::vector<std::tuple<Tensor, size_t>> constants;
std::vector<size_t> constant_idx;
std::vector<std::string> args;
for (const Tensor& t : inputs) {
auto ft = from_tensor(t);
args.push_back(std::get<1>(ft));
input_values.push_back(std::get<0>(ft));
for (const Tensor& t: inputs) {
if (auto* cpt = dynamic_cast<CheckpointTensorImpl*>(t.unsafeGetTensorImpl())) {
input_values.push_back(cpt->ref->value);
args.push_back(cpt->counter_name());
}
else {
size_t idx = input_values.size() + constants.size();
constants.push_back({t, idx});
constant_idx.push_back(idx);
}
}
auto ret = make_raw(remat, input_values);
auto ret = make_raw(remat, input_values, constants);
const auto& modified = std::get<0>(ret);
for (size_t idx: mutate_idx) {
cell_from_tensor(inputs[idx])->value = intrusive_ptr<CheckpointTensorCell>::make(modified[idx]);
}
DTRLogMutate(name, args, mutate_idx, from_time(std::get<1>(ret)));
DTRLogMutate(name, args, constant_idx, mutate_idx, from_time(std::get<1>(ret)));
}

void DTRLogRelease(const std::string& counter_name) {
Expand Down
9 changes: 0 additions & 9 deletions aten/src/ATen/CheckpointTensorImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -112,15 +112,6 @@ inline CheckpointTensorImpl* get_cpti(const Tensor& t) {
return cpti;
}

inline std::tuple<strong, std::string> from_tensor(const Tensor& t) {
auto* cpt = dynamic_cast<CheckpointTensorImpl*>(t.unsafeGetTensorImpl());
if(cpt != nullptr) {
return {cpt->ref->value, cpt->counter_name()};
} else {
return from_tensor(native::checkpoint(t));
}
}

inline Tensor get(const strong& s) {
return s->t;
}
Expand Down

0 comments on commit dab0aa8

Please # to comment.