Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Fix logger #19

Merged
merged 2 commits into from
Apr 15, 2020
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 56 additions & 17 deletions aten/src/ATen/CheckpointTensorImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ const std::string ARGS = "ARGS";
const std::string MEMORY = "MEMORY";
const std::string NAME = "NAME";
const std::string CONSTANT = "CONSTANT";
const std::string CONSTANTS = "CONSTANTS";

void DTRLogConstant(const std::string& name) {
if (log_json) {
Expand Down Expand Up @@ -107,11 +108,21 @@ Tensor checkpoint_raw(const Tensor& t) {
}

std::tuple<Tensors, duration_t> make_raw(const rematerialize_function_t& remat,
const strongs& input_values) {
const strongs& input_values,
const std::vector<std::tuple<Tensor, size_t>>& constants) {
slyubomirsky marked this conversation as resolved.
Show resolved Hide resolved
std::vector<Tensor> input;
for (const strong& s: input_values) {
CHECK(!s->t.key_set().has(DispatchKey::CheckpointTensorId));
input.push_back(s->t);
size_t i = 0, j = 0;
while (i != input_values.size() || j != constants.size()) {
if (j < constants.size() && std::get<1>(constants[j]) == input.size()) {
input.push_back(std::get<0>(constants[j]));
++j;
}
else {
CHECK(i < input_values.size());
CHECK(!input_values[i]->t.key_set().has(DispatchKey::CheckpointTensorId));
input.push_back(input_values[i]->t);
++i;
}
}
time_t pre = std::chrono::system_clock::now();
auto output = remat(input);
Expand All @@ -123,16 +134,22 @@ std::string from_time(duration_t t) {
return std::to_string(std::chrono::nanoseconds(t).count());
}

void DTRLogCall(const std::vector<std::string>& res, const std::string& name, const std::vector<std::string>& args, const std::string& time) {
void DTRLogCall(const std::vector<std::string>& res,
const std::string& name,
const std::vector<std::string>& args,
const std::vector<size_t>& constants,
const std::string& time) {
if (log_json) {
json j;
j[INSTRUCTION] = "CALL";
j[NAME] = name;
j["RESULT"] = res;
j[ARGS] = args;
j[CONSTANTS] = constants;
j[TIME] = time;
DTRLog(j.dump());
} else {
CHECK(constants.size() == 0); //TODO: implement.
std::string arg = name + "(";
for (const auto& s : arg) {
arg += s;
Expand All @@ -156,38 +173,52 @@ Tensors CheckpointTensorImpl::make(const std::string& name,
const rematerialize_function_t& remat,
const Tensors& input) {
strongs input_values;
std::vector<std::tuple<Tensor, size_t>> constants;
std::vector<size_t> constant_idx;
std::vector<std::string> args;
for (const Tensor& t: input) {
auto ft = from_tensor(t);
input_values.push_back(std::get<0>(ft));
args.push_back(std::get<1>(ft));
if (auto* cpt = dynamic_cast<CheckpointTensorImpl*>(t.unsafeGetTensorImpl())) {
input_values.push_back(cpt->ref->value);
args.push_back(cpt->counter_name());
}
else {
size_t idx = input_values.size() + constants.size();
constants.push_back({t, idx});
constant_idx.push_back(idx);
}
}
std::vector<std::string> res;
auto ret = make_raw(remat, input_values);
auto ret = make_raw(remat, input_values, constants);
Tensors tensors;
for (const Tensor& t: std::get<0>(ret)) {
auto cp = checkpoint_raw(t);
tensors.push_back(cp);
res.push_back(get_cpti(cp)->counter_name());
}
DTRLogCall(res, name, args, from_time(std::get<1>(ret)));
DTRLogCall(res, name, args, constant_idx, from_time(std::get<1>(ret)));
for (const Tensor& t: tensors) {
auto cpti = get_cpti(t);
DTRLogMemory(cpti->counter_name(), cpti->ref->value->memory());
}
return tensors;
}

void DTRLogMutate(const std::string& name, const std::vector<std::string>& args, const std::vector<size_t>& mutate, const std::string& time) {
void DTRLogMutate(const std::string& name,
const std::vector<std::string>& args,
const std::vector<size_t>& constants,
const std::vector<size_t>& mutate,
const std::string& time) {
if (log_json) {
json j;
j[INSTRUCTION] = "MUTATE";
j[NAME] = name;
j[ARGS] = args;
j[CONSTANTS] = constants;
j["MUTATE"] = mutate;
j[TIME] = time;
DTRLog(j.dump());
} else {
CHECK(constants.size() == 0); //TODO: implement.
std::string log = name;
log += "(";
for (const auto& s : args) {
Expand Down Expand Up @@ -222,18 +253,26 @@ void CheckpointTensorImpl::mutate(const std::string& name,
return new_input_values;
};
strongs input_values;
std::vector<std::tuple<Tensor, size_t>> constants;
std::vector<size_t> constant_idx;
std::vector<std::string> args;
for (const Tensor& t : inputs) {
auto ft = from_tensor(t);
args.push_back(std::get<1>(ft));
input_values.push_back(std::get<0>(ft));
for (const Tensor& t: inputs) {
if (auto* cpt = dynamic_cast<CheckpointTensorImpl*>(t.unsafeGetTensorImpl())) {
input_values.push_back(cpt->ref->value);
args.push_back(cpt->counter_name());
}
else {
size_t idx = input_values.size() + constants.size();
constants.push_back({t, idx});
constant_idx.push_back(idx);
}
}
auto ret = make_raw(remat, input_values);
auto ret = make_raw(remat, input_values, constants);
const auto& modified = std::get<0>(ret);
for (size_t idx: mutate_idx) {
cell_from_tensor(inputs[idx])->value = intrusive_ptr<CheckpointTensorCell>::make(modified[idx]);
}
DTRLogMutate(name, args, mutate_idx, from_time(std::get<1>(ret)));
DTRLogMutate(name, args, constant_idx, mutate_idx, from_time(std::get<1>(ret)));
}

void DTRLogRelease(const std::string& counter_name) {
Expand Down
9 changes: 0 additions & 9 deletions aten/src/ATen/CheckpointTensorImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -112,15 +112,6 @@ inline CheckpointTensorImpl* get_cpti(const Tensor& t) {
return cpti;
}

inline std::tuple<strong, std::string> from_tensor(const Tensor& t) {
auto* cpt = dynamic_cast<CheckpointTensorImpl*>(t.unsafeGetTensorImpl());
if(cpt != nullptr) {
return {cpt->ref->value, cpt->counter_name()};
} else {
return from_tensor(native::checkpoint(t));
}
}

inline Tensor get(const strong& s) {
return s->t;
}
Expand Down