Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Cannot find pseudo label for frame #9

Closed
hughjazzman opened this issue Jul 21, 2021 · 17 comments
Closed

Cannot find pseudo label for frame #9

hughjazzman opened this issue Jul 21, 2021 · 17 comments

Comments

@hughjazzman
Copy link

I am getting an error when running train.py, it seems to have something to do with PSEUDO_LABEL not being updated. The Traceback repeats for multiple frames, not just 002080 as seen below. I've also put the full output on this gist, in case the information below is not enough. Am I missing something? Thanks for any help!

Commands Run

$ NUM_GPUS=8
$ CONFIG_FILE=cfgs/da-waymo-kitti_models/pvrcnn_st3d/pvrcnn_st3d.yaml
$ bash scripts/dist_train.sh ${NUM_GPUS} --cfg_file ${CONFIG_FILE}

Error

[2021-07-21 15:05:09,022  train.py 168  INFO]  **********************Start training da-waymo-kitti_models/pvrcnn_st3d/pvrcnn_st3d(default)**********************
generate_ps_e0: 100%|████████████████████| 232/232 [03:14<00:00,  1.19it/s, pos_ps_box=0.000(0.000), ign_ps_box=15.000(14.899)]
Traceback (most recent call last):                                                                                             
  File "train.py", line 199, in <module>
    main()
  File "train.py", line 191, in main
    ema_model=None
  File "/home/user5/open-mmlab/ST3D/tools/train_utils/train_st_utils.py", line 157, in train_model_st
    dataloader_iter=dataloader_iter, ema_model=ema_model
  File "/home/user5/open-mmlab/ST3D/tools/train_utils/train_st_utils.py", line 42, in train_one_epoch_st
    target_batch = next(dataloader_iter)
  File "/home/user5/anaconda3/envs/st3d7/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 582, in __next__
    return self._process_next_batch(batch)
  File "/home/user5/anaconda3/envs/st3d7/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 608, in _process_next_batch
    raise batch.exc_type(batch.exc_msg)
ValueError: Traceback (most recent call last):
  File "/home/user5/anaconda3/envs/st3d7/lib/python3.7/site-packages/torch/utils/data/_utils/worker.py", line 99, in _worker_loop
    samples = collate_fn([dataset[i] for i in batch_indices])
  File "/home/user5/anaconda3/envs/st3d7/lib/python3.7/site-packages/torch/utils/data/_utils/worker.py", line 99, in <listcomp>
    samples = collate_fn([dataset[i] for i in batch_indices])
  File "/home/user5/open-mmlab/ST3D/tools/../pcdet/datasets/kitti/kitti_dataset.py", line 413, in __getitem__
    self.fill_pseudo_labels(input_dict)
  File "/home/user5/open-mmlab/ST3D/tools/../pcdet/datasets/dataset.py", line 146, in fill_pseudo_labels
    gt_boxes = self_training_utils.load_ps_label(input_dict['frame_id'])
  File "/home/user5/open-mmlab/ST3D/tools/../pcdet/utils/self_training_utils.py", line 221, in load_ps_label
    raise ValueError('Cannot find pseudo label for frame: %s' % frame_id)
ValueError: Cannot find pseudo label for frame: 002080

epochs:   0%|                                                                                           | 0/30 [04:05<?, ?it/s]

Environment

Python 3.7
CUDA 10.0
PyTorch 1.1
spconv 1.0 (commit 8da6f96)
pcdet 0.2.0+73dda8c

@jihanyang
Copy link
Member

Please follow GETTING_START.md. You need to pretrain a detector first.

@hughjazzman
Copy link
Author

Thank you.

@hughjazzman
Copy link
Author

hughjazzman commented Aug 2, 2021

I am getting the exact same error even when using the PRETRAINED_MODEL. How can I fix this?

