diff --git a/aten/src/ATen/native/Checkpoint.cpp b/aten/src/ATen/native/Checkpoint.cpp index 4a00913fe1e..cfa7ecfaf68 100644 --- a/aten/src/ATen/native/Checkpoint.cpp +++ b/aten/src/ATen/native/Checkpoint.cpp @@ -746,6 +746,23 @@ Tensor checkpoint_repeat(const at::Tensor& a, c10::ArrayRef b) { return CheckpointTensorImpl::make("repeat", rt, {a})[0]; } +Tensor checkpoint_mean(const Tensor& self, c10::optional dtype) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::native::mean_cpu_gpu(vec[0], dtype)}; + }; + return CheckpointTensorImpl::make("mean", rt, {self})[0]; +} + +Tensor checkpoint_mean(const Tensor& self, IntArrayRef dim, bool keepdim, c10::optional dtype) { + std::vector dim_ = dim.vec(); + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::native::mean_cpu_gpu(vec[0], dim_, keepdim, dtype)}; + }; + return CheckpointTensorImpl::make("mean.dim", rt, {self})[0]; +} + Tensor checkpoint__cat(c10::ArrayRef a, long b) { rematerialize_function_t rt = [=](const Tensors& vec) -> Tensors { diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 04ff2643319..4dcea3d4e6e 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -1848,6 +1848,7 @@ CPU: mean_cpu_gpu CUDA: mean_cpu_gpu QuantizedCPU: quantized_mean_cpu + Checkpoint: checkpoint_mean - func: mean.dim(Tensor self, int[1] dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor variants: function, method @@ -1856,6 +1857,7 @@ CPU: mean_cpu_gpu CUDA: mean_cpu_gpu QuantizedCPU: quantized_mean_cpu + Checkpoint: checkpoint_mean - func: mean.out(Tensor self, int[1] dim, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!) supports_named_tensor: True