-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? # to your account
[Paddle Inference] Add add eye trt converter #48937
Conversation
你的PR提交成功,感谢你对开源项目的贡献! |
// Declare inputs attr | ||
const int num_rows = PADDLE_GET_CONST(int, op_desc.GetAttr("num_rows")); | ||
int num_columns = PADDLE_GET_CONST(int, op_desc.GetAttr("num_columns")); | ||
const int dtype = PADDLE_GET_CONST(int, op_desc.GetAttr("dtype")); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
可以改成
auto dtype = static_cast<framework::proto::VarType::Type>(
PADDLE_GET_CONST(int, op_desc.GetAttr("dtype")));
if (-1 == num_columns) { | ||
input_shape.d[1] = num_rows; | ||
num_columns = num_rows; | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
num_columns为默认值-1时,num_columns=num_rows, 这段逻辑可以放在input_shape.d赋值前
} | ||
|
||
std::vector<T> constant_arr(num_rows * num_columns, 0); | ||
for (int i = 0; i < std::min(num_rows, num_columns); i++) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
std::min(num_rows, num_columns) 放循环外避免多次调用std::min
nvinfer1::DataType nv_type = nvinfer1::DataType::kFLOAT; | ||
switch (dtype) { | ||
case paddle::framework::proto::VarType::FP32: | ||
nv_type = nvinfer1::DataType::kFLOAT; | ||
typedef float T; | ||
break; | ||
case paddle::framework::proto::VarType::FP16: | ||
nv_type = nvinfer1::DataType::kHALF; | ||
typedef uint16_t T; | ||
break; | ||
case paddle::framework::proto::VarType::INT32: | ||
nv_type = nvinfer1::DataType::kINT32; | ||
typedef int32_t T; | ||
break; | ||
default: | ||
paddle::platform::errors::InvalidArgument( | ||
"Paddle-TRT loads weighths failed, found not supported data type " | ||
"%s.", | ||
dtype); | ||
break; | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里T类型声明编译报错,可以参考convert/fill_constant_op.cc 中写法,进行不同类型数组赋值
for _ in range(6): | ||
if np.random.random() > 0.5: | ||
num_rows = generate_input_attr1() | ||
attr_dic = {"num_rows": num_rows} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
可以加一个num_columns为-1。两处attr_dic中添加一个dtype字段,37-38行之间添加一层for循环用于dtype赋值,
for dtype in [2, 4, 5]
@zhangjun 打扰了,问一下这个 |
很抱歉,经过我们的反复讨论,你的PR暂未达到合入标准,请阅读飞桨原生算子开发规范,你可以重新提交新的PR,我们先将此PR关闭,感谢你的贡献。 |
PR types
Others
PR changes
Others
Describe
#48292