$ CONFIG_FILE=cfgs/da-waymo-kitti_models/pvrcnn_st3d/pvrcnn_st3d.yaml
$ PRETRAINED_MODEL=../output/da-waymo-kitti_models/pvrcnn/pvrcnn_old_anchor/default/ckpt/checkpoint_epoch_25.pth
$ bash scripts/dist_train.sh ${NUM_GPUS} --cfg_file ${CONFIG_FILE} --pretrained_model ${PRETRAINED_MODEL}
Traceback (most recent call last):                                                                                          
  File "train.py", line 199, in <module>
    main()
  File "train.py", line 169, in main
    train_func(
  File "/home/user5/open-mmlab/ST3D/tools/train_utils/train_st_utils.py", line 150, in train_model_st
    accumulated_iter = train_one_epoch_st(
  File "/home/user5/open-mmlab/ST3D/tools/train_utils/train_st_utils.py", line 42, in train_one_epoch_st
    target_batch = next(dataloader_iter)
  File "/home/user5/anaconda3/envs/st3d/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 345, in __next__
    data = self._next_data()
  File "/home/user5/anaconda3/envs/st3d/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 856, in _next_data
    return self._process_data(data)
  File "/home/user5/anaconda3/envs/st3d/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 881, in _process_data
    data.reraise()
  File "/home/user5/anaconda3/envs/st3d/lib/python3.8/site-packages/torch/_utils.py", line 394, in reraise
    raise self.exc_type(msg)
ValueError: Caught ValueError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/home/user5/anaconda3/envs/st3d/lib/python3.8/site-packages/torch/utils/data/_utils/worker.py", line 178, in _worker_loop
    data = fetcher.fetch(index)
  File "/home/user5/anaconda3/envs/st3d/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py", line 44, in fetch
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/home/user5/anaconda3/envs/st3d/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py", line 44, in <listcomp>
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/home/user5/open-mmlab/ST3D/tools/../pcdet/datasets/kitti/kitti_dataset.py", line 413, in __getitem__
    self.fill_pseudo_labels(input_dict)
  File "/home/user5/open-mmlab/ST3D/tools/../pcdet/datasets/dataset.py", line 146, in fill_pseudo_labels
    gt_boxes = self_training_utils.load_ps_label(input_dict['frame_id'])
  File "/home/user5/open-mmlab/ST3D/tools/../pcdet/utils/self_training_utils.py", line 221, in load_ps_label
    raise ValueError('Cannot find pseudo label for frame: %s' % frame_id)
ValueError: Cannot find pseudo label for frame: 002829

@hughjazzman hughjazzman reopened this Aug 2, 2021
@jihanyang
Copy link
Member

Could you provide the log of your config?

@hughjazzman
Copy link
Author

Here is the logfile:
log_train_20210802-131316.txt

@jihanyang
Copy link
Member

Everything seems well. Could you find the ps_label folder under the corresponding output foder? Also, have you noticed the number of pseudo label that generated?

@hughjazzman
Copy link
Author

There is only one ps_label file generated as ps_label_e0.pkl. If I load it:

>>> import pickle
>>> with open('ps_label_e0.pkl', 'rb') as f:
...     data = pickle.load(f)
...
>>> len(data)
3712
>>> data['002080'].keys()
dict_keys(['gt_boxes', 'cls_scores', 'iou_scores', 'memory_counter'])
>>> [len(data['002080'][x]) for x in data['002080'].keys()]
[43, 43, 43, 43]
>>> [data['002080'][x][0] for x in data['002080'].keys()]
[array([7.82266045, 5.05001545, 0.72913325, 4.59020233, 2.00834894,
       1.51216364, 4.59793329, 1.        , 0.84709507]), 0.9126919, 0.8470951, 0.0]

@jihanyang
Copy link
Member

Can you load frame 002829?

@hughjazzman
Copy link
Author

Yes.

>>> len(data['002829']['gt_boxes'])
36
>>> data['002829']['gt_boxes'][0]
array([-12.67259693,   2.17882085,   0.4152911 ,   4.12861061,
         1.84807122,   1.557392  ,   3.2007041 ,   1.        ,
         0.81863701])
>>> data['002829']['cls_scores']
array([0.93731874, 0.92046964, 0.9030588 , 0.9078988 , 0.8054646 ,
       0.80939907, 0.88196886, 0.85936695, 0.86304134, 0.8979211 ,
       0.8379814 , 0.6361919 , 0.71232146, 0.8505274 , 0.70933247,
       0.7508229 , 0.6692671 , 0.4446245 , 0.46104938, 0.15859666,
       0.3290353 , 0.1491799 , 0.24889277, 0.13674273, 0.13740459,
       0.11474051, 0.14078389, 0.14664578, 0.11770795, 0.25883213,
       0.11457415, 0.12707922, 0.13901637, 0.12712085, 0.16028087,
       0.18077461], dtype=float32)
>>> data['002829']['iou_scores']
array([0.818637  , 0.8166671 , 0.8074603 , 0.8054869 , 0.79433286,
       0.7942598 , 0.7878299 , 0.78539443, 0.7752307 , 0.76748055,
       0.75878924, 0.7525175 , 0.7474871 , 0.74367374, 0.7227673 ,
       0.7154487 , 0.67805743, 0.64385206, 0.61908954, 0.55265254,
       0.5234711 , 0.4574411 , 0.34837985, 0.33632967, 0.32025087,
       0.2999313 , 0.26421076, 0.24447218, 0.24083728, 0.21878265,
       0.21566562, 0.19372286, 0.19200402, 0.16185848, 0.1382765 ,
       0.10244821], dtype=float32)
>>> data['002829']['memory_counter']
array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0.])

