Huggingface之transformers零基础使用指南
前几篇博文中介绍了Transformer,由于其优越的性能表现,在工业界使用的越来越广泛,同时,配合迁移学习理论,越来越多的Transformer预训练模型和源码库逐渐开源,Huggingface就是其中做的最为出色的一家机构。Huggingface是一家在NLP社区做出杰出贡献的纽约创业公司,其所提供的大量预训练模型和代码等资源被广泛的应用于学术研究当中。Huggingface所开源的Transformers提供了数以千计针对于各种任务的预训练模型模型,开发者可以根据自身的需要,选择模型进行训练或微调,也可阅读api文档和源码, 快速开发新模型。
本篇博文,我们对Huggingface所开源的Transformers进行介绍。在此之前,请通过下行命令安装transformers库:
pip install transformers
1 从AutoClass说起¶
transformers库中提供了上百个算法模型的实现,有BERT模型对应的BertModel类,有BART对应的BartModel类……,每当我们使用对应的预训练模型时,都必须先找到对应类名,然后进行实例化,麻烦吗?非常麻烦!
所以,transformers库中提供统一的入口,也就是我们这里说到的“AutoClass”系列的高级对象,通过在调用“AutoClass”的from_pretrained()方法时指定预训练模型的名称或预训练模型所在目录,即可快速、便捷得完成预训练模型创建。有了“AutoClass”,只需要知道预训练模型的名称,或者将预训练模型下载好,程序将根据预训练模型配置文件中model_type或者预训练模型名称、路径进行模式匹配,自动决定实例化哪一个模型类,不再需要再到该模型在transfors库中对应的类名。“AutoClass”所有类都不能够通过init()方法进行实例化,只能通过from_pretrained()方法实例化指定的类。
如下所示,我们到Huggingface官网下载好一个中文BERT预训练模型,模型所有文件存放在当前目录下的“model/bert-base-chinese”路径下。创建预训练模型时,我们将这一路径传递到from_pretrained()方法,即可完成模型创建,创建好的模型为BertModel类的实例。
from transformers import AutoModel
model = AutoModel.from_pretrained("./models/bert-base-chinese")
print(type(model))
Some weights of the model checkpoint at ./models/bert-base-chinese were not used when initializing BertModel: ['cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.bias', 'cls.predictions.decoder.weight'] - This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model). - This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
<class 'transformers.models.bert.modeling_bert.BertModel'>
可以看到,。这一过程中,之所以会有提示信息显示,是因为有些权重参数并未使用上,这是正常的。
之所以说“AutoClass”是一个系列,是因为“AutoClass”不仅包括便捷创建预训练模型的类对象AutoModel,还包括预训练模型对应的分词器类对象AutoTokenizer,预训练模型配置管理类AutoTokenizer,以及其他各种特色功能的类对象,这里不一一列举,看一参考Huggingface官方文档对“AutoClass”的说明。
2 词的向量表示——AutoTokenizer¶
几乎所有的自然语言处理任务,都是从分词和词的向量表示开始的,Transformer算法模型也不例外,所以,在Huggingface的transformers库中提供了高级API对象——AutoTokenizer,用以加载预训练的分词器实现这一过程。
AutoTokenizer是Huggingface提供的“AutoClass”系列的高级对象,可以便捷的调用tokenizers库(Huggingface提供的专门用于分词等操作的代码库)实现加载预训练的分词器。
通过在AutoTokenizer中定义的from_pretrained方法指定需要加载的分词器名称,即可从网络上自动加载分词器,并实例化tokenizers库中分词器。tokenizers中定义的分词器对象提供非常丰富的功能,例如定义词库、加载词库、截断、填充、指定特殊标记等。
`这里需要注意,大多数情况下,我们都是同时使用预定义的分词器和预训练模型,或者说是配套使用的,例如,我们使用的预训练模型是“bert-base-chinese”,那么,加载分词器是,也必须使用“bert-base-chinese”对应的词库,否则,使用预训练模型就效果将大大降低。`
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("bert-base-chinese")
sentence = "床前明月光"
tokenizer(sentence)
{'input_ids': [101, 2414, 1184, 3209, 3299, 1045, 102], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0], 'attention_mask': [1, 1, 1, 1, 1, 1, 1]}
- 指定目录加载分词器
当然,有时候因为网络原因,也可以先手动从Huggingface官网下载模型,然后在from_pretrained方法中指定本地目录方式进行加载。
tokenizer = AutoTokenizer.from_pretrained("./models/bert-base-chinese")
sentence = "床前明月光"
tokenizer(sentence)
{'input_ids': [101, 2414, 1184, 3209, 3299, 1045, 102], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0], 'attention_mask': [1, 1, 1, 1, 1, 1, 1]}
- 同时转化多个句子
sentence = ["床前明月光", "床前明月光,疑是地上霜。"]
tokenizer(sentence)
{'input_ids': [[101, 2414, 1184, 3209, 3299, 1045, 102], [101, 2414, 1184, 3209, 3299, 1045, 8024, 4542, 3221, 1765, 677, 7458, 511, 102]], 'token_type_ids': [[0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], 'attention_mask': [[1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]}
- 其他参数功能
tokenizer内部还提供其他丰富的参数用于实现多种多样功能:
tokenizer(
["床前明月光", "床前明月光,疑是地上霜。"],
padding=True, # 长度不足max_length时是否进行填充
truncation=True, # 长度超过max_length时是否进行截断
max_length=10,
return_tensors="pt", # 指定返回数据类型,pt:pytorch的张量,tf:TensorFlow的张量
)
{'input_ids': tensor([[ 101, 2414, 1184, 3209, 3299, 1045, 102, 0, 0, 0], [ 101, 2414, 1184, 3209, 3299, 1045, 8024, 4542, 3221, 102]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 0, 0, 0], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])}
3 配置——AutoConfig¶
每一类的算法模型的框架结构都是不一样的,所以超参数配置也不一样,如果每次加载预训练模型,都要用户手动去找到对应的配置项、配置类,就是非不便捷了,所以,在“AutoClass”中也提供有专门的配置管理入口——AutoConfig。
一般来说,就算同一个算法的预训练模型,也可能有不同的网络结构,所以,我们下载的预训练模型本身就提供有一个配置文件,例如在Huggingface官网下载的预训练模型,提供有一个config.json
文件,AutoConfig将从里面加载当前预训练模型的特定配置项信息进行覆盖。
以BERT模型为例,我们下来看看默认的配置项:
from transformers import BertConfig
config = BertConfig()
config
BertConfig { "attention_probs_dropout_prob": 0.1, "classifier_dropout": null, "hidden_act": "gelu", "hidden_dropout_prob": 0.1, "hidden_size": 768, "initializer_range": 0.02, "intermediate_size": 3072, "layer_norm_eps": 1e-12, "max_position_embeddings": 512, "model_type": "bert", "num_attention_heads": 12, "num_hidden_layers": 12, "pad_token_id": 0, "position_embedding_type": "absolute", "transformers_version": "4.24.0", "type_vocab_size": 2, "use_cache": true, "vocab_size": 30522 }
from transformers import AutoConfig
config = AutoConfig.from_pretrained("./models/bert-base-chinese")
config
BertConfig { "_name_or_path": "./models/bert-base-chinese", "attention_probs_dropout_prob": 0.1, "classifier_dropout": null, "directionality": "bidi", "hidden_act": "gelu", "hidden_dropout_prob": 0.1, "hidden_size": 768, "initializer_range": 0.02, "intermediate_size": 3072, "layer_norm_eps": 1e-12, "lstm_dropout_prob": 0.5, "lstm_embedding_size": 768, "max_position_embeddings": 512, "model_type": "bert", "num_attention_heads": 12, "num_hidden_layers": 12, "pad_token_id": 0, "pooler_fc_size": 768, "pooler_num_attention_heads": 12, "pooler_num_fc_layers": 3, "pooler_size_per_head": 128, "pooler_type": "first_token_transform", "position_embedding_type": "absolute", "transformers_version": "4.24.0", "type_vocab_size": 2, "use_cache": true, "vocab_size": 21128 }
可以看到,从预训练模型加载出来的配置项与之前的默认配置项略有不同。而且,这个配置实例就是BertConfig类的实例,如下所示:
type(config)
transformers.models.bert.configuration_bert.BertConfig
通过config实例,我们可以对配置项进行修改,例如,上述配置中,编码器结构为12层编码器层,我们将其修改为5层,如下所示,经过修改后,最终创建的模型编码器只包含5层结构,也只有前5层会加载预训练结构,其他权重将会被舍弃。
config.num_hidden_layers=5
print(config)
BertConfig { "_name_or_path": "./models/bert-base-chinese", "attention_probs_dropout_prob": 0.1, "classifier_dropout": null, "directionality": "bidi", "hidden_act": "gelu", "hidden_dropout_prob": 0.1, "hidden_size": 768, "initializer_range": 0.02, "intermediate_size": 3072, "layer_norm_eps": 1e-12, "lstm_dropout_prob": 0.5, "lstm_embedding_size": 768, "max_position_embeddings": 512, "model_type": "bert", "num_attention_heads": 12, "num_hidden_layers": 5, "pad_token_id": 0, "pooler_fc_size": 768, "pooler_num_attention_heads": 12, "pooler_num_fc_layers": 3, "pooler_size_per_head": 128, "pooler_type": "first_token_transform", "position_embedding_type": "absolute", "transformers_version": "4.24.0", "type_vocab_size": 2, "use_cache": true, "vocab_size": 21128 }
修改之后的参数,如果后续需要再次使用,可以保存到本地,传入保存路径,将在指定目录保存为config.json
文件:
config.save_pretrained("./models/bert-base-chinese")
4 创建预训练模型——AutoModel¶
Huggingface官方提供了很多的预训练模型,可以在Huggingface官网很容易找到。通过AutoModel类,创建预训练模型最简单的方法就是直接传入预训练模型名称或者本地路径,因为国内网络环境原因,建议先去将预训练模型下载到本地,通过指定目录的方式进行加载:
from transformers import AutoModel
model = AutoModel.from_pretrained("./models/bert-base-chinese")
Some weights of the model checkpoint at ./models/bert-base-chinese were not used when initializing BertModel: ['cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.bias', 'cls.predictions.decoder.weight'] - This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model). - This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
通过这种方法,模型将直接加载预训练模型config.json
的配置项。也可以在加载模型时,指定配置类实例,这样就可以实现对预训练模型的自定义,例如,传入我们上一小节中修改后的config实例:
model = AutoModel.from_pretrained("./models/bert-base-chinese", config=config)
Some weights of the model checkpoint at ./models/bert-base-chinese were not used when initializing BertModel: ['cls.seq_relationship.weight', 'bert.encoder.layer.8.attention.self.value.bias', 'bert.encoder.layer.8.attention.output.dense.bias', 'bert.encoder.layer.10.attention.self.query.bias', 'bert.encoder.layer.6.attention.output.LayerNorm.bias', 'cls.predictions.transform.dense.bias', 'bert.encoder.layer.7.attention.output.dense.bias', 'bert.encoder.layer.9.attention.self.value.bias', 'bert.encoder.layer.7.attention.self.query.weight', 'bert.encoder.layer.10.output.LayerNorm.weight', 'bert.encoder.layer.9.attention.output.LayerNorm.bias', 'bert.encoder.layer.10.attention.output.LayerNorm.weight', 'bert.encoder.layer.9.attention.self.key.weight', 'cls.predictions.transform.dense.weight', 'bert.encoder.layer.5.attention.self.query.weight', 'bert.encoder.layer.11.output.LayerNorm.weight', 'bert.encoder.layer.6.attention.self.query.weight', 'bert.encoder.layer.10.output.dense.weight', 'bert.encoder.layer.6.attention.output.dense.bias', 'bert.encoder.layer.5.attention.self.key.weight', 'bert.encoder.layer.7.attention.self.value.bias', 'bert.encoder.layer.5.intermediate.dense.weight', 'bert.encoder.layer.9.intermediate.dense.weight', 'bert.encoder.layer.5.attention.self.query.bias', 'bert.encoder.layer.7.attention.self.key.weight', 'bert.encoder.layer.11.output.dense.weight', 'bert.encoder.layer.8.attention.self.key.weight', 'bert.encoder.layer.10.output.dense.bias', 'bert.encoder.layer.10.attention.output.dense.weight', 'bert.encoder.layer.11.intermediate.dense.bias', 'bert.encoder.layer.9.output.LayerNorm.weight', 'bert.encoder.layer.9.output.LayerNorm.bias', 'bert.encoder.layer.6.attention.self.value.weight', 'bert.encoder.layer.10.attention.output.dense.bias', 'bert.encoder.layer.11.attention.output.dense.bias', 'bert.encoder.layer.10.intermediate.dense.weight', 'bert.encoder.layer.10.attention.self.value.weight', 'bert.encoder.layer.6.attention.self.value.bias', 'bert.encoder.layer.6.attention.self.query.bias', 'bert.encoder.layer.11.intermediate.dense.weight', 'bert.encoder.layer.5.attention.output.LayerNorm.bias', 'bert.encoder.layer.8.output.dense.weight', 'bert.encoder.layer.11.attention.self.query.weight', 'bert.encoder.layer.7.intermediate.dense.bias', 'bert.encoder.layer.9.output.dense.bias', 'bert.encoder.layer.11.attention.self.value.bias', 'bert.encoder.layer.5.attention.self.value.bias', 'bert.encoder.layer.11.attention.output.LayerNorm.weight', 'bert.encoder.layer.10.attention.self.key.bias', 'bert.encoder.layer.9.attention.output.LayerNorm.weight', 'bert.encoder.layer.6.output.dense.weight', 'bert.encoder.layer.6.output.LayerNorm.bias', 'bert.encoder.layer.7.attention.self.key.bias', 'bert.encoder.layer.11.output.LayerNorm.bias', 'bert.encoder.layer.8.output.LayerNorm.weight', 'bert.encoder.layer.11.attention.self.value.weight', 'bert.encoder.layer.8.attention.output.dense.weight', 'bert.encoder.layer.9.attention.output.dense.weight', 'cls.predictions.bias', 'bert.encoder.layer.9.output.dense.weight', 'bert.encoder.layer.8.output.LayerNorm.bias', 'bert.encoder.layer.9.attention.self.key.bias', 'bert.encoder.layer.6.attention.self.key.bias', 'bert.encoder.layer.9.attention.self.query.bias', 'bert.encoder.layer.6.intermediate.dense.bias', 'bert.encoder.layer.6.attention.self.key.weight', 'bert.encoder.layer.8.attention.self.key.bias', 'bert.encoder.layer.7.attention.output.LayerNorm.bias', 'bert.encoder.layer.9.attention.output.dense.bias', 'bert.encoder.layer.6.output.dense.bias', 'bert.encoder.layer.11.attention.self.key.weight', 'bert.encoder.layer.7.output.dense.weight', 'bert.encoder.layer.8.intermediate.dense.bias', 'bert.encoder.layer.5.attention.self.value.weight', 'bert.encoder.layer.7.output.LayerNorm.weight', 'bert.encoder.layer.5.output.dense.weight', 'bert.encoder.layer.11.output.dense.bias', 'bert.encoder.layer.8.output.dense.bias', 'bert.encoder.layer.10.attention.self.query.weight', 'bert.encoder.layer.9.intermediate.dense.bias', 'bert.encoder.layer.9.attention.self.value.weight', 'bert.encoder.layer.11.attention.output.LayerNorm.bias', 'bert.encoder.layer.7.intermediate.dense.weight', 'bert.encoder.layer.7.attention.self.query.bias', 'bert.encoder.layer.7.attention.output.LayerNorm.weight', 'bert.encoder.layer.5.output.LayerNorm.weight', 'cls.predictions.decoder.weight', 'bert.encoder.layer.9.attention.self.query.weight', 'bert.encoder.layer.6.intermediate.dense.weight', 'bert.encoder.layer.10.output.LayerNorm.bias', 'bert.encoder.layer.11.attention.output.dense.weight', 'bert.encoder.layer.10.intermediate.dense.bias', 'bert.encoder.layer.6.attention.output.LayerNorm.weight', 'bert.encoder.layer.11.attention.self.key.bias', 'bert.encoder.layer.8.intermediate.dense.weight', 'bert.encoder.layer.5.output.dense.bias', 'bert.encoder.layer.5.attention.output.dense.bias', 'bert.encoder.layer.8.attention.output.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight', 'bert.encoder.layer.8.attention.self.value.weight', 'cls.predictions.transform.LayerNorm.bias', 'bert.encoder.layer.8.attention.self.query.bias', 'bert.encoder.layer.5.attention.output.dense.weight', 'bert.encoder.layer.7.output.dense.bias', 'cls.seq_relationship.bias', 'bert.encoder.layer.6.attention.output.dense.weight', 'bert.encoder.layer.11.attention.self.query.bias', 'bert.encoder.layer.7.attention.self.value.weight', 'bert.encoder.layer.8.attention.output.LayerNorm.weight', 'bert.encoder.layer.6.output.LayerNorm.weight', 'bert.encoder.layer.5.attention.self.key.bias', 'bert.encoder.layer.10.attention.self.value.bias', 'bert.encoder.layer.5.attention.output.LayerNorm.weight', 'bert.encoder.layer.7.output.LayerNorm.bias', 'bert.encoder.layer.5.intermediate.dense.bias', 'bert.encoder.layer.7.attention.output.dense.weight', 'bert.encoder.layer.10.attention.output.LayerNorm.bias', 'bert.encoder.layer.8.attention.self.query.weight', 'bert.encoder.layer.10.attention.self.key.weight', 'bert.encoder.layer.5.output.LayerNorm.bias'] - This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model). - This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
因为在上一章节,我们将编码器结构层数改为5,所以,这里提示很多权重参数并未使用。
同时,我们也可以通过在from_pretrained()方法中直接传参的方式,传入配置项,例如,我们将编码器层数改为3层。注意,这种方式在指定了config参数时不在生效。
model = AutoModel.from_pretrained("./models/bert-base-chinese", num_hidden_layers=3)
Some weights of the model checkpoint at ./models/bert-base-chinese were not used when initializing BertModel: ['bert.encoder.layer.4.attention.self.value.bias', 'cls.seq_relationship.weight', 'bert.encoder.layer.8.attention.self.value.bias', 'bert.encoder.layer.8.attention.output.dense.bias', 'bert.encoder.layer.10.attention.self.query.bias', 'bert.encoder.layer.6.attention.output.LayerNorm.bias', 'cls.predictions.transform.dense.bias', 'bert.encoder.layer.7.attention.output.dense.bias', 'bert.encoder.layer.9.attention.self.value.bias', 'bert.encoder.layer.7.attention.self.query.weight', 'bert.encoder.layer.10.output.LayerNorm.weight', 'bert.encoder.layer.9.attention.output.LayerNorm.bias', 'bert.encoder.layer.4.intermediate.dense.weight', 'bert.encoder.layer.4.attention.output.LayerNorm.bias', 'bert.encoder.layer.9.attention.self.key.weight', 'cls.predictions.transform.dense.weight', 'bert.encoder.layer.10.attention.output.LayerNorm.weight', 'bert.encoder.layer.4.attention.self.key.weight', 'bert.encoder.layer.3.intermediate.dense.bias', 'bert.encoder.layer.5.attention.self.query.weight', 'bert.encoder.layer.11.output.LayerNorm.weight', 'bert.encoder.layer.6.attention.self.query.weight', 'bert.encoder.layer.10.output.dense.weight', 'bert.encoder.layer.6.attention.output.dense.bias', 'bert.encoder.layer.5.attention.self.key.weight', 'bert.encoder.layer.7.attention.self.value.bias', 'bert.encoder.layer.5.intermediate.dense.weight', 'bert.encoder.layer.9.intermediate.dense.weight', 'bert.encoder.layer.5.attention.self.query.bias', 'bert.encoder.layer.7.attention.self.key.weight', 'bert.encoder.layer.4.output.dense.weight', 'bert.encoder.layer.8.attention.self.key.weight', 'bert.encoder.layer.11.output.dense.weight', 'bert.encoder.layer.10.output.dense.bias', 'bert.encoder.layer.10.attention.output.dense.weight', 'bert.encoder.layer.11.intermediate.dense.bias', 'bert.encoder.layer.3.attention.self.key.weight', 'bert.encoder.layer.9.output.LayerNorm.weight', 'bert.encoder.layer.9.output.LayerNorm.bias', 'bert.encoder.layer.3.attention.self.value.bias', 'bert.encoder.layer.6.attention.self.value.weight', 'bert.encoder.layer.10.attention.output.dense.bias', 'bert.encoder.layer.11.attention.output.dense.bias', 'bert.encoder.layer.10.intermediate.dense.weight', 'bert.encoder.layer.10.attention.self.value.weight', 'bert.encoder.layer.6.attention.self.value.bias', 'bert.encoder.layer.6.attention.self.query.bias', 'bert.encoder.layer.4.attention.output.LayerNorm.weight', 'bert.encoder.layer.3.output.dense.bias', 'bert.encoder.layer.11.intermediate.dense.weight', 'bert.encoder.layer.5.attention.output.LayerNorm.bias', 'bert.encoder.layer.4.attention.self.query.weight', 'bert.encoder.layer.8.output.dense.weight', 'bert.encoder.layer.11.attention.self.query.weight', 'bert.encoder.layer.3.intermediate.dense.weight', 'bert.encoder.layer.4.attention.output.dense.bias', 'bert.encoder.layer.7.intermediate.dense.bias', 'bert.encoder.layer.9.output.dense.bias', 'bert.encoder.layer.11.attention.self.value.bias', 'bert.encoder.layer.5.attention.self.value.bias', 'bert.encoder.layer.11.attention.output.LayerNorm.weight', 'bert.encoder.layer.3.attention.output.dense.weight', 'bert.encoder.layer.10.attention.self.key.bias', 'bert.encoder.layer.9.attention.output.LayerNorm.weight', 'bert.encoder.layer.4.output.LayerNorm.bias', 'bert.encoder.layer.6.output.dense.weight', 'bert.encoder.layer.6.output.LayerNorm.bias', 'bert.encoder.layer.7.attention.self.key.bias', 'bert.encoder.layer.11.output.LayerNorm.bias', 'bert.encoder.layer.8.output.LayerNorm.weight', 'bert.encoder.layer.11.attention.self.value.weight', 'bert.encoder.layer.8.attention.output.dense.weight', 'bert.encoder.layer.9.attention.output.dense.weight', 'cls.predictions.bias', 'bert.encoder.layer.9.output.dense.weight', 'bert.encoder.layer.8.output.LayerNorm.bias', 'bert.encoder.layer.9.attention.self.key.bias', 'bert.encoder.layer.6.attention.self.key.bias', 'bert.encoder.layer.3.output.dense.weight', 'bert.encoder.layer.9.attention.self.query.bias', 'bert.encoder.layer.6.intermediate.dense.bias', 'bert.encoder.layer.6.attention.self.key.weight', 'bert.encoder.layer.8.attention.self.key.bias', 'bert.encoder.layer.7.attention.output.LayerNorm.bias', 'bert.encoder.layer.3.output.LayerNorm.bias', 'bert.encoder.layer.3.attention.output.LayerNorm.bias', 'bert.encoder.layer.9.attention.output.dense.bias', 'bert.encoder.layer.6.output.dense.bias', 'bert.encoder.layer.11.attention.self.key.weight', 'bert.encoder.layer.7.output.dense.weight', 'bert.encoder.layer.8.intermediate.dense.bias', 'bert.encoder.layer.5.attention.self.value.weight', 'bert.encoder.layer.7.output.LayerNorm.weight', 'bert.encoder.layer.5.output.dense.weight', 'bert.encoder.layer.3.attention.self.query.weight', 'bert.encoder.layer.11.output.dense.bias', 'bert.encoder.layer.4.output.LayerNorm.weight', 'bert.encoder.layer.3.attention.self.key.bias', 'bert.encoder.layer.8.output.dense.bias', 'bert.encoder.layer.9.intermediate.dense.bias', 'bert.encoder.layer.10.attention.self.query.weight', 'bert.encoder.layer.9.attention.self.value.weight', 'bert.encoder.layer.11.attention.output.LayerNorm.bias', 'bert.encoder.layer.4.attention.self.key.bias', 'bert.encoder.layer.7.intermediate.dense.weight', 'bert.encoder.layer.7.attention.self.query.bias', 'bert.encoder.layer.7.attention.output.LayerNorm.weight', 'bert.encoder.layer.5.output.LayerNorm.weight', 'cls.predictions.decoder.weight', 'bert.encoder.layer.9.attention.self.query.weight', 'bert.encoder.layer.6.intermediate.dense.weight', 'bert.encoder.layer.10.output.LayerNorm.bias', 'bert.encoder.layer.11.attention.output.dense.weight', 'bert.encoder.layer.10.intermediate.dense.bias', 'bert.encoder.layer.6.attention.output.LayerNorm.weight', 'bert.encoder.layer.11.attention.self.key.bias', 'bert.encoder.layer.8.intermediate.dense.weight', 'bert.encoder.layer.5.output.dense.bias', 'bert.encoder.layer.5.attention.output.dense.bias', 'bert.encoder.layer.8.attention.output.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight', 'bert.encoder.layer.8.attention.self.value.weight', 'cls.predictions.transform.LayerNorm.bias', 'bert.encoder.layer.3.attention.output.LayerNorm.weight', 'bert.encoder.layer.8.attention.self.query.bias', 'bert.encoder.layer.5.attention.output.dense.weight', 'bert.encoder.layer.3.attention.output.dense.bias', 'bert.encoder.layer.7.output.dense.bias', 'bert.encoder.layer.4.attention.output.dense.weight', 'bert.encoder.layer.6.attention.output.dense.weight', 'bert.encoder.layer.11.attention.self.query.bias', 'cls.seq_relationship.bias', 'bert.encoder.layer.4.intermediate.dense.bias', 'bert.encoder.layer.7.attention.self.value.weight', 'bert.encoder.layer.8.attention.output.LayerNorm.weight', 'bert.encoder.layer.3.output.LayerNorm.weight', 'bert.encoder.layer.4.attention.self.query.bias', 'bert.encoder.layer.6.output.LayerNorm.weight', 'bert.encoder.layer.5.attention.self.key.bias', 'bert.encoder.layer.10.attention.self.value.bias', 'bert.encoder.layer.4.output.dense.bias', 'bert.encoder.layer.5.attention.output.LayerNorm.weight', 'bert.encoder.layer.3.attention.self.value.weight', 'bert.encoder.layer.7.output.LayerNorm.bias', 'bert.encoder.layer.5.intermediate.dense.bias', 'bert.encoder.layer.7.attention.output.dense.weight', 'bert.encoder.layer.10.attention.output.LayerNorm.bias', 'bert.encoder.layer.8.attention.self.query.weight', 'bert.encoder.layer.3.attention.self.query.bias', 'bert.encoder.layer.10.attention.self.key.weight', 'bert.encoder.layer.5.output.LayerNorm.bias', 'bert.encoder.layer.4.attention.self.value.weight'] - This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model). - This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
我们尝试一下将tokenizer编码后的张量在model中进行前向传播:
tens = model(**tokenizer("床前明月光,疑是地上霜。", return_tensors="pt"))
tens.last_hidden_state.shape
torch.Size([1, 14, 768])
当模型修改或者重新训练后,可以通过model.save_pretrained()方法再次保存,保存后,在指定目录中将生成两个文件:配置文件(config.json),权重文件(pytorch_model.bin)。
model.save_pretrained("./new_model/bert-base-chinese")
5 使用现成的任务模型¶
在transformers库中,Huggingface还提供有许多完整网络模型用于各式各样的AI任务,例如图像分类、文本分类、语音分类、翻译、问答等,这类API大多以AutoModelFor*开头,我们打印输出看看:
import transformers
for api in dir(transformers):
if api.startswith('AutoModelFor'):
print(api)
AutoModelForAudioClassification AutoModelForAudioFrameClassification AutoModelForAudioXVector AutoModelForCTC AutoModelForCausalLM AutoModelForDepthEstimation AutoModelForDocumentQuestionAnswering AutoModelForImageClassification AutoModelForImageSegmentation AutoModelForInstanceSegmentation AutoModelForMaskedImageModeling AutoModelForMaskedLM AutoModelForMultipleChoice AutoModelForNextSentencePrediction AutoModelForObjectDetection AutoModelForPreTraining AutoModelForQuestionAnswering AutoModelForSemanticSegmentation AutoModelForSeq2SeqLM AutoModelForSequenceClassification AutoModelForSpeechSeq2Seq AutoModelForTableQuestionAnswering AutoModelForTokenClassification AutoModelForVideoClassification AutoModelForVision2Seq AutoModelForVisualQuestionAnswering AutoModelForZeroShotObjectDetection
以其中的AutoModelForSequenceClassification为例,介绍怎么使用:
from transformers import AutoModelForSequenceClassification
model = AutoModelForSequenceClassification.from_pretrained("./models/bert-base-chinese", num_labels=2)
Some weights of the model checkpoint at ./models/bert-base-chinese were not used when initializing BertForSequenceClassification: ['cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.bias', 'cls.predictions.decoder.weight'] - This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model). - This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model). Some weights of BertForSequenceClassification were not initialized from the model checkpoint at ./models/bert-base-chinese and are newly initialized: ['classifier.weight', 'classifier.bias'] You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
num_labels意思是,我们需要进行的任务最终标签有两类,即这是一个二分类模型。我们查看一下模型结构:
model
BertForSequenceClassification( (bert): BertModel( (embeddings): BertEmbeddings( (word_embeddings): Embedding(21128, 768, padding_idx=0) (position_embeddings): Embedding(512, 768) (token_type_embeddings): Embedding(2, 768) (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) (encoder): BertEncoder( (layer): ModuleList( (0): BertLayer( (attention): BertAttention( (self): BertSelfAttention( (query): Linear(in_features=768, out_features=768, bias=True) (key): Linear(in_features=768, out_features=768, bias=True) (value): Linear(in_features=768, out_features=768, bias=True) (dropout): Dropout(p=0.1, inplace=False) ) (output): BertSelfOutput( (dense): Linear(in_features=768, out_features=768, bias=True) (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) ) (intermediate): BertIntermediate( (dense): Linear(in_features=768, out_features=3072, bias=True) (intermediate_act_fn): GELUActivation() ) (output): BertOutput( (dense): Linear(in_features=3072, out_features=768, bias=True) (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) ) (1): BertLayer( (attention): BertAttention( (self): BertSelfAttention( (query): Linear(in_features=768, out_features=768, bias=True) (key): Linear(in_features=768, out_features=768, bias=True) (value): Linear(in_features=768, out_features=768, bias=True) (dropout): Dropout(p=0.1, inplace=False) ) (output): BertSelfOutput( (dense): Linear(in_features=768, out_features=768, bias=True) (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) ) (intermediate): BertIntermediate( (dense): Linear(in_features=768, out_features=3072, bias=True) (intermediate_act_fn): GELUActivation() ) (output): BertOutput( (dense): Linear(in_features=3072, out_features=768, bias=True) (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) ) (2): BertLayer( (attention): BertAttention( (self): BertSelfAttention( (query): Linear(in_features=768, out_features=768, bias=True) (key): Linear(in_features=768, out_features=768, bias=True) (value): Linear(in_features=768, out_features=768, bias=True) (dropout): Dropout(p=0.1, inplace=False) ) (output): BertSelfOutput( (dense): Linear(in_features=768, out_features=768, bias=True) (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) ) (intermediate): BertIntermediate( (dense): Linear(in_features=768, out_features=3072, bias=True) (intermediate_act_fn): GELUActivation() ) (output): BertOutput( (dense): Linear(in_features=3072, out_features=768, bias=True) (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) ) (3): BertLayer( (attention): BertAttention( (self): BertSelfAttention( (query): Linear(in_features=768, out_features=768, bias=True) (key): Linear(in_features=768, out_features=768, bias=True) (value): Linear(in_features=768, out_features=768, bias=True) (dropout): Dropout(p=0.1, inplace=False) ) (output): BertSelfOutput( (dense): Linear(in_features=768, out_features=768, bias=True) (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) ) (intermediate): BertIntermediate( (dense): Linear(in_features=768, out_features=3072, bias=True) (intermediate_act_fn): GELUActivation() ) (output): BertOutput( (dense): Linear(in_features=3072, out_features=768, bias=True) (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) ) (4): BertLayer( (attention): BertAttention( (self): BertSelfAttention( (query): Linear(in_features=768, out_features=768, bias=True) (key): Linear(in_features=768, out_features=768, bias=True) (value): Linear(in_features=768, out_features=768, bias=True) (dropout): Dropout(p=0.1, inplace=False) ) (output): BertSelfOutput( (dense): Linear(in_features=768, out_features=768, bias=True) (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) ) (intermediate): BertIntermediate( (dense): Linear(in_features=768, out_features=3072, bias=True) (intermediate_act_fn): GELUActivation() ) (output): BertOutput( (dense): Linear(in_features=3072, out_features=768, bias=True) (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) ) (5): BertLayer( (attention): BertAttention( (self): BertSelfAttention( (query): Linear(in_features=768, out_features=768, bias=True) (key): Linear(in_features=768, out_features=768, bias=True) (value): Linear(in_features=768, out_features=768, bias=True) (dropout): Dropout(p=0.1, inplace=False) ) (output): BertSelfOutput( (dense): Linear(in_features=768, out_features=768, bias=True) (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) ) (intermediate): BertIntermediate( (dense): Linear(in_features=768, out_features=3072, bias=True) (intermediate_act_fn): GELUActivation() ) (output): BertOutput( (dense): Linear(in_features=3072, out_features=768, bias=True) (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) ) (6): BertLayer( (attention): BertAttention( (self): BertSelfAttention( (query): Linear(in_features=768, out_features=768, bias=True) (key): Linear(in_features=768, out_features=768, bias=True) (value): Linear(in_features=768, out_features=768, bias=True) (dropout): Dropout(p=0.1, inplace=False) ) (output): BertSelfOutput( (dense): Linear(in_features=768, out_features=768, bias=True) (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) ) (intermediate): BertIntermediate( (dense): Linear(in_features=768, out_features=3072, bias=True) (intermediate_act_fn): GELUActivation() ) (output): BertOutput( (dense): Linear(in_features=3072, out_features=768, bias=True) (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) ) (7): BertLayer( (attention): BertAttention( (self): BertSelfAttention( (query): Linear(in_features=768, out_features=768, bias=True) (key): Linear(in_features=768, out_features=768, bias=True) (value): Linear(in_features=768, out_features=768, bias=True) (dropout): Dropout(p=0.1, inplace=False) ) (output): BertSelfOutput( (dense): Linear(in_features=768, out_features=768, bias=True) (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) ) (intermediate): BertIntermediate( (dense): Linear(in_features=768, out_features=3072, bias=True) (intermediate_act_fn): GELUActivation() ) (output): BertOutput( (dense): Linear(in_features=3072, out_features=768, bias=True) (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) ) (8): BertLayer( (attention): BertAttention( (self): BertSelfAttention( (query): Linear(in_features=768, out_features=768, bias=True) (key): Linear(in_features=768, out_features=768, bias=True) (value): Linear(in_features=768, out_features=768, bias=True) (dropout): Dropout(p=0.1, inplace=False) ) (output): BertSelfOutput( (dense): Linear(in_features=768, out_features=768, bias=True) (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) ) (intermediate): BertIntermediate( (dense): Linear(in_features=768, out_features=3072, bias=True) (intermediate_act_fn): GELUActivation() ) (output): BertOutput( (dense): Linear(in_features=3072, out_features=768, bias=True) (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) ) (9): BertLayer( (attention): BertAttention( (self): BertSelfAttention( (query): Linear(in_features=768, out_features=768, bias=True) (key): Linear(in_features=768, out_features=768, bias=True) (value): Linear(in_features=768, out_features=768, bias=True) (dropout): Dropout(p=0.1, inplace=False) ) (output): BertSelfOutput( (dense): Linear(in_features=768, out_features=768, bias=True) (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) ) (intermediate): BertIntermediate( (dense): Linear(in_features=768, out_features=3072, bias=True) (intermediate_act_fn): GELUActivation() ) (output): BertOutput( (dense): Linear(in_features=3072, out_features=768, bias=True) (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) ) (10): BertLayer( (attention): BertAttention( (self): BertSelfAttention( (query): Linear(in_features=768, out_features=768, bias=True) (key): Linear(in_features=768, out_features=768, bias=True) (value): Linear(in_features=768, out_features=768, bias=True) (dropout): Dropout(p=0.1, inplace=False) ) (output): BertSelfOutput( (dense): Linear(in_features=768, out_features=768, bias=True) (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) ) (intermediate): BertIntermediate( (dense): Linear(in_features=768, out_features=3072, bias=True) (intermediate_act_fn): GELUActivation() ) (output): BertOutput( (dense): Linear(in_features=3072, out_features=768, bias=True) (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) ) (11): BertLayer( (attention): BertAttention( (self): BertSelfAttention( (query): Linear(in_features=768, out_features=768, bias=True) (key): Linear(in_features=768, out_features=768, bias=True) (value): Linear(in_features=768, out_features=768, bias=True) (dropout): Dropout(p=0.1, inplace=False) ) (output): BertSelfOutput( (dense): Linear(in_features=768, out_features=768, bias=True) (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) ) (intermediate): BertIntermediate( (dense): Linear(in_features=768, out_features=3072, bias=True) (intermediate_act_fn): GELUActivation() ) (output): BertOutput( (dense): Linear(in_features=3072, out_features=768, bias=True) (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) ) ) ) (pooler): BertPooler( (dense): Linear(in_features=768, out_features=768, bias=True) (activation): Tanh() ) ) (dropout): Dropout(p=0.1, inplace=False) (classifier): Linear(in_features=768, out_features=2, bias=True) )
tokenizer("床前明月光,疑是地上霜。", return_tensors="pt")
{'input_ids': tensor([[ 101, 2414, 1184, 3209, 3299, 1045, 8024, 4542, 3221, 1765, 677, 7458, 511, 102]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])}
tens = model(**tokenizer("床前明月光,疑是地上霜。", return_tensors="pt"))
tens
SequenceClassifierOutput(loss=None, logits=tensor([[-0.4371, -0.1223]], grad_fn=<AddmmBackward0>), hidden_states=None, attentions=None)
最后输出了两个值,分贝对应于两个类别。
6 自定义模型¶
from chb import *
import pandas as pd
from tqdm import tqdm
from collections import defaultdict
import torch
from torch import nn
from torch.utils.data.dataloader import DataLoader
from torch.utils.data.dataset import Dataset
from torch.optim import AdamW
from sklearn.model_selection import train_test_split
from transformers import AutoConfig,AutoModel,AutoTokenizer,get_linear_schedule_with_warmup,logging
import warnings
warnings.filterwarnings('ignore')
RANDOM_SEED = 1000
MAX_LEN = 50
BATCH_SIZE = 64
6.1 自定义数据集¶
x_lst = []
y_lst = []
with open('./data/中文文本-新闻分类数据集','r') as f:
# 获得训练数据的总行数
for _ in tqdm(f,desc='load dataset'):
try:
line = f.readline().replace('\u3000\u3000', '').replace('\n', '')
x, y = line.split('\t')
if y == 'label':
continue
x_lst.append(x)
y_lst.append(y)
except:
pass
load dataset: 5902it [00:00, 55378.83it/s]
len(x_lst), len(y_lst)
(5900, 5900)
x_lst[0], y_lst[0]
('昌平京基鹭府10月29日推别墅1200万套起享97折新浪房产讯(编辑郭彪)京基鹭府(论坛相册户型样板间点评地图搜索)售楼处位于昌平区京承高速北七家出口向西南公里路南。项目预计10月29日开盘,总价1200万元/套起,2012年年底入住。待售户型为联排户型面积为410-522平方米,独栋户型面积为938平方米,双拼户型面积为522平方米。京基鹭府项目位于昌平定泗路与东北路交界处。项目周边配套齐全,幼儿园:伊顿双语幼儿园、温莎双语幼儿园;中学:北师大亚太实验学校、潞河中学(北京市重点);大学:王府语言学校、北京邮电大学、现代音乐学院;医院:王府中西医结合医院(三级甲等)、潞河医院、解放军263医院、安贞医院昌平分院;购物:龙德广场、中联万家商厦、世纪华联超市、瑰宝购物中心、家乐福超市;酒店:拉斐特城堡、鲍鱼岛;休闲娱乐设施:九华山庄、温都温泉度假村、小汤山疗养院、龙脉温泉度假村、小汤山文化广场、皇港高尔夫、高地高尔夫、北鸿高尔夫球场;银行:工商银行、建设银行、中国银行、北京农村商业银行;邮局:中国邮政储蓄;其它:北七家建材城、百安居建材超市、北七家镇武装部、北京宏翔鸿企业孵化基地等,享受便捷生活。京基鹭府坐守定泗路,立汤路交汇处。连接京昌、八达岭、机场高速,南至5环,北上6环,紧邻立汤路,一路向南,直抵鸟巢、水立方、长安街,距北京唯一不堵车的京承高速出口仅1公里,节约出行时间成本,形成了三横、三纵的立体式交通网络项目周边多为别墅项目,人口密度低,交通出行舒适度高。>>报名参加“乐动银十”10月22日大型抄底看房团以上信息仅供参考,最终以开发商公布为准。订阅会员置业刊我们将直接把最新的热盘动向发送到您的邮箱更多热盘推荐:新锐白领淘低价1-2居网罗2万内轨道精装房手握20万咋买板楼2居网罗城南沿轨优质盘关注娃娃教育网罗学区房不足百万元住上通透2居', '房产')
label2id = dict()
id2label = dict()
for i, label in enumerate(set(y_lst)):
label2id[label] = i
id2label[i] = label
tokenizer = AutoTokenizer.from_pretrained("./models/bert-base-chinese")
先把所有的文本都转化为编码,而不是在后续数据集中转化,这样可以避免在后续训练过程中,每一个epoch都要进行转化,提升效率:
token_lens = []
for txt in tqdm(x_lst):
tokens = tokenizer.encode(txt, max_length=512)
token_lens.append(len(tokens))
0%| | 0/5900 [00:00<?, ?it/s]Truncation was not explicitly activated but `max_length` is provided a specific value, please use `truncation=True` to explicitly truncate examples to max length. Defaulting to 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy more precisely by providing a specific strategy to `truncation`. 100%|██████████| 5900/5900 [00:07<00:00, 739.64it/s]
class NewsDataset(Dataset):
def __init__(self,x_lst,y_lst,tokenizer,max_len):
self.x_lst=x_lst
self.y_lst=y_lst
self.tokenizer=tokenizer
self.max_len=max_len
def __len__(self):
return len(self.x_lst)
def __getitem__(self,index):
"""
index 为数据索引,迭代取第index条数据
"""
text=str(self.x_lst[index])
label=label2id[self.y_lst[index]]
encoding=self.tokenizer.encode_plus(
text,
add_special_tokens=True,
max_length=self.max_len,
return_token_type_ids=True,
pad_to_max_length=True,
return_attention_mask=True,
return_tensors='pt',
)
return {
'texts':text,
'input_ids':encoding['input_ids'].flatten(),
'attention_mask':encoding['attention_mask'].flatten(),
'labels':torch.tensor(label,dtype=torch.long)
}
x_train, x_val, y_train, y_val = train_test_split(x_lst, y_lst, test_size=0.15, random_state=RANDOM_SEED) # 划分训练集 测试集
# dataset
train_dataset = NewsDataset(x_train, y_train, tokenizer, MAX_LEN)
val_dataset = NewsDataset(x_val, y_val, tokenizer, MAX_LEN)
# dataloader
train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE)
val_dataloader = DataLoader(val_dataset, batch_size=BATCH_SIZE)
6.2 自定义网络¶
这里我们使用BERT预训练模型,同时接Dropout层和一层线形层,构成自定义网络:
class CustomBERTModel(nn.Module):
def __init__(self, n_classes):
super(CustomBERTModel, self).__init__()
self.bert = AutoModel.from_pretrained("./models/bert-base-chinese")
self.drop = nn.Dropout(p=0.3)
self.out = nn.Linear(self.bert.config.hidden_size, n_classes)
def forward(self, input_ids, attention_mask):
_, pooled_output = self.bert(
input_ids=input_ids,
attention_mask=attention_mask,
return_dict = False
)
output = self.drop(pooled_output) # dropout
return self.out(output)
device = set_device(cuda_index=1)
2022-12-20 16:12:39 set_device line 11 out: cuda:1
n_classes = len(label2id)
model = CustomBERTModel(n_classes)
model = model.to(device)
Some weights of the model checkpoint at ./models/bert-base-chinese were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.bias', 'cls.predictions.transform.dense.weight'] - This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model). - This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
自定义数据集:
6.3 训练¶
def train_epoch(model, data_loader,loss_fn,optimizer,device,scheduler, n_examples):
model.train()
losses = []
correct_predictions = 0
for i, d in bar(data_loader):
input_ids = d["input_ids"].to(device)
attention_mask = d["attention_mask"].to(device)
targets = d["labels"].to(device)
outputs = model(
input_ids=input_ids,
attention_mask=attention_mask
)
_, preds = torch.max(outputs, dim=1)
loss = loss_fn(outputs, targets)
correct_predictions += torch.sum(preds == targets)
losses.append(loss.item())
loss.backward()
nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
scheduler.step()
optimizer.zero_grad()
return correct_predictions.double() / n_examples, np.mean(losses)
def eval_model(model, data_loader, loss_fn, device, n_examples):
model.eval() # 验证预测模式
losses = []
correct_predictions = 0
with torch.no_grad():
for d in data_loader:
input_ids = d["input_ids"].to(device)
attention_mask = d["attention_mask"].to(device)
targets = d["labels"].to(device)
outputs = model(
input_ids=input_ids,
attention_mask=attention_mask
)
_, preds = torch.max(outputs, dim=1)
loss = loss_fn(outputs, targets)
correct_predictions += torch.sum(preds == targets)
losses.append(loss.item())
return correct_predictions.double() / n_examples, np.mean(losses)
EPOCHS = 5 # 训练轮数
optimizer = AdamW(model.parameters(), lr=2e-5)
total_steps = len(train_dataloader) * EPOCHS
scheduler = get_linear_schedule_with_warmup(
optimizer,
num_warmup_steps=0,
num_training_steps=total_steps
)
loss_fn = nn.CrossEntropyLoss().to(device)
best_accuracy = 0
is_best = False
t = Tableprint(['epoch', 'train_accuracy', 'train_loss', 'test_accuracy', 'test_loss', 'is_best'])
t.print_header()
for epoch in range(EPOCHS):
train_acc, train_loss = train_epoch(model,train_dataloader,loss_fn,optimizer,device,scheduler,len(x_train))
val_acc, val_loss = eval_model(
model,
val_dataloader,
loss_fn,
device,
len(x_val)
)
if val_acc > best_accuracy:
is_best = True
torch.save(model.state_dict(), './models/news_classification/best_model_state.bin')
best_accuracy = val_acc
else:
is_best = False
t.print_row(epoch, f"{train_acc:.4f}", f"{train_loss:.4f}", f"{val_acc:.4f}", f"{val_loss:.4f}", is_best)
+======+===========+====================+================+===================+===============+=============+ | | epoch | train_accuracy | train_loss | test_accuracy | test_loss | is_best | +======+===========+====================+================+===================+===============+=============+ | 1 | 0 | 0.6080 | 1.4608 | 0.8893 | 0.5278 | True | +------+-----------+--------------------+----------------+-------------------+---------------+-------------+ | 2 | 1 | 0.9196 | 0.3766 | 0.9096 | 0.3583 | True | +------+-----------+--------------------+----------------+-------------------+---------------+-------------+ | 3 | 2 | 0.9589 | 0.2015 | 0.9153 | 0.3413 | True | +------+-----------+--------------------+----------------+-------------------+---------------+-------------+ | 4 | 3 | 0.9765 | 0.1272 | 0.9153 | 0.3286 | False | +------+-----------+--------------------+----------------+-------------------+---------------+-------------+ | 5 | 4 | 0.9836 | 0.0919 | 0.9220 | 0.3239 | True | +------+-----------+--------------------+----------------+-------------------+---------------+-------------+
使用BERT预训练模型+自定义网络,模型初始时就具有了较高的准确率。
作者:奥辰
微信号:chb1137796095
Github:https://github.com/ChenHuabin321
欢迎加V交流,共同学习,共同进步!
本文版权归作者和博客园共有,欢迎转载,但未经作者同意必须保留此段声明,且在文章页面明显位置给出原文链接,否则保留追究法律责任的权利。