大模型增量训练--基于transformer制作一个大模型聊天机器人

针对夸夸闲聊数据集,利用UniLM模型进行模型训练及测试,更深入地了解预训练语言模型的使用方法,完成一个生成式闲聊机器人任务。

项目主要结构如下:

  • data 存放数据的文件夹
    • dirty_word.txt 敏感词数据
    • douban_kuakua_qa.txt 原始语料 【数据量:大概20M的样子】==》用于增量训练
    • sample.json 处理后的语料样例
  • kuakua_robot_model 已训练好的模型路径
    • config.json
    • pytorch_model.bin
    • vocab.txt
  • pretrain_model UniLm预训练文件路径
    • config.json
    • pytorch_model.bin
    • vocab.txt
  • chatbot.py 模型推理文件
  • configuration_unilm.py UniLm配置文件
  • data_helper.py 数据预处理文件
  • data_set.py 数据类文件
  • modeling_unilm.py UniLm模型文件
  • train.py 模型训练文件
  • dirty_recognize.py 敏感词检测文件

 

增量训练的数据样例:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
Q:  要去打球赛了求表扬
A:  真棒好好打乒乓球!
Q:  要去打球赛了求表扬
A:  是篮球哈哈哈
Q:  要去打球赛了求表扬
A:  篮板王就是你!
Q:  要去打球赛了求表扬
A:  加油别把鞋踢脏喽
Q:  要去打球赛了求表扬
A:  多买点儿币!
Q:  要去打球赛了求表扬
A:  已经脏了
Q:  要去打球赛了求表扬
A:  好滴
Q:  要去打球赛了求表扬
A:  这个配色是是真心不太合我的胃口,还有为什么白鞋要配黑袜子
Q:  要去打球赛了求表扬
A:  这不是表扬组吗hhh你咋来拆台
Q:  要去打球赛了求表扬
A:  我不是,我没有,别瞎说哈
Q:  要去打球赛了求表扬
A:  全场最帅(・ัω・ั),卡胃踩脚拇指戳肋骨无毒神掌天下无敌,然后需要代打嘛
Q:  要去打球赛了求表扬
A:  你走!
Q:  要去打球赛了求表扬
A:  8要!
Q:  要去打球赛了求表扬
A:  我不,我还想问问什么鞋码,多高多重,打什么位置的

  

注意:由于GitHub不方便放模型文件,因此data文件中douban_kuakua_qa.txt文件、kuakua_robot_model文件夹和pretrain_model文件夹中的模型bin文件,请从百度云盘中下载。【bert模型大小:400MB,用于增量训练的模型,应该是来自https://huggingface.co/bert-base-chinese/tree/main下载的原始bert文件】

文件名称下载地址提取码
pretrain_model 百度云 7h4a
kuakua_robot_model 百度云 j954
data 百度云 3sz3

由于敏感词表中包含大量敏感词,导致百度云的data链接会失效,因此将敏感词之间放到项目的data目录下。

环境配置

模型训练或推理所需环境,请参考requirements.txt文件。

数据处理

数据预处理需要运行data_helper.py文件,会在data文件夹中生成训练集和测试集文件。

命令如下:

python3 data_helper.py
 

注意:如果需要修改数据生成路径或名称,请修改data_helper.py文件147-150行,自行定义。

模型训练

模型训练需要运行train.py文件,会自动生成output_dir文件夹,存放每个epoch保存的模型文件。

命令如下:

python3 train.py --device 0 \
                 --data_dir "data/" \
                 --src_file "train.json" \
                 --model_name_or_path "pretrain_model/" \
                 --max_seq_length 256 \
                 --train_batch_size 16 \
                 --num_train_epochs 10  
 

注意:当服务器资源不同或读者更换数据等时,可以在模型训练时修改响应参数,详细参数说明见代码或阅读书3.5.4小节。

模型训练示例如下:

img.png

模型训练阶段损失值变化如下: img.png

模型推理

模型训练需要运行chatbot.py文件,可以采用项目中以提供的模型,也可以采用自己训练后的模型。

命令如下:

python3 chatbot.py --device 0 --topk 3 --topp 0.95 --max_len 32
 

注意:如果修改模型路径,请修改--model_path参数。

模型推理示例如下: img.png

