Skip to content

Commit

Permalink
Fix bug
Browse files Browse the repository at this point in the history
  • Loading branch information
yezhen17 committed Oct 16, 2021
1 parent 8af21a4 commit c01bc22
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 4 deletions.
5 changes: 3 additions & 2 deletions sunrgbd/sunrgbd_ssl_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ def __getitem__(self, idx):

class SunrgbdSSLUnlabeledDataset(Dataset):
def __init__(self, labeled_sample_list=None, num_points=20000, use_color=False, use_height=False, use_v1=False,
aug_num=1, scan_idx_list=None, load_labels=None):
aug_num=1, scan_idx_list=None, load_labels=None, augment=True):
print('----------------Sunrgbd Unlabeled Dataset Initialization----------------')
if use_v1:
self.data_path = os.path.join(ROOT_DIR, 'sunrgbd/sunrgbd_pc_bbox_votes_50k_v1_train')
Expand Down Expand Up @@ -211,6 +211,7 @@ def __init__(self, labeled_sample_list=None, num_points=20000, use_color=False,
self.use_height = use_height
self.aug_num = aug_num
self.load_labels = load_labels
self.augment = augment
if load_labels:
print('Warning! Loading labels for analysis')

Expand Down Expand Up @@ -308,4 +309,4 @@ def __getitem__(self, idx):
ret_dict['scale'] = np.array(scale_ratio).astype(np.float32)
ret_dict['scan_idx'] = np.array(idx).astype(np.int64)
ret_dict['supervised_mask'] = np.array(0).astype(np.int64)
return ret_dict
return ret_dict
6 changes: 4 additions & 2 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,8 @@ def my_worker_init_fn(worker_id):
use_color=FLAGS.use_color,
use_height=(not FLAGS.no_height),
use_v1=(not FLAGS.use_sunrgbd_v2),
load_labels=FLAGS.view_stats)
load_labels=FLAGS.view_stats,
augment=True)
TEST_DATASET = SunrgbdDetectionVotesDataset('val',
num_points=NUM_POINT, augment=False,
use_color=FLAGS.use_color, use_height=(not FLAGS.no_height),
Expand All @@ -142,7 +143,8 @@ def my_worker_init_fn(worker_id):
num_points=NUM_POINT,
use_color=FLAGS.use_color,
use_height=(not FLAGS.no_height),
load_labels=FLAGS.view_stats)
load_labels=FLAGS.view_stats,
augment=True)
TEST_DATASET = ScannetDetectionDataset('val',
num_points=NUM_POINT, augment=False,
use_color=FLAGS.use_color, use_height=(not FLAGS.no_height))
Expand Down

0 comments on commit c01bc22

Please # to comment.