From 2179ae303976465557bdf9625a50aab9d05a2fd4 Mon Sep 17 00:00:00 2001 From: Stas Bekman Date: Tue, 10 May 2022 07:58:53 -0700 Subject: [PATCH] [trainer] sharded _load_best_model (#17150) * [trainer] sharded _load_best_model probably needs a test? * undo delete --- src/transformers/trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index aa54f2af1bb566..dda2784718112d 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -1705,7 +1705,7 @@ def _load_best_model(self): # If the model is on the GPU, it still works! load_result = self.model.load_state_dict(state_dict, strict=False) self._issue_warnings_after_load(load_result) - elif os.path.exists(best_model_path, os.path.join(self.state.best_model_checkpoint, WEIGHTS_INDEX_NAME)): + elif os.path.exists(os.path.join(self.state.best_model_checkpoint, WEIGHTS_INDEX_NAME)): # Best model is a sharded checkpoint load_result = load_sharded_checkpoint(self.model, self.state.best_model_checkpoint, strict=False) self._issue_warnings_after_load(load_result)