From 43a712729c96b4c6a10458001b3079040cb6486e Mon Sep 17 00:00:00 2001 From: FlyingQianMM <245467267@qq.com> Date: Fri, 26 Aug 2022 19:29:52 +0800 Subject: [PATCH] Add python deployment for squeezesegv3 --- deploy/squeezesegv3/python/infer.py | 164 ++++++++++++++++++++++++++++ docs/models/squeezesegv3/README.md | 40 +++++++ 2 files changed, 204 insertions(+) create mode 100644 deploy/squeezesegv3/python/infer.py diff --git a/deploy/squeezesegv3/python/infer.py b/deploy/squeezesegv3/python/infer.py new file mode 100644 index 00000000..5a6e5d84 --- /dev/null +++ b/deploy/squeezesegv3/python/infer.py @@ -0,0 +1,164 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse + +import cv2 +import numpy as np +import paddle +from paddle.inference import Config, create_predictor + +from paddle3d import transforms as T +from paddle3d.sample import Sample +from paddle3d.transforms.normalize import NormalizeRangeImage +from paddle3d.transforms.reader import LoadSemanticKITTIRange + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--model_file", + type=str, + help="Model filename, Specify this when your model is a combined model.", + required=True) + parser.add_argument( + "--params_file", + type=str, + help= + "Parameter filename, Specify this when your model is a combined model.", + required=True) + parser.add_argument( + '--lidar_file', type=str, help='The lidar path.', required=True) + parser.add_argument( + '--img_mean', + type=str, + help='The mean value of range-view image.', + required=True) + parser.add_argument( + '--img_std', + type=str, + help='The variance value of range-view image.', + required=True) + parser.add_argument("--gpu_id", type=int, default=0, help="GPU card id.") + parser.add_argument( + "--use_trt", + type=int, + default=0, + help="Whether to use tensorrt to accelerate when using gpu.") + parser.add_argument( + "--trt_precision", + type=int, + default=0, + help="Precision type of tensorrt, 0: kFloat32, 1: kHalf.") + parser.add_argument( + "--trt_use_static", + type=int, + default=0, + help="Whether to load the tensorrt graph optimization from a disk path." + ) + parser.add_argument( + "--trt_static_dir", + type=str, + help="Path of a tensorrt graph optimization directory.") + + return parser.parse_args() + + +def preprocess(file_path, img_mean, img_std): + if isinstance(img_mean, str): + img_mean = eval(img_mean) + if isinstance(img_std, str): + img_std = eval(img_std) + + sample = Sample(path=file_path, modality="lidar") + + transforms = T.Compose([ + LoadSemanticKITTIRange(project_label=False), + NormalizeRangeImage(mean=img_mean, std=img_std) + ]) + + sample = transforms(sample) + + if "proj_mask" in sample.meta: + sample.data *= sample.meta.pop("proj_mask") + return np.expand_dims(sample.data, + 0), sample.meta.proj_x, sample.meta.proj_y + + +def init_predictor(model_file, + params_file, + gpu_id=0, + use_trt=False, + trt_precision=0, + trt_use_static=False, + trt_static_dir=None): + config = Config(model_file, params_file) + config.enable_memory_optim() + config.enable_use_gpu(1000, gpu_id) + if use_trt: + precision_mode = paddle.inference.PrecisionType.Float32 + if trt_precision == 1: + precision_mode = paddle.inference.PrecisionType.Half + config.enable_tensorrt_engine( + workspace_size=1 << 20, + max_batch_size=1, + min_subgraph_size=3, + precision_mode=precision_mode, + use_static=trt_use_static, + use_calib_mode=False) + if trt_use_static: + config.set_optim_cache_dir(trt_static_dir) + + predictor = create_predictor(config) + return predictor + + +def run(predictor, points): + # copy img data to input tensor + input_names = predictor.get_input_names() + input_tensor = predictor.get_input_handle(input_names[0]) + input_tensor.reshape(points.shape) + input_tensor.copy_from_cpu(points.copy()) + + # do the inference + predictor.run() + + results = [] + # get out data from output tensor + output_names = predictor.get_output_names() + output_tensor = predictor.get_output_handle(output_names[0]) + pred_label = output_tensor.copy_to_cpu() + + return pred_label[0] + + +def postprocess(pred_img_label, proj_x, proj_y): + return pred_img_label[proj_y, proj_x] + + +def main(args): + predictor = init_predictor(args.model_file, args.params_file, args.gpu_id, + args.use_trt, args.trt_precision, + args.trt_use_static, args.trt_static_dir) + range_img, proj_x, proj_y = preprocess(args.lidar_file, args.img_mean, + args.img_std) + pred_img_label = run(predictor, range_img) + pred_point_label = postprocess(pred_img_label, proj_x, proj_y) + return pred_point_label + + +if __name__ == '__main__': + args = parse_args() + + main(args) diff --git a/docs/models/squeezesegv3/README.md b/docs/models/squeezesegv3/README.md index 7b72fecd..433f28f4 100644 --- a/docs/models/squeezesegv3/README.md +++ b/docs/models/squeezesegv3/README.md @@ -10,6 +10,7 @@ * [训练](#h3-id52h3) * [评估](#h3-id53h3) * [模型导出](#h3-id54h3) + * [模型部署](#h3-id55h3) ##

