Skip to content

Commit

Permalink
improve heart_train.py
Browse files Browse the repository at this point in the history
  • Loading branch information
ChenS676 committed Jun 16, 2024
1 parent d874659 commit cbaaaaa
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 2 deletions.
2 changes: 1 addition & 1 deletion core/gcns/final_gnn_tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,7 @@ def project_main(): # sourcery skip: avoid-builtin-shadow, low-code-quality
print_logger.info(f'Num parameters: {cfg.model.params}')

optimizer = create_optimizer(model, cfg)
scheduler = LinearDecayLR(optimizer, start_lr=0.01, end_lr=0.001, num_epochs=cfg.train.epochs)
scheduler = LinearDecayLR(optimizer, start_lr=cfg.optimizer.base_lr, end_lr=cfg.optimizer.base_lr/10, num_epochs=cfg.train.epochs)

if cfg.train.finetune:
model = init_model_from_pretrained(model, cfg.train.finetune,
Expand Down
2 changes: 1 addition & 1 deletion core/graphgps/train/heart_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,7 +402,7 @@ def train(self):
wandb.log({"lr": self.scheduler.get_lr()}, step=self.step)

self.tensorboard_writer.add_scalar("Loss/train", loss, epoch)

self.tensorboard_writer.add_scalar("LR/train", self.scheduler.get_lr(), epoch)
if epoch % int(self.report_step) == 0:

self.results_rank = self.merge_result_rank()
Expand Down

0 comments on commit cbaaaaa

Please # to comment.