forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathregistry.h
46 lines (35 loc) · 1.2 KB
/
registry.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
#pragma once
#include <c10/util/Exception.h>
#include <c10/util/Registry.h>
namespace torch {
namespace jit {
namespace mobile {
namespace nnc {
using nnc_kernel_function_type = int(void**);
struct TORCH_API NNCKernel {
virtual ~NNCKernel() = default;
virtual int execute(void** /* args */) = 0;
};
C10_DECLARE_REGISTRY(NNCKernelRegistry, NNCKernel);
#define REGISTER_NNC_KERNEL(id, kernel, ...) \
extern "C" { \
nnc_kernel_function_type kernel; \
} \
struct NNCKernel_##kernel : public NNCKernel { \
int execute(void** args) override { \
return kernel(args); \
} \
}; \
C10_REGISTER_TYPED_CLASS(NNCKernelRegistry, id, NNCKernel_##kernel);
namespace registry {
inline bool has_nnc_kernel(const std::string& id) {
return NNCKernelRegistry()->Has(id);
}
inline std::unique_ptr<NNCKernel> get_nnc_kernel(const std::string& id) {
return NNCKernelRegistry()->Create(id);
}
} // namespace registry
} // namespace nnc
} // namespace mobile
} // namespace jit
} // namespace torch