引用

@@ -125,3 +126,42 @@ python tools/export.py \ | model | 待导出模型参数`model.pdparams`路径 | 是 | - | | input_shape | 指定模型的输入尺寸,支持`N, C, H, W`或`H, W`格式 | 是 | - | | save_dir | 保存导出模型的路径,`save_dir`下将会生成三个文件:`squeezesegv3.pdiparams `、`squeezesegv3.pdiparams.info`和`squeezesegv3.pdmodel` | 否 | `deploy` | + + + +###

模型部署

+ +#### C++部署 + +Coming soon... + +#### Python部署 + +命令参数说明如下: + +| 参数 | 说明 | +| -- | -- | +| model_file | 导出模型的结构文件`squeezesegv3.pdmodel`所在路径 | +| params_file | 导出模型的参数文件`squeezesegv3.pdiparams`所在路径 | +| lidar_file | 待预测的点云文件所在路径 | +| img_mean | 点云投影到range-view后所成图像的均值,例如为`12.12,10.88,0.23,-1.04,0.21` | +| img_std | 点云投影到range-view后所成图像的方差,例如为`12.32,11.47,6.91,0.86,0.16` | +| use_trt | 是否使用TensorRT进行加速,默认0| +| trt_precision | 当use_trt设置为1时,模型精度可设置0或1,0表示fp32, 1表示fp16。默认0 | +| trt_use_static | 当trt_use_static设置为1时,**在首次运行程序的时候会将TensorRT的优化信息进行序列化到磁盘上,下次运行时直接加载优化的序列化信息而不需要重新生成**。默认0 | +| trt_static_dir | 当trt_use_static设置为1时,保存优化信息的路径 | + + +运行以下命令,执行预测: + +``` +python infer.py --model_file /path/to/squeezesegv3.pdmodel --params_file /path/to/squeezesegv3.pdiparams --lidar_file /path/to/lidar.pcd.bin --img_mean 12.12,10.88,0.23,-1.04,0.21 --img_std 12.32,11.47,6.91,0.86,0.16 +``` + +如果要开启TensorRT的话,请卸载掉原有的`paddlepaddel_gpu`,至[Paddle官网](https://paddleinference.paddlepaddle.org.cn/user_guides/download_lib.html#python)下载与TensorRT连编的预编译Paddle Inferece安装包,选择符合本地环境CUDA/cuDNN/TensorRT版本的安装包完成安装即可。 + +运行以下命令,开启TensorRT加速模型预测: + +``` +python infer.py --model_file /path/to/squeezesegv3.pdmodel --params_file /path/to/squeezesegv3.pdiparams --lidar_file /path/to/lidar.pcd.bin --img_mean 12.12,10.88,0.23,-1.04,0.21 --img_std 12.32,11.47,6.91,0.86,0.16 --use_trt 1 +```