Skip to content

Commit 60df888

Browse files
committed
feat(prim::NumToTensor): Implement evaluator for NumToTensor
Signed-off-by: Naren Dasan <naren@narendasan.com> Signed-off-by: Naren Dasan <narens@nvidia.com>
1 parent 17099fa commit 60df888

File tree

1 file changed

+6
-0
lines changed

1 file changed

+6
-0
lines changed

core/conversion/evaluators/prim.cpp

+6
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#include "ATen/core/List.h"
66
#include "ATen/core/stack.h"
77
#include "c10/util/intrusive_ptr.h"
8+
#include "torch/torch.h"
89

910
#include "core/conversion/evaluators/evaluators.h"
1011

@@ -23,6 +24,11 @@ auto prim_registrations = RegisterNodeEvaluators()
2324
}
2425
return torch::jit::toIValue(n->output());
2526
}
27+
}).evaluator({
28+
torch::jit::prim::NumToTensor,
29+
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
30+
return at::scalar_to_tensor(args.at(&(n->output()[0])).IValue()->toScalar());
31+
}
2632
}).evaluator({
2733
torch::jit::prim::ListConstruct,
2834
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {

0 commit comments

Comments
 (0)