Skip to content

Commit

Permalink
Merge pull request #22 from slyubomirsky/densenet-overload
Browse files Browse the repository at this point in the history
Overloads for mean and mean.dim
  • Loading branch information
MarisaKirisame authored Apr 18, 2020
2 parents 0680ddb + 3008ba5 commit 97c7265
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 0 deletions.
17 changes: 17 additions & 0 deletions aten/src/ATen/native/Checkpoint.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -746,6 +746,23 @@ Tensor checkpoint_repeat(const at::Tensor& a, c10::ArrayRef<long> b) {
return CheckpointTensorImpl::make("repeat", rt, {a})[0];
}

Tensor checkpoint_mean(const Tensor& self, c10::optional<c10::ScalarType> dtype) {
rematerialize_function_t rt =
[=](const Tensors& vec) -> Tensors {
return {at::native::mean_cpu_gpu(vec[0], dtype)};
};
return CheckpointTensorImpl::make("mean", rt, {self})[0];
}

Tensor checkpoint_mean(const Tensor& self, IntArrayRef dim, bool keepdim, c10::optional<c10::ScalarType> dtype) {
std::vector<long> dim_ = dim.vec();
rematerialize_function_t rt =
[=](const Tensors& vec) -> Tensors {
return {at::native::mean_cpu_gpu(vec[0], dim_, keepdim, dtype)};
};
return CheckpointTensorImpl::make("mean.dim", rt, {self})[0];
}

Tensor checkpoint__cat(c10::ArrayRef<Tensor> a, long b) {
rematerialize_function_t rt =
[=](const Tensors& vec) -> Tensors {
Expand Down
2 changes: 2 additions & 0 deletions aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1848,6 +1848,7 @@
CPU: mean_cpu_gpu
CUDA: mean_cpu_gpu
QuantizedCPU: quantized_mean_cpu
Checkpoint: checkpoint_mean

- func: mean.dim(Tensor self, int[1] dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor
variants: function, method
Expand All @@ -1856,6 +1857,7 @@
CPU: mean_cpu_gpu
CUDA: mean_cpu_gpu
QuantizedCPU: quantized_mean_cpu
Checkpoint: checkpoint_mean

- func: mean.out(Tensor self, int[1] dim, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)
supports_named_tensor: True
Expand Down

0 comments on commit 97c7265

Please # to comment.