USER:工作不开心
CHAT-BOT:我也是,不过现在开始努力找工作啦!一起加油鸭~
USER:工作不开心
CHAT-BOT:那就祝福你,能找到喜欢的工作吧。
USER:工作不开心
CHAT-BOT:你也要加油鸭!
USER:工作不开心
CHAT-BOT:我也是,不过总有一天就会好起来的
样例2
USER:我分手了
CHAT-BOT:分手快乐,你会遇到更好的。
USER:我分手了
CHAT-BOT:旧的不去新的不来,你会遇见更好的
USER:我分手了
CHAT-BOT:分的好,下一个更乖
样例3
USER:好想去吃火锅,但是我在减肥
CHAT-BOT:你好棒!一定会瘦下来的!
USER:好像去吃火锅,但是我在减肥
CHAT-BOT:吃火锅是为了能更好的减肥,你真棒!


我自己没有gpu,训练了一天,运行的流程如下:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
D:\book\ChatGPTBook-main\UniLMProj 的目录
2023-09-25  17:54    <DIR>          .
2023-09-25  17:22    <DIR>          ..
2023-09-25  17:22             5,530 chatbot.py
2023-09-25  17:22             2,153 configuration_unilm.py
2023-09-25  21:04    <DIR>          data
2023-09-25  17:22             4,375 data_helper.py
2023-09-25  17:22             9,173 data_set.py
2023-09-25  17:22             1,304 dirty_recognize.py
2023-09-25  17:22    <DIR>          images
2023-09-25  17:22    <DIR>          kuakua_robot_model
2023-09-25  17:22            13,452 modeling_unilm.py
2023-09-25  17:22    <DIR>          pretrain_model
2023-09-25  17:22             4,199 README.md
2023-09-25  17:22                88 requirements.txt
2023-09-25  17:22             8,337 train.py
2023-09-25  17:22             1,861 trie.py
2023-09-25  17:54    <DIR>          __pycache__
              10 个文件         50,472 字节
               7 个目录 175,152,689,152 可用字节
 
D:\book\ChatGPTBook-main\UniLMProj>python data_helper.py
total number of data: 121687
 
D:\book\ChatGPTBook-main\UniLMProj>dir data
 驱动器 D 中的卷是 Data
 卷的序列号是 CA99-555E
 
 D:\book\ChatGPTBook-main\UniLMProj\data 的目录
 
2023-09-25  21:06    <DIR>          .
2023-09-25  17:54    <DIR>          ..
2023-09-25  17:22           245,546 dirty_words.txt
2023-09-25  17:56        21,620,763 douban_kuakua_qa.txt
2023-09-25  17:22               446 sample.json
2023-09-25  21:06        14,272,447 train.json
               4 个文件     36,139,202 字节
               2 个目录 175,138,414,592 可用字节
 
D:\book\ChatGPTBook-main\UniLMProj>python train.py
Traceback (most recent call last):
  File "D:\book\ChatGPTBook-main\UniLMProj\train.py", line 18, in <module>
    import torch
ModuleNotFoundError: No module named 'torch'
 
D:\book\ChatGPTBook-main\UniLMProj>pip install torch
  
D:\book\ChatGPTBook-main\UniLMProj>pip install torch
  
D:\book\ChatGPTBook-main\UniLMProj>
D:\book\ChatGPTBook-main\UniLMProj>python train.py
Traceback (most recent call last):
  File "D:\book\ChatGPTBook-main\UniLMProj\train.py", line 20, in <module>
    from transformers import BertTokenizer
ModuleNotFoundError: No module named 'transformers'
 
D:\book\ChatGPTBook-main\UniLMProj>pip install transformers
  
D:\book\ChatGPTBook-main\UniLMProj>python train.py
Traceback (most recent call last):
  File "D:\book\ChatGPTBook-main\UniLMProj\train.py", line 170, in <module>
    main()
  File "D:\book\ChatGPTBook-main\UniLMProj\train.py", line 86, in main
    model = UnilmForSeq2Seq.from_pretrained(args.model_name_or_path, config=config)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Python311\Lib\site-packages\transformers\modeling_utils.py", line 2740, in from_pretrained
    raise EnvironmentError(
OSError: Error no file named pytorch_model.bin, tf_model.h5, model.ckpt.index or flax_model.msgpack found in directory pretrain_model/.
 
D:\book\ChatGPTBook-main\UniLMProj>python train.py
Loading Train Dataset data/
convert seq2seq example: 108730it [00:29, 3741.34it/s]
Load 108730 data
C:\Python311\Lib\site-packages\transformers\optimization.py:411: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning
  warnings.warn(
09/25/2023 21:45:57 - INFO - __main__ -   ***** CUDA.empty_cache() *****
09/25/2023 21:45:57 - INFO - __main__ -   ***** Running training *****
09/25/2023 21:45:57 - INFO - __main__ -     Batch size = 16
09/25/2023 21:45:57 - INFO - __main__ -     Num steps = 67960
Epoch:   0%|                                                                                    | 0/10 [00:00<?, ?it/s]C:\Python311\Lib\site-packages\torch\optim\lr_scheduler.py:265: UserWarning: To get the last learning rate computed by the scheduler, please use `get_last_lr()`.
  warnings.warn("To get the last learning rate computed by the scheduler, "
Iter (loss=3.279): 100%|████████████████████████████████████████████████████████| 6796/6796 [25:32:28<00:00, 13.53s/it]
09/26/2023 23:18:26 - INFO - __main__ -   ** ** * Saving fine-tuned model and optimizer ** ** * 32:28<00:009.24s/it]
Epoch:  10%|██████▌                                                           | 1/10 [25:32:29<229:52:22, 91949.11s/it]
Iter (loss=3.525):  31%|████████████████▎                                    | 2096/6796 [10:55:33<16:16:17, 12.46s/it]

 

训练了一天,最后的聊天效果:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
D:\book\ChatGPTBook-main\UniLMProj>python chatbot.py --topk 3 --topp 0.95 --max_len 32 --model_name_or_path output_dir\checkpoint-6796\
cpu
Chitchat Robot Starting
user:你好啊
chat-bot:这么可爱的人真的很有缘分呢,你是最棒了!
user:终于训练好了一个大模型机器人了
chat-bot:好厉害!我也想学,但是还是没有做到
user:哈哈哈
chat-bot:你真的很可爱!我也是,但不想被表扬。
user:大模型是什么呢?
chat-bot:我们一起努力吧!
user:你是谁?
chat-bot:我是你!
user:你逗我呢
chat-bot:你也是个可爱的小可爱,我们一起夸

  

我们来分析下模型源码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
import copy
import math
import logging
import torch
from torch import nn
import torch.nn.functional as F
from torch.nn.modules.loss import _Loss
from transformers.modeling_utils import PreTrainedModel
from configuration_unilm import UnilmConfig
from transformers.models.bert.modeling_bert import load_tf_weights_in_bert, BertPooler, BertIntermediate, BertOutput, \
    BertSelfOutput, BertOnlyMLMHead, BertEmbeddings
 
logger = logging.getLogger(__name__)
 
BertLayerNorm = torch.nn.LayerNorm
 
 
class BertSelfAttention(nn.Module):
    def __init__(self, config):
        super(BertSelfAttention, self).__init__()
        if config.hidden_size % config.num_attention_heads != 0:
            raise ValueError(
                "The hidden size (%d) is not a multiple of the number of attention "
                "heads (%d)" % (config.hidden_size, config.num_attention_heads))
        self.num_attention_heads = config.num_attention_heads
        self.attention_head_size = int(
            config.hidden_size / config.num_attention_heads)
        self.all_head_size = self.num_attention_heads * self.attention_head_size
 
        self.query = nn.Linear(config.hidden_size, self.all_head_size)
        self.key = nn.Linear(config.hidden_size, self.all_head_size)
        self.value = nn.Linear(config.hidden_size, self.all_head_size)
 
        self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
 
    def transpose_for_scores(self, x):
        sz = x.size()[:-1] + (self.num_attention_heads,
                              self.attention_head_size)
        x = x.view(*sz)
        return x.permute(0, 2, 1, 3)
 
    def forward(self, hidden_states, attention_mask, history_states=None):
        if history_states is None:
            mixed_query_layer = self.query(hidden_states)
            mixed_key_layer = self.key(hidden_states)
            mixed_value_layer = self.value(hidden_states)
        else:
            x_states = torch.cat((history_states, hidden_states), dim=1)
            mixed_query_layer = self.query(hidden_states)
            mixed_key_layer = self.key(x_states)
            mixed_value_layer = self.value(x_states)
 
        query_layer = self.transpose_for_scores(mixed_query_layer)
        key_layer = self.transpose_for_scores(mixed_key_layer)
        value_layer = self.transpose_for_scores(mixed_value_layer)
 
        attention_scores = torch.matmul(
            query_layer / math.sqrt(self.attention_head_size), key_layer.transpose(-1, -2))
        attention_scores = attention_scores + attention_mask
 
        attention_probs = nn.Softmax(dim=-1)(attention_scores)
 
        attention_probs = self.dropout(attention_probs)
 
        context_layer = torch.matmul(attention_probs, value_layer)
        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
        new_context_layer_shape = context_layer.size()[
                                  :-2] + (self.all_head_size,)
        context_layer = context_layer.view(*new_context_layer_shape)
        return context_layer
 
 
class BertAttention(nn.Module):
    def __init__(self, config):
        super(BertAttention, self).__init__()
        self.self = BertSelfAttention(config)
        self.output = BertSelfOutput(config)
 
    def forward(self, input_tensor, attention_mask, history_states=None):
        self_output = self.self(
            input_tensor, attention_mask, history_states=history_states)
        attention_output = self.output(self_output, input_tensor)
        return attention_output
 
 
class BertLayer(nn.Module):
    def __init__(self, config):
        super(BertLayer, self).__init__()
        self.attention = BertAttention(config)
        self.intermediate = BertIntermediate(config)
        self.output = BertOutput(config)
 
    def forward(self, hidden_states, attention_mask, history_states=None):
        attention_output = self.attention(
            hidden_states, attention_mask, history_states=history_states)
        intermediate_output = self.intermediate(attention_output)
        layer_output = self.output(intermediate_output, attention_output)
        return layer_output
 
 
class BertEncoder(nn.Module):
    def __init__(self, config):
        super(BertEncoder, self).__init__()
        layer = BertLayer(config)
        self.layer = nn.ModuleList([copy.deepcopy(layer)
                                    for _ in range(config.num_hidden_layers)])
 
    def forward(self, hidden_states, attention_mask, output_all_encoded_layers=True, prev_embedding=None,
                prev_encoded_layers=None):
        assert (prev_embedding is None) == (prev_encoded_layers is None)
 
        all_encoder_layers = []
        if (prev_embedding is not None) and (prev_encoded_layers is not None):
            history_states = prev_embedding
            for i, layer_module in enumerate(self.layer):
                hidden_states = layer_module(
                    hidden_states, attention_mask, history_states=history_states)
                if output_all_encoded_layers:
                    all_encoder_layers.append(hidden_states)
                if prev_encoded_layers is not None:
                    history_states = prev_encoded_layers[i]
        else:
            for layer_module in self.layer:
                hidden_states = layer_module(
                    hidden_states, attention_mask)
                if output_all_encoded_layers:
                    all_encoder_layers.append(hidden_states)
        if not output_all_encoded_layers:
            all_encoder_layers.append(hidden_states)
        return all_encoder_layers
 
 
class UnilmPreTrainedModel(PreTrainedModel):
    config_class = UnilmConfig
    load_tf_weights = load_tf_weights_in_bert
    base_model_prefix = "unilm"
 
    def _init_weights(self, module):
        if isinstance(module, (nn.Linear, nn.Embedding)):
            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
        elif isinstance(module, BertLayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)
        if isinstance(module, nn.Linear) and module.bias is not None:
            module.bias.data.zero_()
 
 
class UnilmModel(UnilmPreTrainedModel):
    def __init__(self, config):
        super(UnilmModel, self).__init__(config)
        self.embeddings = BertEmbeddings(config)
        self.encoder = BertEncoder(config)
        self.pooler = BertPooler(config)
        self.init_weights()
 
    def get_extended_attention_mask(self, input_ids, token_type_ids, attention_mask):
        if attention_mask is None:
            attention_mask = torch.ones_like(input_ids)
        if token_type_ids is None:
            token_type_ids = torch.zeros_like(input_ids)
 
        if attention_mask.dim() == 2:
            extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
        elif attention_mask.dim() == 3:
            extended_attention_mask = attention_mask.unsqueeze(1)
        else:
            raise NotImplementedError
        extended_attention_mask = extended_attention_mask.to(
            dtype=next(self.parameters()).dtype)  # fp16 compatibility
        extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
        return extended_attention_mask
 
    def forward(self, input_ids, token_type_ids=None, attention_mask=None, output_all_encoded_layers=True):
        extended_attention_mask = self.get_extended_attention_mask(
            input_ids, token_type_ids, attention_mask)
 
        embedding_output = self.embeddings(
            input_ids, token_type_ids)
        encoded_layers = self.encoder(embedding_output, extended_attention_mask,
                                      output_all_encoded_layers=output_all_encoded_layers)
        sequence_output = encoded_layers[-1]
        pooled_output = self.pooler(sequence_output)
        if not output_all_encoded_layers:
            encoded_layers = encoded_layers[-1]
        return encoded_layers, pooled_output
 
 
class LabelSmoothingLoss(_Loss):
    def __init__(self, label_smoothing=0, tgt_vocab_size=0, ignore_index=0, size_average=None, reduce=None,
                 reduction='mean'):
        assert 0.0 < label_smoothing <= 1.0
        self.ignore_index = ignore_index
        super(LabelSmoothingLoss, self).__init__(
            size_average=size_average, reduce=reduce, reduction=reduction)
 
        assert label_smoothing > 0
        assert tgt_vocab_size > 0
 
        smoothing_value = label_smoothing / (tgt_vocab_size - 2)
        one_hot = torch.full((tgt_vocab_size,), smoothing_value)
        one_hot[self.ignore_index] = 0
        self.register_buffer('one_hot', one_hot.unsqueeze(0))
        self.confidence = 1.0 - label_smoothing
        self.tgt_vocab_size = tgt_vocab_size
 
    def forward(self, output, target):
        assert self.tgt_vocab_size == output.size(2)
        batch_size, num_pos = target.size(0), target.size(1)
        output = output.view(-1, self.tgt_vocab_size)
        target = target.view(-1)
        model_prob = self.one_hot.repeat(target.size(0), 1)
        model_prob.scatter_(1, target.unsqueeze(1), self.confidence)
        model_prob.masked_fill_((target == self.ignore_index).unsqueeze(1), 0)
 
        return F.kl_div(output, model_prob.type_as(output), reduction='none').view(batch_size, num_pos, -1).sum(2)
 
 
class UnilmForSeq2Seq(UnilmPreTrainedModel):
    """UniLM模型进行Seq2Seq的训练模型类"""
 
    def __init__(self, config):
        """模型初始化函数,定义模型训练所需的各个模块"""
        super(UnilmForSeq2Seq, self).__init__(config)
        self.bert = UnilmModel(config)
        self.cls = BertOnlyMLMHead(config)
        self.mask_lm = nn.CrossEntropyLoss(reduction='none')
        if hasattr(config, 'label_smoothing') and config.label_smoothing:
            self.mask_lm_smoothed = LabelSmoothingLoss(config.label_smoothing, config.vocab_size, ignore_index=0,
                                                       reduction='none')
        else:
            self.mask_lm_smoothed = None
        self.init_weights()
        self.tie_weights()
 
    def tie_weights(self):
        """权重加载,加载预训练模型的embeddings部分权重"""
        self._tie_or_clone_weights(self.cls.predictions.decoder,
                                   self.bert.embeddings.word_embeddings)
 
    def forward(self, input_ids, token_type_ids=None, attention_mask=None, masked_lm_labels=None, masked_pos=None,
                masked_weights=None):
        """模型forward,向前传递函数"""
        # 获取Encoder部分的序列输出,维度[bs,seq_len,hidden_size]
        sequence_output, __ = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False)
 
        def gather_seq_out_by_pos(seq, pos):
            return torch.gather(seq, 1, pos.unsqueeze(2).expand(-1, -1, seq.size(-1)))
 
        def loss_mask_and_normalize(loss, mask):
            mask = mask.type_as(loss)
            loss = loss * mask
            denominator = torch.sum(mask) + 1e-5
            return (loss / denominator).sum()
 
        if masked_lm_labels is None:
            if masked_pos is None:
                prediction_scores = self.cls(sequence_output)
            else:
                sequence_output_masked = gather_seq_out_by_pos(sequence_output, masked_pos)
                prediction_scores = self.cls(sequence_output_masked)
            return prediction_scores
        # 获取被掩码位置的向量
        sequence_output_masked = gather_seq_out_by_pos(sequence_output, masked_pos)
        prediction_scores_masked = self.cls(sequence_output_masked)
        if self.mask_lm_smoothed:
            masked_lm_loss = self.mask_lm_smoothed(F.log_softmax(prediction_scores_masked.float(), dim=-1),
                                                   masked_lm_labels)
        else:
            masked_lm_loss = self.mask_lm(prediction_scores_masked.transpose(1, 2).float(), masked_lm_labels)
        # 计算[Mask]标记的损失值
        masked_lm_loss = loss_mask_and_normalize(masked_lm_loss.float(), masked_weights)
 
        return masked_lm_loss
 
 
class UnilmForSeq2SeqDecodeSample(UnilmPreTrainedModel):
    """UniLM模型进行Seq2Seq的模型解码类"""
    def __init__(self, config):
        """模型初始化函数,定义模型训练所需的各个模块"""
        super(UnilmForSeq2SeqDecodeSample, self).__init__(config)
        self.bert = UnilmModel(config)
        self.cls = BertOnlyMLMHead(config)
        self.init_weights()
        self.tie_weights()
 
    def tie_weights(self):
        """权重加载,加载预训练模型的embeddings部分权重"""
        self._tie_or_clone_weights(self.cls.predictions.decoder, self.bert.embeddings.word_embeddings)
 
    def forward(self, input_ids, token_type_ids, attention_mask):
        # 获取Encoder部分的序列输出,维度[bs,seq_len,hidden_size]
        sequence_output, __ = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False)
        # 获取最优一个节点的输出
        last_hidden = sequence_output[:, -1:, :]
        # 将其映射到词表中,为后面解码提供内容
        prediction_scores = self.cls(last_hidden)
        return prediction_scores

  

文件定义了一个基于UniLM(Unified Language Model)的Seq2Seq模型,主要用于序列生成任务。UniLM是一种预训练的语言模型,它在单一的语言模型架构下整合了双向和单向的语言模型。

文件中定义了以下几个主要的类:

1. BertSelfAttention:这是一个自注意力机制的实现,用于计算输入序列中每个元素的注意力分数。

2. BertAttention、BertLayer、BertEncoder:这些类是BERT模型的主要组成部分,用于处理输入序列并生成隐藏状态。

3. UnilmPreTrainedModel:这是一个预训练模型的基类,定义了权重初始化和加载预训练权重的方法。

4. UnilmModel:这是UniLM模型的主要实现,它包含了BERT的嵌入层、编码器和池化层。

5. LabelSmoothingLoss:这是一个实现了标签平滑的损失函数,用于训练过程中减少模型对于标签的过拟合。

6. UnilmForSeq2Seq:这是一个用于序列到序列任务的UniLM模型,它在UnilmModel的基础上添加了一个预测头,用于预测下一个词。

7. UnilmForSeq2SeqDecodeSample:这是一个用于序列到序列任务的解码器,它使用UnilmModel生成的隐藏状态,通过预测头生成下一个词的预测。

总的来说,这个文件定义的模型结构主要用于处理序列到序列的任务,如机器翻译、文本摘要等。模型的最终目标是根据输入的序列生成一个新的序列

 

【模型训练】

这个训练代码使用的模型是UnilmForSeq2Seq,这是一个基于UniLM(Unified Language Model)的序列到序列模型。这个模型主要用于处理序列生成任务,如机器翻译、文本摘要等。在代码中,模型的加载过程如下:
 
config = UnilmConfig.from_pretrained(args.model_name_or_path)
tokenizer = BertTokenizer.from_pretrained(args.model_name_or_path, do_lower_case=args.do_lower_case)
model = UnilmForSeq2Seq.from_pretrained(args.model_name_or_path, config=config)
model.to(device)
 
这段代码首先从预训练模型的路径加载UniLM的配置和BERT的分词器,然后使用这些配置和分词器从预训练模型的路径加载UnilmForSeq2Seq模型,并将模型移动到指定的设备上(如果有GPU则使用GPU,否则使用CPU)。

 

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
class UnilmForSeq2Seq(UnilmPreTrainedModel):
    """UniLM模型进行Seq2Seq的训练模型类"""
 
    def __init__(self, config):
        """模型初始化函数,定义模型训练所需的各个模块"""
        super(UnilmForSeq2Seq, self).__init__(config)
        self.bert = UnilmModel(config)
        self.cls = BertOnlyMLMHead(config)
        self.mask_lm = nn.CrossEntropyLoss(reduction='none')
        if hasattr(config, 'label_smoothing') and config.label_smoothing:
            self.mask_lm_smoothed = LabelSmoothingLoss(config.label_smoothing, config.vocab_size, ignore_index=0,
                                                       reduction='none')
        else:
            self.mask_lm_smoothed = None
        self.init_weights()
        self.tie_weights()
 
    def tie_weights(self):
        """权重加载,加载预训练模型的embeddings部分权重"""
        self._tie_or_clone_weights(self.cls.predictions.decoder,
                                   self.bert.embeddings.word_embeddings)
 
    def forward(self, input_ids, token_type_ids=None, attention_mask=None, masked_lm_labels=None, masked_pos=None,
                masked_weights=None):
        """模型forward,向前传递函数"""
        # 获取Encoder部分的序列输出,维度[bs,seq_len,hidden_size]
        sequence_output, __ = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False)
 
        def gather_seq_out_by_pos(seq, pos):
            return torch.gather(seq, 1, pos.unsqueeze(2).expand(-1, -1, seq.size(-1)))
 
        def loss_mask_and_normalize(loss, mask):
            mask = mask.type_as(loss)
            loss = loss * mask
            denominator = torch.sum(mask) + 1e-5
            return (loss / denominator).sum()
 
        if masked_lm_labels is None:
            if masked_pos is None:
                prediction_scores = self.cls(sequence_output)
            else:
                sequence_output_masked = gather_seq_out_by_pos(sequence_output, masked_pos)
                prediction_scores = self.cls(sequence_output_masked)
            return prediction_scores
        # 获取被掩码位置的向量
        sequence_output_masked = gather_seq_out_by_pos(sequence_output, masked_pos)
        prediction_scores_masked = self.cls(sequence_output_masked)
        if self.mask_lm_smoothed:
            masked_lm_loss = self.mask_lm_smoothed(F.log_softmax(prediction_scores_masked.float(), dim=-1),
                                                   masked_lm_labels)
        else:
            masked_lm_loss = self.mask_lm(prediction_scores_masked.transpose(1, 2).float(), masked_lm_labels)
        # 计算[Mask]标记的损失值
        masked_lm_loss = loss_mask_and_normalize(masked_lm_loss.float(), masked_weights)
 
        return masked_lm_loss

 

我们重点看看这个模型类:

UnilmForSeq2Seq是一个基于UniLM模型的序列到序列模型,主要用于处理序列生成任务,如机器翻译、文本摘要等。下面是UnilmForSeq2Seq模型的主要组成部分及其功能:

1. self.bert = UnilmModel(config):这是UniLM模型的主体部分,包括BERT的嵌入层、编码器和池化层。这部分用于处理输入序列并生成隐藏状态。

2. self.cls = BertOnlyMLMHead(config):这是一个预测头,用于预测下一个词。它接收UnilmModel生成的隐藏状态,并输出每个词的预测分数。

3. self.mask_lm = nn.CrossEntropyLoss(reduction='none'):这是一个交叉熵损失函数,用于计算预测和真实标签之间的损失。

4. self.mask_lm_smoothed = LabelSmoothingLoss(config.label_smoothing, config.vocab_size, ignore_index=0, reduction='none'):这是一个实现了标签平滑的损失函数,用于训练过程中减少模型对于标签的过拟合。

5. forward函数:这是模型的前向传播函数,它接收输入序列、注意力掩码和标签,然后通过UnilmModel和预测头计算预测分数,最后使用损失函数计算损失。

总的来说,UnilmForSeq2Seq模型的主要功能是根据输入的序列生成一个新的序列,并通过计算预测和真实标签之间的损失进行训练。  

 

posted @   bonelee  阅读(1008)  评论(0编辑  收藏  举报
相关博文:
阅读排行:
· DeepSeek “源神”启动!「GitHub 热点速览」
· 我与微信审核的“相爱相杀”看个人小程序副业
· 微软正式发布.NET 10 Preview 1:开启下一代开发框架新篇章
· 如何使用 Uni-app 实现视频聊天(源码,支持安卓、iOS)
· C# 集成 DeepSeek 模型实现 AI 私有化(本地部署与 API 调用教程)
历史上的今天:
2022-09-27 PEiD查壳软件
2022-09-27 在windows中使用strings
2022-09-27 计算hash和md5的工具md5deep
2020-09-27 奇安信:红队视角下的防御体系突破.pdf
2017-09-27 npm太慢, 淘宝npm镜像使用方法
点击右上角即可分享
微信分享提示