-
Notifications
You must be signed in to change notification settings - Fork 557
New issue
Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? # to your account
RuntimeError: CUDA error: an illegal memory access was encountered when training cfa #614
Comments
Please try to install mmcv by
Do other models also raise errors? Or just this one? |
I've tried reinstall mmcv using command above, but this erroe still exist. |
I also encountered this problem, It can be trained at the beginning, but this problem occurs during the training process or validation randomly. |
Hi @Lebron0126 |
It's weird. Cfa, rotated_reppoints and oriented_reppoints can success train or test on our device. |
I encountered the same problem using oriented_reppoints. |
Several feasible solutions: #405 |
Same error. But it occurs in epoch_8. It seems this error may occur in any time of the training procedure. |
Update for this bug: |
According to open-mmlab/mmcv#2407, adding small random noise to the input of mmcv.ops.min_area_polygons seems to be a feasible solution. As stated in the above link, this bug may be caused by numerical instability of min_area_polygons cuda op. Meanwhile, I think cv2.minAreaRect is another way to fix this bug as cv2.minAreaRect can return the true min area polygons without adding any noises to the input. But the speed of training might be slowed down as cv2.minAreaRect is running on cpu.
FYI, @yangxue0827 Unfortunately, all the solutions you provided is proven to be unuseful in my case |
@yangxue0827 @zytx121 What can i do for this issue. I got same error. +) It maybe related with this, and There's same line in mmrotate_handler.py |
Update for this bug: For min_area_polygon function, I tested the case where the input convex is extremely small, for example, all points in convex are 0 or random numbers generated with mean of 0 and variance of 1e-30 are used as points of convex. However, it works normally in these cases, so I further tested the function and find a convex data that can reproduce this problem:
You can find that the point of this convex are distributed in the corner as well as the boundary of the image(all images in my dataset are 1024*1024), I don't know the exact reason of this problem, but it seems that convex points distributed in the corner or border of the image may lead to the collapse of this function. For convex_giou function, I also test extremely small convex input(the convex area is extremely small). In such situation, convex_iou works normally but convex_giou returns negative iou with 0 grad. However, it doesn't occur errors in this situation and can still running. After further testing, I found that specific predicted convex with specific target quadrangle may lead to the collapse of convex_giou. Here is an example I found can reproduce the error of convex_giou:
In this example, the target quadrangle intersects the boundary of the image, and this error only occurs when convex points are all right bottom corner of the image, even convex with all 0 points will not lead to this error. I don't know why such thing happens and it's so wired. For this case, simply add some small random noise with std=1e-3 to every convex point can avoid the error. |
@sltlls In my case, it works for the min_area_polygons test.
And for the convex_giou test, I found the same error.
|
The error occurred in judging the collinearity of the points. For example, as shown in the figure, the value of while (k_index != max_index) {
p_k = p_max;
k_index = max_index;
for (int i = 1; i < n_poly; i++) {
sign = cross(in_poly[Stack[top2]], in_poly[i], p_k);
if ((sign < 0) || (sign == 0) && (dis(in_poly[Stack[top2]], in_poly[i]) >
dis(in_poly[Stack[top2]], p_k))) {
p_k = in_poly[i];
k_index = i;
}
}
top2++;
Stack[top2] = k_index;
}
|
In this case, the convex hull of
|
Fun fact: when dtype of pred_pt and target are setted to torch.float64, it works for the convex_giou test. Maybe this is another feasible solution? @DapengFeng |
Changing data type will relieve the number instability but not solve the problem. Pre-sorting the points in order guarantees that the newly added point is always outside the convex hull formed upto then. An interesting website about convex hull. |
Prerequisite
Task
I'm using the official example scripts/configs for the officially supported tasks/models/datasets.
Branch
1.x branch https://github.com/open-mmlab/mmrotate/tree/1.x
Environment
sys.platform: linux
Python: 3.8.13 (default, Mar 28 2022, 11:38:47) [GCC 7.5.0]
CUDA available: True
numpy_random_seed: 2147483648
GPU 0,1,2,3: TITAN X (Pascal)
CUDA_HOME: /usr/local/cuda-10.1
NVCC: Cuda compilation tools, release 10.1, V10.1.10
GCC: gcc (Ubuntu 5.4.0-6ubuntu1~16.04.12) 5.4.0 20160609
PyTorch: 1.7.1+cu101
PyTorch compiling details: PyTorch built with:
TorchVision: 0.8.2+cu101
OpenCV: 4.6.0
MMEngine: 0.3.0
MMRotate: 1.0.0rc0+
Reproduces the problem - code sample
python tools/tain.py configs/cfa/cfa-qbox_r50_fpn_1x_dota.py
Reproduces the problem - command or script
CUDA_LAUNCH_BLOCKING=1 python tools/tain.py configs/cfa/cfa-qbox_r50_fpn_1x_dota.py
Reproduces the problem - error message
Result has been saved to /media/amax/partion2/lsy/work_dir/mmlabV2/mmrotate/cfa/cfa_qbox_r50_fpn_1x_dota1/modules_statistic_results.json
11/13 14:51:42 - mmengine - INFO - Distributed training is not used, all SyncBatchNorm (SyncBN) layers in the model will be automatically reverted to BatchNormXd layers if they are used.
11/13 14:51:46 - mmengine - WARNING - Failed to search registry with scope "mmrotate" in the "optim_wrapper" registry tree. As a workaround, the current "optim_wrapper" registry in "mmengine" is used to build instance. This may cause unexpected failure when running the built modules. Please check whether "mmrotate" is a correct scope, or whether the registry is initialized.
fatal: Not a git repository (or any parent up to mount point /media/amax/partion2)
Stopping at filesystem boundary (GIT_DISCOVERY_ACROSS_FILESYSTEM not set).
11/13 14:51:47 - mmengine - INFO - load model from: torchvision://resnet50
11/13 14:51:47 - mmengine - INFO - torchvision loads checkpoint from path: torchvision://resnet50
11/13 14:51:48 - mmengine - WARNING - The model and loaded state dict do not match exactly
unexpected key in source state_dict: fc.weight, fc.bias
11/13 14:51:48 - mmengine - INFO - Checkpoints will be saved to /media/amax/partion2/lsy/work_dir/mmlabV2/mmrotate/cfa/cfa_qbox_r50_fpn_1x_dota1.
/media/amax/partion2/lsy/workspace/mmlab_V2/mmrotate/mmrotate/structures/bbox/quadri_boxes.py:146: UserWarning: The
clip
function does nothing inQuadriBoxes
.warnings.warn('The
clip
function does nothing inQuadriBoxes
.')/media/amax/partion2/lsy/workspace/mmlab_V2/mmrotate/mmrotate/structures/bbox/quadri_boxes.py:146: UserWarning: The
clip
function does nothing inQuadriBoxes
.warnings.warn('The
clip
function does nothing inQuadriBoxes
.')Traceback (most recent call last):
File "tools/train.py", line 122, in
main()
File "tools/train.py", line 118, in main
runner.train()
File "/home/amax/anaconda3/envs/openmmlab2/lib/python3.8/site-packages/mmengine/runner/runner.py", line 1661, in train
model = self.train_loop.run() # type: ignore
File "/home/amax/anaconda3/envs/openmmlab2/lib/python3.8/site-packages/mmengine/runner/loops.py", line 90, in run
self.run_epoch()
File "/home/amax/anaconda3/envs/openmmlab2/lib/python3.8/site-packages/mmengine/runner/loops.py", line 106, in run_epoch
self.run_iter(idx, data_batch)
File "/home/amax/anaconda3/envs/openmmlab2/lib/python3.8/site-packages/mmengine/runner/loops.py", line 122, in run_iter
outputs = self.runner.model.train_step(
File "/home/amax/anaconda3/envs/openmmlab2/lib/python3.8/site-packages/mmengine/model/base_model/base_model.py", line 114, in train_step
losses = self._run_forward(data, mode='loss') # type: ignore
File "/home/amax/anaconda3/envs/openmmlab2/lib/python3.8/site-packages/mmengine/model/base_model/base_model.py", line 320, in _run_forward
results = self(**data, mode=mode)
File "/home/amax/anaconda3/envs/openmmlab2/lib/python3.8/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
result = self.forward(*input, kwargs)
File "/media/amax/partion2/lsy/workspace/mmlab_V2/mmdetection3.x/mmdet/models/detectors/base.py", line 92, in forward
return self.loss(inputs, data_samples)
File "/media/amax/partion2/lsy/workspace/mmlab_V2/mmdetection3.x/mmdet/models/detectors/single_stage.py", line 78, in loss
losses = self.bbox_head.loss(x, batch_data_samples)
File "/media/amax/partion2/lsy/workspace/mmlab_V2/mmdetection3.x/mmdet/models/dense_heads/base_dense_head.py", line 123, in loss
losses = self.loss_by_feat(loss_inputs)
File "/media/amax/partion2/lsy/workspace/mmlab_V2/mmrotate/mmrotate/models/dense_heads/cfa_head.py", line 186, in loss_by_feat
pos_losses_list, = multi_apply(self.get_pos_loss, cls_scores,
File "/media/amax/partion2/lsy/workspace/mmlab_V2/mmdetection3.x/mmdet/models/utils/misc.py", line 218, in multi_apply
return tuple(map(list, zip(map_results)))
File "/media/amax/partion2/lsy/workspace/mmlab_V2/mmrotate/mmrotate/models/dense_heads/cfa_head.py", line 379, in get_pos_loss
loss_bbox = self.loss_bbox_refine(
File "/home/amax/anaconda3/envs/openmmlab2/lib/python3.8/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
result = self.forward(input, **kwargs)
File "/media/amax/partion2/lsy/workspace/mmlab_V2/mmrotate/mmrotate/models/losses/convex_giou_loss.py", line 111, in forward
loss = self.loss_weight * convex_giou_loss(
File "/media/amax/partion2/lsy/workspace/mmlab_V2/mmrotate/mmrotate/models/losses/convex_giou_loss.py", line 37, in forward
convex_gious, grad = convex_giou(pred, target)
File "/home/amax/anaconda3/envs/openmmlab2/lib/python3.8/site-packages/mmcv/ops/convex_iou.py", line 28, in convex_giou
ext_module.convex_giou(pointsets, polygons, output)
RuntimeError: CUDA error: an illegal memory access was encountered
Exception raised from ConvexGIoUCUDAKernelLauncher at /tmp/mmcv/mmcv/ops/csrc/pytorch/cuda/convex_iou.cu:40 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::string) + 0x42 (0x7f6faf2928b2 in /home/amax/anaconda3/envs/openmmlab2/lib/python3.8/site-packages/torch/lib/libc10.so)
frame #1: ConvexGIoUCUDAKernelLauncher(at::Tensor, at::Tensor, at::Tensor) + 0x1d6 (0x7f6f8ccb85c3 in /home/amax/anaconda3/envs/openmmlab2/lib/python3.8/site-packages/mmcv/_ext.cpython-38-x86_64-linux-gnu.so)
frame #2: convex_giou_cuda(at::Tensor, at::Tensor, at::Tensor) + 0x67 (0x7f6f8ccc8417 in /home/amax/anaconda3/envs/openmmlab2/lib/python3.8/site-packages/mmcv/_ext.cpython-38-x86_64-linux-gnu.so)
frame #3: auto Dispatch<DeviceRegistry<void ()(at::Tensor, at::Tensor, at::Tensor), &(convex_giou_impl(at::Tensor, at::Tensor, at::Tensor))>, at::Tensor const&, at::Tensor const&, at::Tensor&>(DeviceRegistry<void ()(at::Tensor, at::Tensor, at::Tensor), &(convex_giou_impl(at::Tensor, at::Tensor, at::Tensor))> const&, char const, at::Tensor const&, at::Tensor const&, at::Tensor&) + 0x802 (0x7f6f8cbfe182 in /home/amax/anaconda3/envs/openmmlab2/lib/python3.8/site-packages/mmcv/_ext.cpython-38-x86_64-linux-gnu.so)
frame #4: convex_giou(at::Tensor, at::Tensor, at::Tensor) + 0x67 (0x7f6f8cbfbdd7 in /home/amax/anaconda3/envs/openmmlab2/lib/python3.8/site-packages/mmcv/_ext.cpython-38-x86_64-linux-gnu.so)
frame #5: + 0x330513 (0x7f6f8ce8a513 in /home/amax/anaconda3/envs/openmmlab2/lib/python3.8/site-packages/mmcv/_ext.cpython-38-x86_64-linux-gnu.so)
frame #6: + 0x350190 (0x7f6f8ceaa190 in /home/amax/anaconda3/envs/openmmlab2/lib/python3.8/site-packages/mmcv/_ext.cpython-38-x86_64-linux-gnu.so)
frame #7: + 0x34de4e (0x7f6f8cea7e4e in /home/amax/anaconda3/envs/openmmlab2/lib/python3.8/site-packages/mmcv/_ext.cpython-38-x86_64-linux-gnu.so)
frame #16: THPFunction_apply(_object, _object) + 0x93d (0x7f6ff9a732dd in /home/amax/anaconda3/envs/openmmlab2/lib/python3.8/site-packages/torch/lib/libtorch_python.so)
Additional information
I'm using DOTA V1.0 dataset, and this error occurs when training
The text was updated successfully, but these errors were encountered: