Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

[Distributed Dataloader] change process new_group creation #9438

Merged
merged 1 commit into from
Nov 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 14 additions & 14 deletions paddlenlp/data/dist_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@

eval = kwargs.pop("eval", False)
is_iterable_dataset = kwargs.pop("is_iterable_dataset", False)
self._pp_data_group = kwargs.pop("pp_data_group", None)

Check warning on line 69 in paddlenlp/data/dist_dataloader.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/data/dist_dataloader.py#L69

Added line #L69 was not covered by tests

if dataset is None:
dataset = DummyDataset() if not is_iterable_dataset else IterableDummyDataset()
Expand All @@ -78,10 +79,8 @@

# Init pp data comm group.
if self._hcg.get_pipe_parallel_world_size() > 1:
self._pp_data_group = self._init_dataloader_comm_group()
self._pp_group = self._hcg.get_pipe_parallel_group()
else:
self._pp_data_group = None
self._pp_group = None

self.mp_group = self._hcg.get_model_parallel_group()
Expand Down Expand Up @@ -132,18 +131,6 @@
else:
raise ValueError("raise error for `paddlenlp.trainer.trainer_utils.has_length`")

def _init_dataloader_comm_group(self):
topo = self._hcg._topo
parallel_comm_group = None
parallel_groups = topo.get_comm_list("pipe")

for group in parallel_groups:
ranks = [group[0], group[-1]]
comm_group = paddle.distributed.new_group(ranks=ranks)
if paddle.distributed.get_rank() in ranks:
parallel_comm_group = comm_group
return parallel_comm_group

def __iter__(self):
return self

Expand Down Expand Up @@ -212,3 +199,16 @@
logger.debug(e)
data = self._broadcast_data(data)
return data


def init_dataloader_comm_group():
hcg = fleet.get_hybrid_communicate_group()
topo = hcg._topo
parallel_groups = topo.get_comm_list("pipe")

Check warning on line 207 in paddlenlp/data/dist_dataloader.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/data/dist_dataloader.py#L205-L207

Added lines #L205 - L207 were not covered by tests

for group in parallel_groups:
ranks = [group[0], group[-1]]
comm_group = paddle.distributed.new_group(ranks=ranks)
if paddle.distributed.get_rank() in ranks:
parallel_comm_group = comm_group
return parallel_comm_group

Check warning on line 214 in paddlenlp/data/dist_dataloader.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/data/dist_dataloader.py#L209-L214

Added lines #L209 - L214 were not covered by tests
31 changes: 16 additions & 15 deletions paddlenlp/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@
DataCollatorWithPadding,
DistDataLoader,
default_data_collator,
init_dataloader_comm_group,
)
from ..peft import LoRAModel, PrefixModelForCausalLM, VeRAModel

Expand Down Expand Up @@ -440,6 +441,10 @@

model.apply(fn)

self._pp_data_group = None
if self.args.pipeline_parallel_degree > 1 and self.args.distributed_dataloader:
self._pp_data_group = init_dataloader_comm_group()

Check warning on line 446 in paddlenlp/trainer/trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer.py#L446

Added line #L446 was not covered by tests

