-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
41 lines (32 loc) · 1.57 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
import torch
import torch.nn as nn
import config
from tqdm import tqdm
def optimize_batch(model,batch,criterion,optimizer,device,scheduler=None):
ids , mask , token_type_ids , target_tags = batch['ids'] , batch['mask'] , batch['token_type_ids'] , batch['target_tags']
ids , mask , token_type_ids , target_tags = ids.to(device) , mask.to(device) , token_type_ids.to(device) , target_tags.to(device)
output = model(ids,mask,token_type_ids)
active_loss = mask.view(-1) == 1
active_logits = output.view(-1,len(config.LABEL2IDX))
active_labels = torch.where(active_loss, target_tags.view(-1), torch.tensor(criterion.ignore_index).type_as(target_tags))
loss = criterion(active_logits,active_labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if scheduler:
scheduler.step()
return loss
def train_epoch(model, dataloader, criterion, optimizer, device, scheduler=None):
model.train()
total_loss = 0.0
num_batches = len(dataloader)
for batch in tqdm(dataloader):
loss = optimize_batch(model, batch, criterion, optimizer, device, scheduler)
total_loss += loss
average_loss = total_loss / num_batches
return average_loss
def optimization(model,dataloader,criterion,optimizer,device,scheduler=None):
for epoch in range(config.EPOCHS):
avg_loss = train_epoch(model, dataloader, criterion, optimizer, device, scheduler)
torch.save(model.state_dicts,'logs/BERTNer.pth')
print(f"Epoch {epoch + 1}/{config.EPOCHS}, Average Loss: {avg_loss}")