diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index aa54f2af1bb566..dda2784718112d 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -1705,7 +1705,7 @@ def _load_best_model(self): # If the model is on the GPU, it still works! load_result = self.model.load_state_dict(state_dict, strict=False) self._issue_warnings_after_load(load_result) - elif os.path.exists(best_model_path, os.path.join(self.state.best_model_checkpoint, WEIGHTS_INDEX_NAME)): + elif os.path.exists(os.path.join(self.state.best_model_checkpoint, WEIGHTS_INDEX_NAME)): # Best model is a sharded checkpoint load_result = load_sharded_checkpoint(self.model, self.state.best_model_checkpoint, strict=False) self._issue_warnings_after_load(load_result)