@jihanyang
Copy link
Member

I am getting the exact same error even when using the PRETRAINED_MODEL. How can I fix this?

$ CONFIG_FILE=cfgs/da-waymo-kitti_models/pvrcnn_st3d/pvrcnn_st3d.yaml
$ PRETRAINED_MODEL=../output/da-waymo-kitti_models/pvrcnn/pvrcnn_old_anchor/default/ckpt/checkpoint_epoch_25.pth
$ bash scripts/dist_train.sh ${NUM_GPUS} --cfg_file ${CONFIG_FILE} --pretrained_model ${PRETRAINED_MODEL}
Traceback (most recent call last):                                                                                          
  File "train.py", line 199, in <module>
    main()
  File "train.py", line 169, in main
    train_func(
  File "/home/user5/open-mmlab/ST3D/tools/train_utils/train_st_utils.py", line 150, in train_model_st
    accumulated_iter = train_one_epoch_st(
  File "/home/user5/open-mmlab/ST3D/tools/train_utils/train_st_utils.py", line 42, in train_one_epoch_st
    target_batch = next(dataloader_iter)
  File "/home/user5/anaconda3/envs/st3d/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 345, in __next__
    data = self._next_data()
  File "/home/user5/anaconda3/envs/st3d/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 856, in _next_data
    return self._process_data(data)
  File "/home/user5/anaconda3/envs/st3d/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 881, in _process_data
    data.reraise()
  File "/home/user5/anaconda3/envs/st3d/lib/python3.8/site-packages/torch/_utils.py", line 394, in reraise
    raise self.exc_type(msg)
ValueError: Caught ValueError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/home/user5/anaconda3/envs/st3d/lib/python3.8/site-packages/torch/utils/data/_utils/worker.py", line 178, in _worker_loop
    data = fetcher.fetch(index)
  File "/home/user5/anaconda3/envs/st3d/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py", line 44, in fetch
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/home/user5/anaconda3/envs/st3d/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py", line 44, in <listcomp>
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/home/user5/open-mmlab/ST3D/tools/../pcdet/datasets/kitti/kitti_dataset.py", line 413, in __getitem__
    self.fill_pseudo_labels(input_dict)
  File "/home/user5/open-mmlab/ST3D/tools/../pcdet/datasets/dataset.py", line 146, in fill_pseudo_labels
    gt_boxes = self_training_utils.load_ps_label(input_dict['frame_id'])
  File "/home/user5/open-mmlab/ST3D/tools/../pcdet/utils/self_training_utils.py", line 221, in load_ps_label
    raise ValueError('Cannot find pseudo label for frame: %s' % frame_id)
ValueError: Cannot find pseudo label for frame: 002829

But you cannot find frame 002829 here, it is quite strange. You can try to debug it. Beside, you can set the INIT_PS key in the config as the path of generated pseudo_label and set UPDATE_PSEUDO_LABEL to [-1], so you can directly load pseudo label from the pkl without re-generate.

@AndyYuan96
Copy link

Yes.

>>> len(data['002829']['gt_boxes'])
36
>>> data['002829']['gt_boxes'][0]
array([-12.67259693,   2.17882085,   0.4152911 ,   4.12861061,
         1.84807122,   1.557392  ,   3.2007041 ,   1.        ,
         0.81863701])
>>> data['002829']['cls_scores']
array([0.93731874, 0.92046964, 0.9030588 , 0.9078988 , 0.8054646 ,
       0.80939907, 0.88196886, 0.85936695, 0.86304134, 0.8979211 ,
       0.8379814 , 0.6361919 , 0.71232146, 0.8505274 , 0.70933247,
       0.7508229 , 0.6692671 , 0.4446245 , 0.46104938, 0.15859666,
       0.3290353 , 0.1491799 , 0.24889277, 0.13674273, 0.13740459,
       0.11474051, 0.14078389, 0.14664578, 0.11770795, 0.25883213,
       0.11457415, 0.12707922, 0.13901637, 0.12712085, 0.16028087,
       0.18077461], dtype=float32)
>>> data['002829']['iou_scores']
array([0.818637  , 0.8166671 , 0.8074603 , 0.8054869 , 0.79433286,
       0.7942598 , 0.7878299 , 0.78539443, 0.7752307 , 0.76748055,
       0.75878924, 0.7525175 , 0.7474871 , 0.74367374, 0.7227673 ,
       0.7154487 , 0.67805743, 0.64385206, 0.61908954, 0.55265254,
       0.5234711 , 0.4574411 , 0.34837985, 0.33632967, 0.32025087,
       0.2999313 , 0.26421076, 0.24447218, 0.24083728, 0.21878265,
       0.21566562, 0.19372286, 0.19200402, 0.16185848, 0.1382765 ,
       0.10244821], dtype=float32)
>>> data['002829']['memory_counter']
array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0.])

use one gpu to train is normal, for 8 gpus, you need to change the code, use multiprocessing library, use Manager().dict() to replace dict(), and then carefully change the pseudo label save and load code.

@jihanyang
Copy link
Member

We haven't met this problem with massive 8 gpus training attempts, but this could be helpful for others who also meet this problem.

@hughjazzman
Copy link
Author

@AndyYuan96 Thank you. I managed to train, but I did not get a big improvement in the result. Should it use pvrcnn_old_anchor_ros.yaml instead?

This is the best pretrained result for pvrcnn_old_anchor.yaml, at epoch 25:

Car AP@0.70, 0.70, 0.70:
bbox AP:93.5211, 85.1041, 85.0105
bev  AP:77.4258, 70.3797, 64.6284
3d   AP:31.4607, 29.8805, 28.0660
aos  AP:93.41, 84.80, 84.62
Car AP_R40@0.70, 0.70, 0.70:
bbox AP:94.5887, 86.8419, 85.3183
bev  AP:77.5128, 69.9266, 66.3514
3d   AP:26.0079, 24.7598, 22.9933
aos  AP:94.47, 86.51, 84.94
Car AP@0.70, 0.50, 0.50:
bbox AP:93.5211, 85.1041, 85.0105
bev  AP:94.9880, 89.3720, 90.0146
3d   AP:94.8230, 86.9519, 86.9732
aos  AP:93.41, 84.80, 84.62
Car AP_R40@0.70, 0.50, 0.50:
bbox AP:94.5887, 86.8419, 85.3183
bev  AP:96.2490, 92.4525, 92.8184
3d   AP:96.0851, 90.6961, 90.9766
aos  AP:94.47, 86.51, 84.94

Then after training for pvrcnn_st3d.yaml using the above as PRETRAINED_MODEL, this is the best result at epoch 7:

Car AP@0.70, 0.70, 0.70:
bbox AP:94.5726, 85.9425, 85.7265
bev  AP:74.7032, 68.8260, 63.5809
3d   AP:26.4151, 25.9272, 25.8578
aos  AP:94.42, 85.62, 85.35
Car AP_R40@0.70, 0.70, 0.70:
bbox AP:95.1976, 88.2716, 86.5882
bev  AP:74.0632, 68.5121, 65.2104
3d   AP:21.1435, 21.6394, 19.8873
aos  AP:95.03, 87.93, 86.20
Car AP@0.70, 0.50, 0.50:
bbox AP:94.5726, 85.9425, 85.7265
bev  AP:96.6629, 92.4132, 92.8342
3d   AP:95.5956, 87.3461, 87.3515
aos  AP:94.42, 85.62, 85.35
Car AP_R40@0.70, 0.50, 0.50:
bbox AP:95.1976, 88.2716, 86.5882
bev  AP:97.6557, 94.3580, 94.3803
3d   AP:96.2199, 91.8536, 91.9323
aos  AP:95.03, 87.93, 86.20

@Liz66666
Copy link

Liz66666 commented Aug 5, 2021

@hughjazzman I meet the same "Cannot find pseudo label for frame" error after I have generated the pseudo labels of epoch0, and this error only occurs when I use distributed trianing. How did you fix this problem?

@hughjazzman
Copy link
Author

hughjazzman commented Aug 5, 2021

@Liz66666 I followed @AndyYuan96's advice on using Manager().dict() for PSEUDO_LABEL in self_training_utils.py.

from multiprocessing import Manager

PSEUDO_LABEL = Manager().dict()

Actually before doing this, I managed to run the training by adding the pkl.load code from check_already_exsit_pseudo_label to load_ps_label, but changed to the above solution and reran the training as it should be better.

@Liz66666
Copy link

Liz66666 commented Aug 5, 2021

@hughjazzman Thanks for your reply! It really works for me!

@AndyYuan96
Copy link

@AndyYuan96 Thank you. I managed to train, but I did not get a big improvement in the result. Should it use pvrcnn_old_anchor_ros.yaml instead?

This is the best pretrained result for pvrcnn_old_anchor.yaml, at epoch 25:

Car AP@0.70, 0.70, 0.70:
bbox AP:93.5211, 85.1041, 85.0105
bev  AP:77.4258, 70.3797, 64.6284
3d   AP:31.4607, 29.8805, 28.0660
aos  AP:93.41, 84.80, 84.62
Car AP_R40@0.70, 0.70, 0.70:
bbox AP:94.5887, 86.8419, 85.3183
bev  AP:77.5128, 69.9266, 66.3514
3d   AP:26.0079, 24.7598, 22.9933
aos  AP:94.47, 86.51, 84.94
Car AP@0.70, 0.50, 0.50:
bbox AP:93.5211, 85.1041, 85.0105
bev  AP:94.9880, 89.3720, 90.0146
3d   AP:94.8230, 86.9519, 86.9732
aos  AP:93.41, 84.80, 84.62
Car AP_R40@0.70, 0.50, 0.50:
bbox AP:94.5887, 86.8419, 85.3183
bev  AP:96.2490, 92.4525, 92.8184
3d   AP:96.0851, 90.6961, 90.9766
aos  AP:94.47, 86.51, 84.94

Then after training for pvrcnn_st3d.yaml using the above as PRETRAINED_MODEL, this is the best result at epoch 7:

Car AP@0.70, 0.70, 0.70:
bbox AP:94.5726, 85.9425, 85.7265
bev  AP:74.7032, 68.8260, 63.5809
3d   AP:26.4151, 25.9272, 25.8578
aos  AP:94.42, 85.62, 85.35
Car AP_R40@0.70, 0.70, 0.70:
bbox AP:95.1976, 88.2716, 86.5882
bev  AP:74.0632, 68.5121, 65.2104
3d   AP:21.1435, 21.6394, 19.8873
aos  AP:95.03, 87.93, 86.20
Car AP@0.70, 0.50, 0.50:
bbox AP:94.5726, 85.9425, 85.7265
bev  AP:96.6629, 92.4132, 92.8342
3d   AP:95.5956, 87.3461, 87.3515
aos  AP:94.42, 85.62, 85.35
Car AP_R40@0.70, 0.50, 0.50:
bbox AP:95.1976, 88.2716, 86.5882
bev  AP:97.6557, 94.3580, 94.3803
3d   AP:96.2199, 91.8536, 91.9323
aos  AP:95.03, 87.93, 86.20

use ros or sn config, not pvrcnn_old_anchor.yaml。

# for free to join this conversation on GitHub. Already have an account? # to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants