Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

[Unified Checkpoint] update non-merge checkpoint loading, move async_save_info.json location #9321

Merged
merged 25 commits into from
Oct 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
5451d31
[Unified checkpoint] update optimizer async save signal
DesmonDay Aug 21, 2024
4f0b61a
Merge branch 'develop' of https://github.com/PaddlePaddle/PaddleNLP i…
gongel Sep 4, 2024
68470aa
Merge branch 'develop' of https://github.com/PaddlePaddle/PaddleNLP i…
DesmonDay Sep 6, 2024
15e83e2
update paddlepaddle
DesmonDay Sep 6, 2024
6837b2f
split param
DesmonDay Sep 10, 2024
633d742
add save for split param
DesmonDay Sep 24, 2024
55186d7
Merge branch 'develop' of https://github.com/PaddlePaddle/PaddleNLP i…
DesmonDay Oct 9, 2024
b6aa309
fix save split_param
DesmonDay Oct 10, 2024
bf5d72b
add load uc split_param
DesmonDay Oct 11, 2024
9fdaae2
update uc files
DesmonDay Oct 14, 2024
9a210db
Merge branch 'develop' of https://github.com/PaddlePaddle/PaddleNLP i…
DesmonDay Oct 14, 2024
19071ef
update uc files
DesmonDay Oct 14, 2024
223e089
Merge branch 'develop' of https://github.com/PaddlePaddle/PaddleNLP i…
DesmonDay Oct 15, 2024
ae9ddce
Merge branch 'develop' of https://github.com/PaddlePaddle/PaddleNLP i…
DesmonDay Oct 16, 2024
4ab0df1
Merge branch 'develop' of https://github.com/PaddlePaddle/PaddleNLP i…
DesmonDay Oct 22, 2024
cbbc074
update split_param loading
DesmonDay Oct 24, 2024
7678fad
mkdir unified_checkpoint directory
DesmonDay Oct 25, 2024
238888d
rename file
DesmonDay Oct 25, 2024
780040e
Merge branch 'develop' of https://github.com/PaddlePaddle/PaddleNLP i…
DesmonDay Oct 25, 2024
b219ba6
update async handler
DesmonDay Oct 25, 2024
dbd13df
update files
DesmonDay Oct 25, 2024
c758d96
update async_save_info.json file place
DesmonDay Oct 28, 2024
e6db62e
Merge branch 'develop' of https://github.com/PaddlePaddle/PaddleNLP i…
DesmonDay Oct 28, 2024
2dd22ca
update load non-merge
DesmonDay Oct 28, 2024
fd5dea0
fix
DesmonDay Oct 28, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions paddlenlp/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2308,7 +2308,7 @@
if output_dir is None:
output_dir = self.args.output_dir

if PREFIX_CHECKPOINT_DIR in output_dir:
if PREFIX_CHECKPOINT_DIR in os.path.split(output_dir)[-1]:
signal_dir = os.path.join(self.args.output_signal_dir, os.path.split(output_dir)[-1])
else:
signal_dir = self.args.output_signal_dir
Expand Down Expand Up @@ -2606,7 +2606,7 @@
# signal_dir is used for asynchronous saving situations.
signal_dir = self.args.output_signal_dir
if self.args.unified_checkpoint and "async_save" in self.args.unified_checkpoint_config:
if PREFIX_CHECKPOINT_DIR in output_dir:
if PREFIX_CHECKPOINT_DIR in os.path.split(output_dir)[-1]:

Check warning on line 2609 in paddlenlp/trainer/trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer.py#L2609

Added line #L2609 was not covered by tests
signal_dir = os.path.join(signal_dir, os.path.split(output_dir)[-1])
os.makedirs(signal_dir, exist_ok=True)
logger.info(f"Saving model checkpoint finish signal to {signal_dir}")
Expand All @@ -2626,9 +2626,11 @@
"ignore_save_lr_and_optim": self.args.ignore_save_lr_and_optim,
"skip_save_model_weight": "skip_save_model_weight" in self.args.unified_checkpoint_config,
}
if os.path.exists(os.path.join(signal_dir, "async_save_info.json")): # afs cannot overwrite
os.remove(os.path.join(signal_dir, "async_save_info.json"))
with open(os.path.join(signal_dir, "async_save_info.json"), "w") as f:
if os.path.exists(

Check warning on line 2629 in paddlenlp/trainer/trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer.py#L2629

Added line #L2629 was not covered by tests
os.path.join(self.args.output_signal_dir, "async_save_info.json")
): # afs cannot overwrite
os.remove(os.path.join(self.args.output_signal_dir, "async_save_info.json"))
with open(os.path.join(self.args.output_signal_dir, "async_save_info.json"), "w") as f:

Check warning on line 2633 in paddlenlp/trainer/trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer.py#L2632-L2633

Added lines #L2632 - L2633 were not covered by tests
json.dump(save_info, f)

if self.args.should_save:
Expand Down
28 changes: 6 additions & 22 deletions paddlenlp/trainer/unified_checkpoint/unified_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
import copy
import json
import os
import sys

import paddle
from paddle.distributed import fleet
Expand All @@ -31,13 +30,10 @@
from paddlenlp.transformers.model_utils import (
PretrainedModel,
_add_variant,
load_state_dict,
unwrap_model,
)
from paddlenlp.transformers.utils import (
device_guard,
dtype_byte_size,
is_safetensors_available,
)
from paddlenlp.transformers.utils import dtype_byte_size
from paddlenlp.utils.env import (
LORA_WEIGHTS_NAME,
PADDLE_MASTER_WEIGHTS_NAME,
Expand All @@ -56,12 +52,6 @@
from paddlenlp.utils.log import logger
from paddlenlp.utils.nested import nested_copy

if is_safetensors_available():
if sys.platform.startswith("win"):
from safetensors.numpy import load_file
else:
from paddlenlp.utils.safetensors import fast_load_file as load_file

from .async_handler import AsyncCheckpointHandler
from .check_completion import check_unified_checkpoint, check_unified_optimizer
from .load_dynamic import (
Expand Down Expand Up @@ -282,9 +272,9 @@

model_state_dict = get_expected_state_dict(model)
struct2static_name_mappings = {k: v.name for k, v in model_state_dict.items()} # get optimizer param mappings
optimizer_state_dict = load_file(optimizer_path)
optimizer_state_dict = load_state_dict(optimizer_path, None, None, device="expected")

Check warning on line 275 in paddlenlp/trainer/unified_checkpoint/unified_checkpoint.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/unified_checkpoint/unified_checkpoint.py#L275

Added line #L275 was not covered by tests
if has_master_weights:
master_weights = load_file(master_weights_path)
master_weights = load_state_dict(master_weights_path, None, None, device="expected")

Check warning on line 277 in paddlenlp/trainer/unified_checkpoint/unified_checkpoint.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/unified_checkpoint/unified_checkpoint.py#L277

Added line #L277 was not covered by tests

# rename and move to paddle.Tensor
for key in list(optimizer_state_dict.keys()):
Expand All @@ -297,20 +287,14 @@
key_name = "_".join([static_name, key_name[1]])
else:
key_name = "_".join([static_name, key_name[1]])
with device_guard():
weight = paddle.Tensor(optimizer_state_dict.pop(key), zero_copy=True)
weight = weight._copy_to(paddle.framework._current_expected_place(), False)
returned_optim_state_dict[key_name] = weight
returned_optim_state_dict[key_name] = optimizer_state_dict.pop(key)

Check warning on line 290 in paddlenlp/trainer/unified_checkpoint/unified_checkpoint.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/unified_checkpoint/unified_checkpoint.py#L290

Added line #L290 was not covered by tests
returned_optim_state_dict[key_name].name = key_name

if has_master_weights:
returned_optim_state_dict["master_weights"] = {}
for key in list(master_weights.keys()):
static_name = struct2static_name_mappings[key]
with device_guard():
weight = paddle.Tensor(master_weights.pop(key), zero_copy=True)
weight = weight._copy_to(paddle.framework._current_expected_place(), False)
returned_optim_state_dict["master_weights"][static_name] = weight
returned_optim_state_dict["master_weights"][static_name] = master_weights.pop(key)

Check warning on line 297 in paddlenlp/trainer/unified_checkpoint/unified_checkpoint.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/unified_checkpoint/unified_checkpoint.py#L297

Added line #L297 was not covered by tests
returned_optim_state_dict["master_weights"][static_name].name = "_".join([static_name, FP32_MASTER])

return returned_optim_state_dict
Expand Down
Loading