forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcustom_operator.h
59 lines (47 loc) · 2.06 KB
/
custom_operator.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
#pragma once
#include <torch/csrc/jit/operator.h>
#include <ATen/core/stack.h>
#include <ATen/core/op_registration/op_registration.h>
namespace torch {
namespace jit {
/// Registration class for new operators. Effectively calls
/// `torch::jit::registerOperator` for every supplied operator, but allows doing
/// so in the global scope when a `RegisterOperators` object is assigned to a
/// static variable. Also handles registration of user-defined, "custom"
/// operators.
struct TORCH_API RegisterOperators {
RegisterOperators() = default;
/// Registers a vector of already created `Operator`s.
RegisterOperators(std::vector<Operator> operators) {
for (Operator& o : operators) {
registerOperator(std::move(o));
}
}
/// Calls `op(...)` with the given operator name and implementation.
template <typename Implementation>
C10_DEPRECATED_MESSAGE("torch::jit::RegisterOperators is deprecated. Please use torch::RegisterOperators instead.")
RegisterOperators(const std::string& name, Implementation&& implementation) {
op_(name, std::forward<Implementation>(implementation));
}
template <typename Implementation>
C10_DEPRECATED_MESSAGE("torch::jit::RegisterOperators is deprecated. Please use torch::RegisterOperators instead.")
RegisterOperators& op(
const std::string& name,
Implementation&& implementation) {
op_(name, std::forward<Implementation>(implementation));
return *this;
}
private:
template <typename Implementation>
void op_(const std::string& name, Implementation&& implementation) {
registrars_.emplace_back(std::make_shared<c10::RegisterOperators>(name, std::forward<Implementation>(implementation)));
}
// A c10::RegisterOperators instance is not copyable, so to make
// torch::jit::RegisterOperators copyable, we use shared_ptrs.
// We need to keep the c10::RegisterOperators instances around
// because this is an RAII pattern. In the destructor, the registered
// ops get de-registered.
std::vector<std::shared_ptr<c10::RegisterOperators>> registrars_;
};
} // namespace jit
} // namespace torch