From 3008ba5c223dcaf9bb997a976a6721baacaf09e5 Mon Sep 17 00:00:00 2001 From: "Steven S. Lyubomirsky" Date: Fri, 17 Apr 2020 20:09:38 -0700 Subject: [PATCH] Overloads for mean and mean.dim (needed for densenet) --- aten/src/ATen/native/Checkpoint.cpp | 17 +++++++++++++++++ aten/src/ATen/native/native_functions.yaml | 2 ++ 2 files changed, 19 insertions(+) diff --git a/aten/src/ATen/native/Checkpoint.cpp b/aten/src/ATen/native/Checkpoint.cpp index 4a00913fe1e..cfa7ecfaf68 100644 --- a/aten/src/ATen/native/Checkpoint.cpp +++ b/aten/src/ATen/native/Checkpoint.cpp @@ -746,6 +746,23 @@ Tensor checkpoint_repeat(const at::Tensor& a, c10::ArrayRef b) { return CheckpointTensorImpl::make("repeat", rt, {a})[0]; } +Tensor checkpoint_mean(const Tensor& self, c10::optional dtype) { + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::native::mean_cpu_gpu(vec[0], dtype)}; + }; + return CheckpointTensorImpl::make("mean", rt, {self})[0]; +} + +Tensor checkpoint_mean(const Tensor& self, IntArrayRef dim, bool keepdim, c10::optional dtype) { + std::vector dim_ = dim.vec(); + rematerialize_function_t rt = + [=](const Tensors& vec) -> Tensors { + return {at::native::mean_cpu_gpu(vec[0], dim_, keepdim, dtype)}; + }; + return CheckpointTensorImpl::make("mean.dim", rt, {self})[0]; +} + Tensor checkpoint__cat(c10::ArrayRef a, long b) { rematerialize_function_t rt = [=](const Tensors& vec) -> Tensors { diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 04ff2643319..4dcea3d4e6e 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -1848,6 +1848,7 @@ CPU: mean_cpu_gpu CUDA: mean_cpu_gpu QuantizedCPU: quantized_mean_cpu + Checkpoint: checkpoint_mean - func: mean.dim(Tensor self, int[1] dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor variants: function, method @@ -1856,6 +1857,7 @@ CPU: mean_cpu_gpu CUDA: mean_cpu_gpu QuantizedCPU: quantized_mean_cpu + Checkpoint: checkpoint_mean - func: mean.out(Tensor self, int[1] dim, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!) supports_named_tensor: True