Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

meet deploy problem when using Group Fisher in rtmdet with custom data #494

Closed
zhubochao opened this issue Apr 4, 2023 · 6 comments
Closed
Assignees
Labels
bug Something isn't working Pruning

Comments

@zhubochao
Copy link

Checklist

  • I have searched related issues but cannot get the expected help.
  • I have read related documents and don't know what to do.

Describe the question you meet

I use Group Fisher to prune the rtmdet, and when deploying the pruned model with this scripts, it meets the ERROR:Exception: Forward failed, there may be an error in demo input.here I just want to export the onnx or trt model.

Post related information

  1. The output of pip list | grep "mmcv\|mmrazor\|^torch"
    mmcv 2.0.0rc4
    mmrazor 1.0.0rc2 /home/yaojiahao/zbc/mmrazor
    torch 1.13.0+cu116
    torchaudio 0.13.0+cu116
    torchvision 0.14.0+cu116
  2. Your config file if you modified it or created a new one.
_base_ = 'mmyolo::rtmdet/rtmdet_s_syncbn_fast_1xb64-300e_fs_4k.py'
fix_subnet = {
    "backbone.stem.0.conv_(0, 16)_16": 13,
    "backbone.stem.1.conv_(0, 16)_16": 16,
    "backbone.stem.2.conv_(0, 32)_32": 32,
    "backbone.stage1.0.conv_(0, 64)_64": 63,
    "backbone.stage1.1.short_conv.conv_(0, 32)_32": 32,
    "backbone.stage1.1.main_conv.conv_(0, 32)_32": 32,
    "backbone.stage1.1.blocks.0.conv1.conv_(0, 32)_32": 31,
    "backbone.stage1.1.final_conv.conv_(0, 64)_64": 64,
    "backbone.stage2.0.conv_(0, 128)_128": 128,
    "backbone.stage2.1.short_conv.conv_(0, 64)_64": 64,
    "backbone.stage2.1.main_conv.conv_(0, 64)_64": 64,
    "backbone.stage2.1.blocks.0.conv1.conv_(0, 64)_64": 64,
    "backbone.stage2.1.blocks.1.conv1.conv_(0, 64)_64": 63,
    "backbone.stage2.1.final_conv.conv_(0, 128)_128": 128,
    "backbone.stage3.0.conv_(0, 256)_256": 256,
    "backbone.stage3.1.short_conv.conv_(0, 128)_128": 127,
    "backbone.stage3.1.main_conv.conv_(0, 128)_128": 128,
    "backbone.stage3.1.blocks.0.conv1.conv_(0, 128)_128": 128,
    "backbone.stage3.1.blocks.1.conv1.conv_(0, 128)_128": 128,
    "backbone.stage3.1.final_conv.conv_(0, 256)_256": 256,
    "backbone.stage4.0.conv_(0, 512)_512": 512,
    "backbone.stage4.1.conv1.conv_(0, 256)_256": 256,
    "backbone.stage4.1.conv2.conv_(0, 512)_512": 512,
    "backbone.stage4.2.short_conv.conv_(0, 256)_256": 256,
    "backbone.stage4.2.main_conv.conv_(0, 256)_256": 256,
    "backbone.stage4.2.blocks.0.conv1.conv_(0, 256)_256": 256,
    "backbone.stage4.2.blocks.0.conv2.pointwise_conv.conv_(0, 256)_256": 256,
    "backbone.stage4.2.final_conv.conv_(0, 512)_512": 511,
    "neck.reduce_layers.2.conv_(0, 256)_256": 256,
    "neck.top_down_layers.0.0.short_conv.conv_(0, 128)_128": 128,
    "neck.top_down_layers.0.0.main_conv.conv_(0, 128)_128": 128,
    "neck.top_down_layers.0.0.blocks.0.conv1.conv_(0, 128)_128": 128,
    "neck.top_down_layers.0.0.blocks.0.conv2.pointwise_conv.conv_(0, 128)_128": 128,
    "neck.top_down_layers.0.0.final_conv.conv_(0, 256)_256": 256,
    "neck.top_down_layers.0.1.conv_(0, 128)_128": 128,
    "neck.top_down_layers.1.short_conv.conv_(0, 64)_64": 64,
    "neck.top_down_layers.1.main_conv.conv_(0, 64)_64": 62,
    "neck.top_down_layers.1.blocks.0.conv1.conv_(0, 64)_64": 61,
    "neck.top_down_layers.1.blocks.0.conv2.pointwise_conv.conv_(0, 64)_64": 64,
    "neck.top_down_layers.1.final_conv.conv_(0, 128)_128": 127,
    "neck.downsample_layers.0.conv_(0, 128)_128": 128,
    "neck.bottom_up_layers.0.short_conv.conv_(0, 128)_128": 128,
    "neck.bottom_up_layers.0.main_conv.conv_(0, 128)_128": 128,
    "neck.bottom_up_layers.0.blocks.0.conv1.conv_(0, 128)_128": 128,
    "neck.bottom_up_layers.0.blocks.0.conv2.pointwise_conv.conv_(0, 128)_128": 128,
    "neck.bottom_up_layers.0.final_conv.conv_(0, 256)_256": 254,
    "neck.downsample_layers.1.conv_(0, 256)_256": 255,
    "neck.bottom_up_layers.1.short_conv.conv_(0, 256)_256": 255,
    "neck.bottom_up_layers.1.main_conv.conv_(0, 256)_256": 256,
    "neck.bottom_up_layers.1.blocks.0.conv1.conv_(0, 256)_256": 255,
    "neck.bottom_up_layers.1.blocks.0.conv2.pointwise_conv.conv_(0, 256)_256": 256,
    "neck.bottom_up_layers.1.final_conv.conv_(0, 512)_512": 406
}
divisor = 16

##############################################################################

architecture = _base_.model

model = dict(
    _delete_=True,
    _scope_='mmrazor',
    type='GroupFisherDeploySubModel',
    architecture=architecture,
    fix_subnet=fix_subnet,
    divisor=divisor,
)
  1. Your train log file if you meet the problem during training.

/home/yaojiahao/miniconda3/envs/mmyolo/lib/python3.9/site-packages/mmdet/utils/setup_env.py:83: UserWarning: The current default scope "mmyolo" is not "mmdet", register_all_modules will force the currentdefault scope to be "mmdet". If this is not expected, please set init_default_scope=False.
warnings.warn('The current default scope '
/home/yaojiahao/zbc/mmrazor/mmrazor/utils/setup_env.py:77: UserWarning: The current default scope "mmdet" is not "mmrazor", register_all_modules will force the currentdefault scope to be "mmrazor". If this is not expected, please set init_default_scope=False.
warnings.warn(
/home/yaojiahao/miniconda3/envs/mmyolo/lib/python3.9/site-packages/mmdet/utils/setup_env.py:83: UserWarning: The current default scope "mmrazor" is not "mmdet", register_all_modules will force the currentdefault scope to be "mmdet". If this is not expected, please set init_default_scope=False.
warnings.warn('The current default scope '
/home/yaojiahao/miniconda3/envs/mmyolo/lib/python3.9/site-packages/mmdet/utils/setup_env.py:83: UserWarning: The current default scope "mmyolo" is not "mmdet", register_all_modules will force the currentdefault scope to be "mmdet". If this is not expected, please set init_default_scope=False.
warnings.warn('The current default scope '
/home/yaojiahao/miniconda3/envs/mmyolo/lib/python3.9/site-packages/torch/fx/_symbolic_trace.py:564: UserWarning: Was not able to add assertion to guarantee correct input data_samples to specialized function. It is up to the user to make sure that your inputs match the inputs you specialized the function with.
warnings.warn(
/home/yaojiahao/miniconda3/envs/mmyolo/lib/python3.9/site-packages/torch/functional.py:504: UserWarning: torch.meshgrid: in an upcoming release, it will be required to pass the indexing argument. (Triggered internally at ../aten/src/ATen/native/TensorShape.cpp:3190.)
return _VF.meshgrid(tensors, **kwargs) # type: ignore[attr-defined]
Process Process-2:
Traceback (most recent call last):
File "/home/yaojiahao/miniconda3/envs/mmyolo/lib/python3.9/multiprocessing/process.py", line 315, in _bootstrap
self.run()
File "/home/yaojiahao/miniconda3/envs/mmyolo/lib/python3.9/multiprocessing/process.py", line 108, in run
self._target(*self._args, **self._kwargs)
File "/home/yaojiahao/zbc/mmdeploy/mmdeploy/apis/core/pipeline_manager.py", line 107, in call
ret = func(*args, **kwargs)
File "/home/yaojiahao/zbc/mmdeploy/mmdeploy/apis/pytorch2onnx.py", line 63, in torch2onnx
torch_model = task_processor.build_pytorch_model(model_checkpoint)
File "/home/yaojiahao/zbc/mmdeploy/mmdeploy/codebase/mmrazor/deploy/mmrazor.py", line 133, in build_pytorch_model
model.post_process_for_mmdeploy()
File "/home/yaojiahao/zbc/mmrazor/mmrazor/implementations/pruning/group_fisher/prune_deploy_sub_model.py", line 19, in post_process_for_mmdeploy
s = make_channel_divisible(model, divisor=divisor)
File "/home/yaojiahao/zbc/mmrazor/mmrazor/models/utils/expandable_utils/tools.py", line 76, in make_channel_divisible
mutator = to_expandable_model(model)
File "/home/yaojiahao/zbc/mmrazor/mmrazor/models/utils/expandable_utils/tools.py", line 23, in to_expandable_model
mutator.prepare_from_supernet(model)
File "/home/yaojiahao/zbc/mmrazor/mmrazor/models/mutators/channel_mutator/channel_mutator.py", line 110, in prepare_from_supernet
units = self._prepare_from_tracer(supernet, self.parse_cfg)
File "/home/yaojiahao/zbc/mmrazor/mmrazor/models/mutators/channel_mutator/channel_mutator.py", line 307, in _prepare_from_tracer
unit_configs = tracer.analyze(model)
File "/home/yaojiahao/zbc/mmrazor/mmrazor/models/task_modules/tracer/channel_analyzer.py", line 127, in analyze
return self._find_mutable_units(model, unit_configs)
File "/home/yaojiahao/zbc/mmrazor/mmrazor/models/task_modules/tracer/channel_analyzer.py", line 167, in _find_mutable_units
raise Exception(
Exception: Forward failed, there may be an error in demo input.
5. Other code you modified in the mmrazor folder.

change the group_fisher_act_{action}_{model}_coco.py

@zhubochao
Copy link
Author

Fine, change the divisor to 1 can solve this problem but leads to another problem,

RuntimeError: Only tuples, lists and Variables are supported as JIT inputs/outputs. Dictionaries and strings are also accepted, but their usage is not recommended. Here, received an input of unsupported type: InstanceData
04/05 12:39:01 - mmengine - ERROR - /home/yaojiahao/zbc/mmdeploy/mmdeploy/apis/core/pipeline_manager.py - pop_mp_output - 80 - `mmdeploy.apis.pytorch2onnx.torch2onnx` with Call id: 0 failed. exit.

@LKJacky
Copy link
Collaborator

LKJacky commented Apr 6, 2023

I'm sorry, I cannot reproduce this problem. I can deploy the pruned rtmdet on onnxcpu sucessfullly. Here is the version of my codes.

mmcls 1.0.0rc5
mmcv 2.0.0rc4
mmdet 3.0.0rc6
mmdet3d 1.1.0rc3
mmengine 0.7.0
mmpose 1.0.0rc0
mmrazor 1.0.0rc2
mmrotate 1.0.0rc0
mmsegmentation 1.0.0rc1
mmyolo 0.2.0
.

Please try the following commands and provide the error info if it failed

razor_config={your config}.py
deploy_config=mmdeploy/configs/mmdet/detection/detection_onnxruntime_static.py

python mmdeploy/tools/deploy.py $deploy_config \
    $razor_config \
   {your checkpoint} \
    mmdeploy/tests/data/tiger.jpeg \
    --work-dir ./work_dirs/mmdeploy

@zhubochao
Copy link
Author

zhubochao commented Apr 6, 2023

this is my env

mmcv       2.0.0rc4   https://github.com/open-mmlab/mmcv

mmdet      3.0.0rc6   https://github.com/open-mmlab/mmdetection

mmengine   0.6.0      https://github.com/open-mmlab/mmengine

mmrazor    1.0.0rc2   /home/yaojiahao/zbc/mmrazor

mmyolo     0.5.0      /home/yaojiahao/zbc/mmyolo

mmdeploy 1.0.0rc3     /home/yaojiahao/zbc/mmdeploy

here is my command

python tools/deploy.py configs/mmdet/detection/detection_onnxruntime_static.py \
../mmrazor/configs/pruning/mmdet/group_fisher/custom/group_fisher_act_deploy_rtmdet_1x_coco.py \
../mmrazor/work_dirs/group_fisher_act_finetune_rtmdet_1x_coco_0.8/best_coco/bbox_mAP_epoch_286.pth \
tests/auto_08856.jpg \
--device cuda:0 \
--work-dir work_dirs/deploy_rtmdet_prune_onnx/\
 --dump-info 

here is the error info

04/06 19:06:54 - mmengine - INFO - Export PyTorch model to ONNX: work_dirs/deploy_rtmdet_prune_onnx/end2end.onnx.
04/06 19:06:54 - mmengine - WARNING - Can not find mmdet.models.dense_heads.DETRHead.forward_single, function rewrite will not be applied

/home/yaojiahao/zbc/mmdeploy/mmdeploy/codebase/mmdet/models/detectors/single_stage.py:84: TracerWarning: Iterating over a tensor might cause the trace to be incorrect. Passing a tensor of different shape won't change the number of iterations executed (and might lead to errors or silently give incorrect results).

  img_shape = [int(val) for val in img_shape]
/home/yaojiahao/zbc/mmdeploy/mmdeploy/codebase/mmdet/models/detectors/single_stage.py:84: TracerWarning: Converting a tensor to a Python integer might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  img_shape = [int(val) for val in img_shape]

/home/yaojiahao/zbc/mmdeploy/mmdeploy/core/optimizers/function_marker.py:160: TracerWarning: Converting a tensor to a Python integer might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  ys_shape = tuple(int(s) for s in ys.shape)

/home/yaojiahao/zbc/mmyolo/mmyolo/models/task_modules/coders/distance_point_bbox_coder.py:45: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  assert points.size(-2) == pred_bboxes.size(-2)

/home/yaojiahao/zbc/mmyolo/mmyolo/models/task_modules/coders/distance_point_bbox_coder.py:46: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  assert points.size(-1) == 2

/home/yaojiahao/zbc/mmyolo/mmyolo/models/task_modules/coders/distance_point_bbox_coder.py:47: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  assert pred_bboxes.size(-1) == 4

/home/yaojiahao/zbc/mmyolo/mmyolo/models/dense_heads/yolov5_head.py:372: TracerWarning: Iterating over a tensor might cause the trace to be incorrect. Passing a tensor of different shape won't change the number of iterations executed (and might lead to errors or silently give incorrect results).
  img_meta) in zip(flatten_decoded_bboxes, flatten_cls_scores,

/home/yaojiahao/zbc/mmyolo/mmyolo/models/dense_heads/yolov5_head.py:394: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if scores.shape[0] == 0:

/home/yaojiahao/miniconda3/envs/mmyolo/lib/python3.9/site-packages/mmdet/models/utils/misc.py:336: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  num_topk = min(topk, valid_idxs.size(0))

/home/yaojiahao/miniconda3/envs/mmyolo/lib/python3.9/site-packages/mmengine/structures/instance_data.py:296: TracerWarning: Using len to get tensor shape might cause the trace to be incorrect. Recommended usage would be tensor.shape[0]. Passing a tensor of different shape might lead to errors or silently give incorrect results.
  return len(self.values()[0])

/home/yaojiahao/miniconda3/envs/mmyolo/lib/python3.9/site-packages/mmengine/structures/instance_data.py:139: TracerWarning: Using len to get tensor shape might cause the trace to be incorrect. Recommended usage would be tensor.shape[0]. Passing a tensor of different shape might lead to errors or silently give incorrect results.
  assert len(value) == len(self), 'The length of ' \

/home/yaojiahao/miniconda3/envs/mmyolo/lib/python3.9/site-packages/mmdet/models/dense_heads/base_dense_head.py:477: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if with_nms and results.bboxes.numel() > 0:

/home/yaojiahao/miniconda3/envs/mmyolo/lib/python3.9/site-packages/mmcv/ops/nms.py:276: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if boxes.size(-1) == 5:

/home/yaojiahao/miniconda3/envs/mmyolo/lib/python3.9/site-packages/mmcv/ops/nms.py:293: TracerWarning: torch.tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect.
  max_coordinate + torch.tensor(1).to(boxes))

/home/yaojiahao/miniconda3/envs/mmyolo/lib/python3.9/site-packages/mmcv/ops/nms.py:301: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if boxes_for_nms.shape[0] < split_thr:

/home/yaojiahao/miniconda3/envs/mmyolo/lib/python3.9/site-packages/mmcv/ops/nms.py:316: TracerWarning: Iterating over a tensor might cause the trace to be incorrect. Passing a tensor of different shape won't change the number of iterations executed (and might lead to errors or silently give incorrect results).
  for id in torch.unique(idxs):

/home/yaojiahao/miniconda3/envs/mmyolo/lib/python3.9/site-packages/mmcv/ops/nms.py:123: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  assert boxes.size(1) == 4

/home/yaojiahao/miniconda3/envs/mmyolo/lib/python3.9/site-packages/mmcv/ops/nms.py:124: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  assert boxes.size(0) == scores.size(0)

Process Process-2:
Traceback (most recent call last):
  File "/home/yaojiahao/miniconda3/envs/mmyolo/lib/python3.9/multiprocessing/process.py", line 315, in _bootstrap
    self.run()

  File "/home/yaojiahao/miniconda3/envs/mmyolo/lib/python3.9/multiprocessing/process.py", line 108, in run
    self._target(*self._args, **self._kwargs)

  File "/home/yaojiahao/zbc/mmdeploy/mmdeploy/apis/core/pipeline_manager.py", line 107, in __call__
    ret = func(*args, **kwargs)

  File "/home/yaojiahao/zbc/mmdeploy/mmdeploy/apis/pytorch2onnx.py", line 110, in torch2onnx
    export(

  File "/home/yaojiahao/zbc/mmdeploy/mmdeploy/apis/core/pipeline_manager.py", line 356, in _wrap
    return self.call_function(func_name_, *args, **kwargs)

  File "/home/yaojiahao/zbc/mmdeploy/mmdeploy/apis/core/pipeline_manager.py", line 326, in call_function
    return self.call_function_local(func_name, *args, **kwargs)

  File "/home/yaojiahao/zbc/mmdeploy/mmdeploy/apis/core/pipeline_manager.py", line 275, in call_function_local
    return pipe_caller(*args, **kwargs)

  File "/home/yaojiahao/zbc/mmdeploy/mmdeploy/apis/core/pipeline_manager.py", line 107, in __call__
    ret = func(*args, **kwargs)

  File "/home/yaojiahao/zbc/mmdeploy/mmdeploy/apis/onnx/export.py", line 131, in export
    torch.onnx.export(

  File "/home/yaojiahao/miniconda3/envs/mmyolo/lib/python3.9/site-packages/torch/onnx/utils.py", line 504, in export
    _export(

  File "/home/yaojiahao/miniconda3/envs/mmyolo/lib/python3.9/site-packages/torch/onnx/utils.py", line 1529, in _export
    graph, params_dict, torch_out = _model_to_graph(

  File "/home/yaojiahao/zbc/mmdeploy/mmdeploy/apis/onnx/optimizer.py", line 11, in model_to_graph__custom_optimizer
    graph, params_dict, torch_out = ctx.origin_func(*args, **kwargs)

  File "/home/yaojiahao/miniconda3/envs/mmyolo/lib/python3.9/site-packages/torch/onnx/utils.py", line 1111, in _model_to_graph
    graph, params, torch_out, module = _create_jit_graph(model, args)

  File "/home/yaojiahao/miniconda3/envs/mmyolo/lib/python3.9/site-packages/torch/onnx/utils.py", line 987, in _create_jit_graph
    graph, torch_out = _trace_and_get_graph_from_model(model, args)

  File "/home/yaojiahao/miniconda3/envs/mmyolo/lib/python3.9/site-packages/torch/onnx/utils.py", line 891, in _trace_and_get_graph_from_model
    trace_graph, torch_out, inputs_states = torch.jit._get_trace_graph(

  File "/home/yaojiahao/miniconda3/envs/mmyolo/lib/python3.9/site-packages/torch/jit/_trace.py", line 1184, in _get_trace_graph
    outs = ONNXTracedModule(f, strict, _force_outplace, return_inputs, _return_inputs_states)(*args, **kwargs)

  File "/home/yaojiahao/miniconda3/envs/mmyolo/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1190, in _call_impl
    return forward_call(*input, **kwargs)

  File "/home/yaojiahao/miniconda3/envs/mmyolo/lib/python3.9/site-packages/torch/jit/_trace.py", line 127, in forward
    graph, out = torch._C._create_graph_by_tracing(

  File "/home/yaojiahao/miniconda3/envs/mmyolo/lib/python3.9/site-packages/torch/jit/_trace.py", line 121, in wrapper
    out_vars, _ = _flatten(outs)

RuntimeError: Only tuples, lists and Variables are supported as JIT inputs/outputs. Dictionaries and strings are also accepted, but their usage is not recommended. Here, received an input of unsupported type: InstanceData
04/06 19:06:58 - mmengine - ERROR - /home/yaojiahao/zbc/mmdeploy/mmdeploy/apis/core/pipeline_manager.py - pop_mp_output - 80 - `mmdeploy.apis.pytorch2onnx.torch2onnx` with Call id: 0 failed. exit.

but if I try to use

# print("test raw onnx start")
# test_model    = torch_model.eval()
#  dummy_input = torch.randn(1, 3, 640, 640, device='cuda:0')
# torch.onnx._export(test_model, dummy_input, "output_cuda.onnx", verbose=True, opset_version=11)

before the export function in pytorch2onnx.py to export
onnx manually, it can export correctlly. So this confused me

@LKJacky
Copy link
Collaborator

LKJacky commented Apr 7, 2023

We find there is a bug when deploy a pruned model using cuda and fix it in this pr. Please try it.

@pppppM pppppM added Pruning bug Something isn't working labels Apr 10, 2023
@LKJacky
Copy link
Collaborator

LKJacky commented Apr 14, 2023

We closed the issue as there has been no response in a long time. You can reopen it when you need.

@LKJacky LKJacky closed this as completed Apr 14, 2023
@zhubochao
Copy link
Author

sry for my no reply, but this pr still can not work in my env.The error info is still the same as before

# for free to join this conversation on GitHub. Already have an account? # to comment
Labels
bug Something isn't working Pruning
Projects
None yet
Development

No branches or pull requests

3 participants