default_label_names = (
["start_positions", "end_positions"]
if "QusetionAnswering" in type(self.model).__name__ or "UIE" in type(self.model).__name__
Expand Down Expand Up @@ -1537,6 +1542,7 @@
train_dataset = self._remove_unused_columns(train_dataset, description="training")
_DataLoader = DistDataLoader if self.args.distributed_dataloader else DataLoader

additional_configs = {}
if is_iterable_dataset: # For iterable dataset
if self.args.dataset_world_size > 1 and train_dataset is not None:
train_dataset = IterableDatasetShard(
Expand All @@ -1549,9 +1555,7 @@

if self.args.distributed_dataloader:
logger.info("Training using DistDataLoader.")
additional_configs = {"is_iterable_dataset": True}
else:
additional_configs = {}
additional_configs = {"is_iterable_dataset": True, "pp_data_group": self._pp_data_group}

Check warning on line 1558 in paddlenlp/trainer/trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer.py#L1558

Added line #L1558 was not covered by tests
return _DataLoader(
train_dataset,
batch_size=self.args.per_device_train_batch_size,
Expand All @@ -1563,11 +1567,13 @@
train_sampler = self._get_train_sampler()
if self.args.distributed_dataloader:
logger.info("Training using DistDataLoader.")
additional_configs = {"pp_data_group": self._pp_data_group}

Check warning on line 1570 in paddlenlp/trainer/trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer.py#L1570

Added line #L1570 was not covered by tests
return _DataLoader(
train_dataset,
batch_sampler=train_sampler,
collate_fn=self.data_collator,
num_workers=self.args.dataloader_num_workers,
**additional_configs,
)

def _get_eval_sampler(self, eval_dataset: Dataset):
Expand Down Expand Up @@ -1623,6 +1629,7 @@
eval_dataset = self._remove_unused_columns(eval_dataset, description="evaluation")
_DataLoader = DistDataLoader if self.args.distributed_dataloader else DataLoader

additional_configs = {}
if is_iterable_dataset:
if self.args.dataset_world_size > 1 and eval_dataset is not None:
eval_dataset = IterableDatasetShard(
Expand All @@ -1632,11 +1639,10 @@
num_processes=self.args.dataset_world_size,
process_index=self.args.dataset_rank,
)

if self.args.distributed_dataloader:
logger.info("Eval using DistDataLoader.")
additional_configs = {"eval": True, "is_iterable_dataset": True}
else:
additional_configs = {}
additional_configs = {"eval": True, "is_iterable_dataset": True, "pp_data_group": self._pp_data_group}

Check warning on line 1645 in paddlenlp/trainer/trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer.py#L1645

Added line #L1645 was not covered by tests
return _DataLoader(
eval_dataset,
batch_size=self.args.per_device_eval_batch_size,
Expand All @@ -1648,9 +1654,7 @@
eval_sampler = self._get_eval_sampler(eval_dataset)
if self.args.distributed_dataloader:
logger.info("Eval using DistDataLoader.")
additional_configs = {"eval": True}
else:
additional_configs = {}
additional_configs = {"eval": True, "pp_data_group": self._pp_data_group}

Check warning on line 1657 in paddlenlp/trainer/trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer.py#L1657

Added line #L1657 was not covered by tests
return _DataLoader(
eval_dataset,
batch_sampler=eval_sampler,
Expand Down Expand Up @@ -1683,6 +1687,7 @@
test_dataset = self._remove_unused_columns(test_dataset, description="test")
_DataLoader = DistDataLoader if self.args.distributed_dataloader else DataLoader

additional_config = {}

Check warning on line 1690 in paddlenlp/trainer/trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer.py#L1690

Added line #L1690 was not covered by tests
if is_iterable_dataset:
if self.args.dataset_world_size > 1 and test_dataset is not None:
test_dataset = IterableDatasetShard(
Expand All @@ -1695,9 +1700,7 @@

if self.args.distributed_dataloader:
logger.info("Test using DistDataLoader.")
additional_config = {"eval": True, "is_iterable_dataset": True}
else:
additional_config = {}
additional_config = {"eval": True, "is_iterable_dataset": True, "pp_data_group": self._pp_data_group}

Check warning on line 1703 in paddlenlp/trainer/trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer.py#L1703

Added line #L1703 was not covered by tests
return _DataLoader(
test_dataset,
batch_size=self.args.per_device_eval_batch_size * self.world_size,
Expand All @@ -1709,9 +1712,7 @@
test_sampler = self._get_eval_sampler(test_dataset)
if self.args.distributed_dataloader:
logger.info("Test using DistDataLoader.")
additional_config = {"eval": True}
else:
additional_config = {}
additional_config = {"eval": True, "pp_data_group": self._pp_data_group}

Check warning on line 1715 in paddlenlp/trainer/trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer.py#L1715

Added line #L1715 was not covered by tests
# We use the same batch_size as for eval.
return _DataLoader(
test_dataset,
Expand Down
Loading