Skip to content

kakakaaka/TrainerDemo

Repository files navigation

一、main函数数据示例

train_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels']

处理后

{'labels': tensor([1]), 
'input_ids': tensor([[  101,  2572,  3217,  5831,  5496,  2010,  2567,  1010,  3183,  2002,
          2170,  1000,  1996,  7409,  1000,  1010,  1997,  9969,  4487, 23809,
          3436,  2010,  3350,  1012,   102,  7727,  2000,  2032,  2004,  2069,
          1000,  1996,  7409,  1000,  1010,  2572,  3217,  5831,  5496,  2010,
          2567,  1997,  9969,  4487, 23809,  3436,  2010,  3350,  1012,   102,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0]]), 

'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0]])
}

二、sft_main函数数据示例

preprocess_dataset处理后

{
'input_ids': tensor([[21106,   318,   281, 12064,   326,  8477,   257,  4876,    13, 19430,
           257,  2882,   326, 20431, 32543,   262,  2581,    13,   198,   198,
         21017, 46486,    25,   198, 23318,  1115,  9040,   329, 10589,  5448,
            13,   198,   198, 21017, 18261,    25,   198,    16,    13, 47659,
           257, 12974,  5496,   290,   787,  1654,   284,  2291,  6088,   286,
         15921,   290, 13701,    13,   220,   198,    17,    13, 32900,  7987,
           284,  1394,   534,  1767,  4075,   290,  1913,    13,   220,   198,
            18,    13,  3497,  1576,  3993,   290,  5529,   257,  6414,  3993,
          7269,    13, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257,
         50257, 50257, 50257, 50257, 50257, 50257]]), 

'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]), 

'labels': tensor([[21106,   318,   281, 12064,   326,  8477,   257,  4876,    13, 19430,
           257,  2882,   326, 20431, 32543,   262,  2581,    13,   198,   198,
         21017, 46486,    25,   198, 23318,  1115,  9040,   329, 10589,  5448,
            13,   198,   198, 21017, 18261,    25,   198,    16,    13, 47659,
           257, 12974,  5496,   290,   787,  1654,   284,  2291,  6088,   286,
         15921,   290, 13701,    13,   220,   198,    17,    13, 32900,  7987,
           284,  1394,   534,  1767,  4075,   290,  1913,    13,   220,   198,
            18,    13,  3497,  1576,  3993,   290,  5529,   257,  6414,  3993,
          7269,    13, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257,
         50257, 50257, 50257, 50257, 50257, 50257]])
}

三、SFTTrainer的整体架构

SFTTrainer(Supervised Fine-Tuning Trainer)是一个用于监督微调语言模型的训练器。其主要职责包括数据预处理、模型训练、评估以及保存模型等。以下是SFTTrainer的主要组成部分和工作流程:

  1. 初始化(Initialization)

    • 加载预训练模型和分词器。
    • 配置训练参数(如学习率、批次大小、训练轮数等)。
    • 准备训练和验证数据集。
    • 定义优化器和学习率调度器。
    • 设置回调函数以在训练过程中执行特定操作(如日志记录、提前停止等)。
  2. 数据预处理(Data Preprocessing)

    • 对原始数据进行分词,将文本转换为模型可接受的输入格式。
    • 添加标签(labels)以用于计算损失。
    • 进行必要的数据打包和填充。
  3. 训练过程(Training Process)

    • 迭代训练数据,通过前向传播计算损失。
    • 反向传播优化模型参数。
    • 根据设定的步数间隔进行日志记录、模型保存和评估。
  4. 评估与保存(Evaluation and Saving)

    • 在训练过程中定期评估模型性能。
    • 根据评估结果决定是否提前停止训练。
    • 保存最优模型和相关参数。

接下来,我们将详细探讨SFTTrainer的初始化过程。

二、SFTTrainer的初始化过程

1. 初始化函数

SFTTrainer的初始化函数定义如下:

def __init__(
    self,
    model: torch.nn.Module,
    args,
    data_collator=None,
    train_dataset=None,
    eval_dataset=None,
    tokenizer=None,
    callbacks=None,
    max_seq_length: int = 512,
    dataset_text_field: str = "text",
    packing: bool = False,
):
    ...

参数说明:

  • model: 要训练的预训练模型(如GPT-2)。
  • args: 训练参数,类型为TrainingArguments,包含训练配置如学习率、批次大小等。
  • data_collator: 数据整理函数,用于将单个样本整理成批次。
  • train_dataset: 训练数据集。
  • eval_dataset: 验证数据集。
  • tokenizer: 分词器,用于将文本转换为模型输入的张量。
  • callbacks: 回调函数列表,用于在训练过程中执行自定义操作。
  • max_seq_length: 输入序列的最大长度。
  • dataset_text_field: 数据集中包含文本的字段名称。
  • packing: 是否启用数据打包。

2. 设置设备和模型

在初始化过程中,首先确定训练设备(GPU或CPU)并将模型移动到该设备上:

self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model.to(self.device)

3. 其他初始化操作

  • 优化器和调度器: 初始化时并未创建优化器和调度器,而是在训练开始前根据训练步数创建。
  • 随机种子: 为了保证训练的可重复性,设置了随机种子(在训练函数中)。
  • 回调函数: 如果提供了回调函数,则在训练过程中会调用这些函数来执行特定操作,如日志记录和提前停止。

4. 示例代码解析

以下是初始化函数的核心代码片段及其解释:

self.model = model
self.args = args
self.data_collator = data_collator
self.train_dataset = train_dataset
self.eval_dataset = eval_dataset
self.tokenizer = tokenizer
self.callbacks = callbacks if callbacks is not None else []
self.max_seq_length = max_seq_length
self.dataset_text_field = dataset_text_field
self.packing = packing
self.optimizer = None
self.lr_scheduler = None
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model.to(self.device)
self.should_stop = False  # 用于提前停止训练的标志
  • 属性赋值: 将传入的参数赋值给类的属性,以便在其他方法中使用。
  • 设备设置: 确定训练设备并将模型移动到该设备。
  • 停止标志: self.should_stop用于在满足提前停止条件时中断训练。

三、数据预处理流程

数据预处理是训练过程中的关键步骤,主要包括分词和标签的处理。下面详细介绍SFTTrainer中数据预处理的步骤和数学描述。

1. 分词(Tokenization)

分词器将文本数据转换为模型可接受的输入格式(通常是整数ID序列)。在SFTTrainer中,分词函数如下:

def tokenize_function(examples):
    inputs = self.tokenizer(
        examples[self.dataset_text_field],
        truncation=True,
        max_length=self.max_seq_length,
        padding='max_length',  # 为了对齐序列
    )
    inputs['labels'] = inputs['input_ids'].copy()
    return inputs

步骤说明:

  1. 文本提取: 从数据集中提取包含文本的字段(如"text")。
  2. 分词: 使用分词器将文本转换为input_idsattention_mask
    • 截断: 如果文本长度超过max_seq_length,则截断。
    • 填充: 将序列填充到max_seq_length,以便在批处理时对齐。
  3. 标签处理: 将input_ids复制为labels,用于计算语言模型的损失。

2. 标签的数学描述

在语言模型中,标签通常是输入序列的下一个词。因此,labels的定义为:

$$ \text{labels} = \text{input_ids} $$

这意味着模型的目标是预测每个位置的下一个词,损失函数会在input_idslabels之间计算。

3. 数据打包(Packing)

数据打包是为了充分利用序列长度,提高训练效率。在SFTTrainer中,pack_dataset函数的实现较为简单,仅作为占位符:

def pack_dataset(self, dataset):
    # 这里实现数据打包的逻辑
    return dataset  # 返回打包后的数据集

实际应用中,数据打包可能涉及将多个样本拼接成一个长序列,并相应地调整labelsattention_mask

4. 数据流示例

假设原始数据集中的一个样本如下:

{
    "text": "今天天气很好,我们去公园玩耍吧。"
}

经过分词和预处理后,转换为如下格式:

{
    "input_ids": [101, 2592, 3007, 2042, 3651, 2153, 3793, 1400, 1808, 2640, 102, 0, 0, ..., 0],
    "attention_mask": [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, ..., 0],
    "labels": [101, 2592, 3007, 2042, 3651, 2153, 3793, 1400, 1808, 2640, 102, 0, 0, ..., 0]
}

其中:

  • input_ids:文本的分词ID序列,长度为max_seq_length(如512),不足部分填充0
  • attention_mask:表示哪些位置是有效的输入,1表示有效,0表示填充。
  • labels:与input_ids相同,用于计算损失。

5. 数学表示

假设一个批次中有$N$个样本,每个样本的输入序列长度为$L$。数据整理后的批次数据可以表示为:

$$ \mathbf{X} \in \mathbb{N}^{N \times L} \quad \text{(input_ids)} $$

$$ \mathbf{M} \in {0,1}^{N \times L} \quad \text{(attention_mask)} $$

$$ \mathbf{Y} \in \mathbb{N}^{N \times L} \quad \text{(labels)} $$

其中:

  • $\mathbf{X}$:输入的词ID矩阵。
  • $\mathbf{M}$:注意力掩码矩阵。
  • $\mathbf{Y}$:标签词ID矩阵。

这些矩阵将作为模型的输入,用于前向传播和损失计算。

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages