diff --git a/aten/src/ATen/native/Checkpoint.cpp b/aten/src/ATen/native/Checkpoint.cpp index 184dcfcb813..be787fefa4c 100644 --- a/aten/src/ATen/native/Checkpoint.cpp +++ b/aten/src/ATen/native/Checkpoint.cpp @@ -1249,6 +1249,29 @@ checkpoint_layer_norm_backward(const Tensor& grad_out, const Tensor& input, cons return {ret[0], ret[1], ret[2]}; } +std::tuple +checkpoint_topk(const Tensor& self, long k, long dim, bool largest, bool sorted) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + auto ret = at::topk(vec.at(0), k, dim, largest, sorted); + return {std::get<0>(ret), std::get<1>(ret)}; + }; + auto ret = CheckpointTensorImpl::make("topk", rt, {self}); + return {ret[0], ret[1]}; +} + +std::tuple +checkpoint_topk_values(Tensor& values, Tensor& indices, const Tensor& self, long k, long dim, bool largest, bool sorted) { + mutate_function_t mt = + [=](const Tensors& vec) { + Tensor values_ = vec.at(0); + Tensor indices_ = vec.at(1); + at::topk_out(values_, indices_, vec.at(2), k, dim, largest, sorted); + }; + CheckpointTensorImpl::mutate("topk_values", mt, {values, indices, self}, {0, 1}); + return {values, indices}; +} + bool checkpoint_equal(const Tensor& self, const Tensor& other) { // there can't possibly be a reason to rematerialize // a single bool so we'll just compute it now diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 89013ef1a94..90032615686 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -5343,6 +5343,7 @@ dispatch: CPU: topk_out_cpu CUDA: legacy::cuda::_th_topk_out + Checkpoint: checkpoint_topk_values - func: topk(Tensor self, int k, int dim=-1, bool largest=True, bool sorted=True) -> (Tensor values, Tensor indices) variants: method, function @@ -5350,6 +5351,7 @@ CPU: topk CUDA: topk QuantizedCPU: quantized_topk_cpu + Checkpoint: checkpoint_topk - func: all(Tensor self) -> Tensor use_c10_dispatcher: full