Skip to content

Commit 917b8f2

Browse files
committed
optimizer updated
1 parent 8397eb6 commit 917b8f2

File tree

1 file changed

+9
-3
lines changed

1 file changed

+9
-3
lines changed

src/train_eval_egnn.py

+9-3
Original file line numberDiff line numberDiff line change
@@ -692,7 +692,10 @@ def pose_loss(pred_quaternion, pred_translation, gt_pose, delta=1.5):
692692
gt_translation = gt_pose[:3, 3]
693693
gt_rotation = gt_pose[:3, :3]
694694
gt_quaternion = rotation_matrix_to_quaternion(gt_rotation) # Convert [3x3] to [4]
695-
695+
# Normalize the ground truth quaternion
696+
gt_quaternion = F.normalize(gt_quaternion, p=2, dim=-1)
697+
# Normalize the predicted quaternion
698+
pred_quaternion = F.normalize(pred_quaternion, p=2, dim=-1)
696699
# Convert predicted quaternion to rotation matrix
697700
pred_rotation = quaternion_to_matrix(pred_quaternion, device=pred_quaternion.device)
698701

@@ -792,6 +795,7 @@ def train_one_epoch(model, dataloader, optimizer, device, epoch, writer, use_poi
792795

793796
# # Backward pass and optimization step
794797
loss.backward()
798+
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
795799
optimizer.step()
796800

797801
running_loss += loss.item()
@@ -1017,7 +1021,9 @@ def train_model(model, train_loader, val_loader, num_epochs, learning_rate, devi
10171021
log_interval (int): Interval for logging the training progress.
10181022
save_path (str): Path to save the model checkpoints.
10191023
"""
1020-
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
1024+
# optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
1025+
# optimizer = torch.optim.RMSprop(model.parameters(), lr=learning_rate)
1026+
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9)
10211027
best_val_loss = float('inf') # Track the best validation loss
10221028

10231029
# If using PointNet encoder, initialize it separately
@@ -1157,7 +1163,7 @@ def get_args():
11571163
# Add arguments with default values
11581164
parser.add_argument('--base_dir', type=str, default='/home/eavise3d/3DMatch_FCGF_Feature_32_transform', help='Path to the dataset')
11591165
parser.add_argument('--batch_size', type=int, default=1, help='Batch size for training')
1160-
parser.add_argument('--learning_rate', type=float, default=0.0001, help='Learning rate for the optimizer')
1166+
parser.add_argument('--learning_rate', type=float, default=2e-5, help='Learning rate for the optimizer')
11611167
parser.add_argument('--num_epochs', type=int, default=500, help='Number of epochs for training')
11621168
parser.add_argument('--num_node', type=int, default=2048, help='Number of nodes in the graph')
11631169
parser.add_argument('--k', type=int, default=12, help='Number of nearest neighbors in KNN graph')

0 commit comments

Comments
 (0)