Skip to content

Commit

Permalink
feat(ddp): logger when rank == 0, or no ddp
Browse files Browse the repository at this point in the history
  • Loading branch information
mmmwhy committed Feb 13, 2022
1 parent 6270a0b commit 792709b
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 32 deletions.
22 changes: 17 additions & 5 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,18 +1,14 @@
## 0.0.25 (2022-02-12)
## [0.0.26](https://github.com/mmmwhy/pure_attention/compare/v0.0.25...v0.0.26) (2022-02-13)


### Bug Fixes

* **bert:** change LayerNorm to layer_norm ([a99831e](https://github.com/mmmwhy/pure_attention/commit/a99831ee3b4ad06cadbb0262720c0836717d7508))
* **type:** fix some code mistake ([3e1a81d](https://github.com/mmmwhy/pure_attention/commit/3e1a81dd351f2a31ca03fce7cf8ca80be2b94a6d))


### Features

* **bert:** add tokenizer part ([054df14](https://github.com/mmmwhy/pure_attention/commit/054df14c7dfefc0b2edb47824578b33f4a5c8539))
* **decode:** add some transformer decode code ([52b044b](https://github.com/mmmwhy/pure_attention/commit/52b044b0fa79dcb3b9ba8fcd2747f05bc43de808))
* **layers:** fix import for layerNorm ([eb61b31](https://github.com/mmmwhy/pure_attention/commit/eb61b313458ac18bf4b15271fee2cf7e39f8afde))
* **nlp:** init basic bert code ([f9cb13a](https://github.com/mmmwhy/pure_attention/commit/f9cb13a3e811eb8c44ba8ff1373d688311426927))
* **schedule:** cosine schedule with warmup ([87c8886](https://github.com/mmmwhy/pure_attention/commit/87c88865a685bcf5bb2806d15ae7a1a243676509))


Expand All @@ -22,3 +18,19 @@



## [0.0.20](https://github.com/mmmwhy/pure_attention/compare/eb61b313458ac18bf4b15271fee2cf7e39f8afde...v0.0.20) (2022-02-02)


### Bug Fixes

* **bert:** change LayerNorm to layer_norm ([a99831e](https://github.com/mmmwhy/pure_attention/commit/a99831ee3b4ad06cadbb0262720c0836717d7508))


### Features

* **bert:** add tokenizer part ([054df14](https://github.com/mmmwhy/pure_attention/commit/054df14c7dfefc0b2edb47824578b33f4a5c8539))
* **layers:** fix import for layerNorm ([eb61b31](https://github.com/mmmwhy/pure_attention/commit/eb61b313458ac18bf4b15271fee2cf7e39f8afde))
* **nlp:** init basic bert code ([f9cb13a](https://github.com/mmmwhy/pure_attention/commit/f9cb13a3e811eb8c44ba8ff1373d688311426927))



53 changes: 27 additions & 26 deletions examples/model_chineseNMT/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
""""""
import os
import sys

# 寻找根目录
sys.path.append(os.path.abspath(__file__).split("examples")[0]) # noqa E402

Expand Down Expand Up @@ -39,38 +40,38 @@
class Runner:
def __init__(self, config):
self.logger = init_logger(self.__class__.__name__)

self.train_epochs_num = 15
self.batch_size = 64

self.device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
self.gpu_list = list(range(torch.cuda.device_count()))

self.num_works = len(self.gpu_list) * 4

train_dataset = ChineseNMTDataset("train")
train_sampler = DistributedSampler(train_dataset)
self.train_dataloader = DataLoader(train_dataset, batch_size=self.batch_size,
sampler=train_sampler, num_workers=self.num_works, pin_memory=True)

eval_dataset = ChineseNMTDataset("dev")
eval_sampler = DistributedSampler(eval_dataset)
self.eval_dataloader = DataLoader(eval_dataset, batch_size=self.batch_size,
sampler=eval_sampler, num_workers=self.num_works)

self.total_step = len(self.train_dataloader) * self.train_epochs_num
model = Seq2SeqModel(config).to(local_rank)
self.ddp_model = DDP(model, device_ids=[local_rank], output_device=local_rank)
self.optimizer = torch.optim.Adam(self.ddp_model.parameters(), lr=0.0001, betas=(0.9, 0.98), eps=1e-9)
self.scheduler = get_cosine_schedule_with_warmup(
optimizer=self.optimizer,
num_warmup_steps=int(0.1 * self.total_step),
num_training_steps=self.total_step
optimizer=self.optimizer,
num_warmup_steps=int(0.1 * self.total_step),
num_training_steps=self.total_step
)
self.criterion = torch.nn.CrossEntropyLoss(ignore_index=0).to(local_rank)
self.scaler = GradScaler()
self.start_epoch = 0

def run_epoch(self, dataloader, now_epoch, all_epoch):
# let all processes sync up before starting with a new epoch of training
# 不清楚 dist.barrier() 和 trainloader.sampler.set_epoch 的差异 todo @mmmwhy
Expand All @@ -79,36 +80,36 @@ def run_epoch(self, dataloader, now_epoch, all_epoch):
self.optimizer.zero_grad()
with autocast():
result = self.ddp_model(
row["src_text"].to(local_rank), row["tgt_text"].to(local_rank),
row["src_mask"].unsqueeze(1).to(local_rank), row["tgt_mask"].unsqueeze(1).to(local_rank)
row["src_text"].to(local_rank), row["tgt_text"].to(local_rank),
row["src_mask"].unsqueeze(1).to(local_rank), row["tgt_mask"].unsqueeze(1).to(local_rank)
)
loss = self.criterion(result.view(-1, result.size(-1)), row["tgt_true"].view(-1).to(local_rank))
self.scaler.scale(loss).backward()
self.scaler.step(self.optimizer)
self.scaler.update()

self.scheduler.step()

self.logger.info((
"Epoch: {epoch:03d} / {all_epoch:03d},"
"Step: {step:04d} / {all_step:04d},"
"Loss: {loss:.04f},"
"Lr: {lr:.08f}".format(epoch=now_epoch, all_epoch=all_epoch,
step=step, all_step=len(dataloader),
loss=np.mean(loss.item()),
lr=self.optimizer.param_groups[0]['lr'])))


if local_rank in [-1, 0]:
self.logger.info((
"Epoch: {epoch:03d} / {all_epoch:03d},"
"Step: {step:04d} / {all_step:04d},"
"Loss: {loss:.04f},"
"Lr: {lr:.08f}".format(epoch=now_epoch, all_epoch=all_epoch,
step=step, all_step=len(dataloader),
loss=np.mean(loss.item()),
lr=self.optimizer.param_groups[0]['lr'])))

def train(self):
for now_epoch in range(self.start_epoch, self.train_epochs_num):

# 训练模型
self.ddp_model.train()
self.run_epoch(self.train_dataloader, now_epoch, self.train_epochs_num)

# 验证模型效果
self.ddp_model.eval()
self.run_epoch(self.eval_dataloader, 1, 1)

def run(self):
self.train()

Expand Down
2 changes: 1 addition & 1 deletion package.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"name": "pure_attention",
"version": "0.0.25",
"version": "0.0.26",
"description": "Generate a changelog from git metadata",
"repository": {
"type": "git",
Expand Down

0 comments on commit 792709b

Please # to comment.