bert微调(1)
bert微调步骤:
首先从主函数开刀:
copy run_classifier.py 随便重命名 my_classifier.py
先看主函数:
if __name__ == "__main__": flags.mark_flag_as_required("data_dir") flags.mark_flag_as_required("task_name") flags.mark_flag_as_required("vocab_file") flags.mark_flag_as_required("bert_config_file") flags.mark_flag_as_required("output_dir") tf.app.run()
1,data_dir
flags.mark_flag_as_required("data_dir")中data_dir为数据的路径文件夹,数据格式已经定义好了:
class InputExample(object): """A single training/test example for simple sequence classification.""" def __init__(self, guid, text_a, text_b=None, label=None): """Constructs a InputExample. Args: guid: Unique id for the example. text_a: string. The untokenized text of the first sequence. For single sequence tasks, only this sequence must be specified. text_b: (Optional) string. The untokenized text of the second sequence. Only must be specified for sequence pair tasks. label: (Optional) string. The label of the example. This should be specified for train and dev examples, but not for test examples. """ self.guid = guid self.text_a = text_a self.text_b = text_b self.label = label
要求的数据格式是:必选参数:guid, text_a,可选参数text_b, label
其中单句子分类任务不需要text_b,且在test数据样本中不需要输入label
2,task_name
processors = { "cola": ColaProcessor, "mnli": MnliProcessor, "mrpc": MrpcProcessor, "xnli": XnliProcessor, }
其中task_name表示processors这个字典中的键值对,在bert中给了四个,分别是:"cola","mnli","mrpc","xnli",如果需要别的,另行添加
值得注意的是:
task_name = FLAGS.task_name.lower() if task_name not in processors: raise ValueError("Task not found: %s" % (task_name)) processor = processors[task_name]() label_list = processor.get_labels()
task_name是用来选择processor的,在bert的源码中有4个processors,而我们进行微调,需要自定义自己的processor,如下:
class MrpcProcessor(DataProcessor): """Processor for the MRPC data set (GLUE version).""" def get_train_examples(self, data_dir): """See base class.""" return self._create_examples( self._read_tsv(os.path.join(data_dir, "train.tsv")), "train") def get_dev_examples(self, data_dir): """See base class.""" return self._create_examples( self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev") def get_test_examples(self, data_dir): """See base class.""" return self._create_examples( self._read_tsv(os.path.join(data_dir, "test.tsv")), "test") def get_labels(self): """See base class.""" return ["0", "1"] #todo def _create_examples(self, lines, set_type): """Creates examples for the training and dev sets.""" examples = [] for (i, line) in enumerate(lines): if i == 0: continue guid = "%s-%s" % (set_type, i) text_a = tokenization.convert_to_unicode(line[3]) text_b = tokenization.convert_to_unicode(line[4]) if set_type == "test": label = "0" else: label = tokenization.convert_to_unicode(line[0]) examples.append( InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) return examples
其实processor表示对数据进行处理的类,它继承了DataProcessor类对输入数据进行预处理,此外,在data_dir文件夹中,我们的文件格式为.tsv格式,由于设定的分类为二分类,我们将label设置为了0,1
同时_create_examples()中,给定了如何获取guid以及如何给text_a, text_b和label赋值。
主函数的前两句代码看完了,继续看主函数
if __name__ == "__main__": flags.mark_flag_as_required("data_dir") flags.mark_flag_as_required("task_name") flags.mark_flag_as_required("vocab_file") flags.mark_flag_as_required("bert_config_file") flags.mark_flag_as_required("output_dir") tf.app.run()
3,vocab_file, bert_config_file, output_dir
其中,vocab_file, bert_config_file分别是下载预训练模型的文件,output_dir表示输出的微调之后的model
此外,在前面所说的.tsv文件格式类似于.csv文件
train.tsv和dev.tsv文件格式
标签+“/t”(制表符)+句子
test文件为
句子
4,修改processors字典,添加自己的分类
processors = { "cola": ColaProcessor, "mnli": MnliProcessor, "mrpc": MrpcProcessor, "xnli": XnliProcessor, "mrpc": MrpcProcessor
}
5,设定参数,进行fine-tune
python my_classifier.py \ --task_name=mprc \ --do_train=true \ --do_eval=true \ --data_dir=$GLUE_DIR/MRPC \ --vocab_file=$BERT_BASE_DIR/vocab.txt \ --bert_config_file=$BERT_BASE_DIR/bert_config.json \ --init_checkpoint=$BERT_BASE_DIR/bert_model.ckpt \ --max_seq_length=128 \ --train_batch_size=32 \ --learning_rate=2e-5 \ --num_train_epochs=3.0 \ --output_dir=/tmp/mrpc_output/