diff --git a/common/checkpointing/snapshot.py b/common/checkpointing/snapshot.py
index 2703efd..f20b25d 100644
--- a/common/checkpointing/snapshot.py
+++ b/common/checkpointing/snapshot.py
@@ -73,8 +73,8 @@ def restore(self, checkpoint: str) -> None:
       snapshot.restore(self.state)
       # we still need to ensure that extra_state has walltime in it
       self.state["extra_state"] = torchsnapshot.StateDict(step=self.step, walltime=0.0)
-
-    logging.info(f"Restored snapshot from {snapshot.path}. ({time.time() - start_time:.05}s")
+    else:
+      logging.info(f"Restored snapshot from {snapshot.path}. ({time.time() - start_time:.05}s")
 
   @classmethod
   def get_torch_snapshot(