@@ -692,7 +692,10 @@ def pose_loss(pred_quaternion, pred_translation, gt_pose, delta=1.5):
692
692
gt_translation = gt_pose [:3 , 3 ]
693
693
gt_rotation = gt_pose [:3 , :3 ]
694
694
gt_quaternion = rotation_matrix_to_quaternion (gt_rotation ) # Convert [3x3] to [4]
695
-
695
+ # Normalize the ground truth quaternion
696
+ gt_quaternion = F .normalize (gt_quaternion , p = 2 , dim = - 1 )
697
+ # Normalize the predicted quaternion
698
+ pred_quaternion = F .normalize (pred_quaternion , p = 2 , dim = - 1 )
696
699
# Convert predicted quaternion to rotation matrix
697
700
pred_rotation = quaternion_to_matrix (pred_quaternion , device = pred_quaternion .device )
698
701
@@ -792,6 +795,7 @@ def train_one_epoch(model, dataloader, optimizer, device, epoch, writer, use_poi
792
795
793
796
# # Backward pass and optimization step
794
797
loss .backward ()
798
+ torch .nn .utils .clip_grad_norm_ (model .parameters (), max_norm = 1.0 )
795
799
optimizer .step ()
796
800
797
801
running_loss += loss .item ()
@@ -1017,7 +1021,9 @@ def train_model(model, train_loader, val_loader, num_epochs, learning_rate, devi
1017
1021
log_interval (int): Interval for logging the training progress.
1018
1022
save_path (str): Path to save the model checkpoints.
1019
1023
"""
1020
- optimizer = torch .optim .Adam (model .parameters (), lr = learning_rate )
1024
+ # optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
1025
+ # optimizer = torch.optim.RMSprop(model.parameters(), lr=learning_rate)
1026
+ optimizer = torch .optim .SGD (model .parameters (), lr = learning_rate , momentum = 0.9 )
1021
1027
best_val_loss = float ('inf' ) # Track the best validation loss
1022
1028
1023
1029
# If using PointNet encoder, initialize it separately
@@ -1157,7 +1163,7 @@ def get_args():
1157
1163
# Add arguments with default values
1158
1164
parser .add_argument ('--base_dir' , type = str , default = '/home/eavise3d/3DMatch_FCGF_Feature_32_transform' , help = 'Path to the dataset' )
1159
1165
parser .add_argument ('--batch_size' , type = int , default = 1 , help = 'Batch size for training' )
1160
- parser .add_argument ('--learning_rate' , type = float , default = 0.0001 , help = 'Learning rate for the optimizer' )
1166
+ parser .add_argument ('--learning_rate' , type = float , default = 2e-5 , help = 'Learning rate for the optimizer' )
1161
1167
parser .add_argument ('--num_epochs' , type = int , default = 500 , help = 'Number of epochs for training' )
1162
1168
parser .add_argument ('--num_node' , type = int , default = 2048 , help = 'Number of nodes in the graph' )
1163
1169
parser .add_argument ('--k' , type = int , default = 12 , help = 'Number of nearest neighbors in KNN graph' )
0 commit comments