bert微调(1)

bert微调步骤:

首先从主函数开刀:

copy    run_classifier.py 随便重命名 my_classifier.py

先看主函数:

if __name__ == "__main__":
  flags.mark_flag_as_required("data_dir")
  flags.mark_flag_as_required("task_name")
  flags.mark_flag_as_required("vocab_file")
  flags.mark_flag_as_required("bert_config_file")
  flags.mark_flag_as_required("output_dir")
  tf.app.run()

1,data_dir

flags.mark_flag_as_required("data_dir")中data_dir为数据的路径文件夹,数据格式已经定义好了:

class InputExample(object):
  """A single training/test example for simple sequence classification."""

  def __init__(self, guid, text_a, text_b=None, label=None):
    """Constructs a InputExample.

    Args:
      guid: Unique id for the example.
      text_a: string. The untokenized text of the first sequence. For single
        sequence tasks, only this sequence must be specified.
      text_b: (Optional) string. The untokenized text of the second sequence.
        Only must be specified for sequence pair tasks.
      label: (Optional) string. The label of the example. This should be
        specified for train and dev examples, but not for test examples.
    """
    self.guid = guid
    self.text_a = text_a
    self.text_b = text_b
    self.label = label

要求的数据格式是:必选参数:guid, text_a,可选参数text_b, label

其中单句子分类任务不需要text_b,且在test数据样本中不需要输入label

2,task_name

  processors = {
      "cola": ColaProcessor,
      "mnli": MnliProcessor,
      "mrpc": MrpcProcessor,
      "xnli": XnliProcessor,
  }

其中task_name表示processors这个字典中的键值对,在bert中给了四个,分别是:"cola","mnli","mrpc","xnli",如果需要别的,另行添加

值得注意的是:

  task_name = FLAGS.task_name.lower()

  if task_name not in processors:
    raise ValueError("Task not found: %s" % (task_name))

  processor = processors[task_name]()

  label_list = processor.get_labels()

task_name是用来选择processor的,在bert的源码中有4个processors,而我们进行微调,需要自定义自己的processor,如下:

class MrpcProcessor(DataProcessor):
  """Processor for the MRPC data set (GLUE version)."""

  def get_train_examples(self, data_dir):
    """See base class."""
    return self._create_examples(
        self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")

  def get_dev_examples(self, data_dir):
    """See base class."""
    return self._create_examples(
        self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")

  def get_test_examples(self, data_dir):
    """See base class."""
    return self._create_examples(
        self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")

  def get_labels(self):
    """See base class."""
    return ["0", "1"]  #todo

  def _create_examples(self, lines, set_type):
    """Creates examples for the training and dev sets."""
    examples = []
    for (i, line) in enumerate(lines):
      if i == 0:
        continue
      guid = "%s-%s" % (set_type, i)
      text_a = tokenization.convert_to_unicode(line[3])
      text_b = tokenization.convert_to_unicode(line[4])
      if set_type == "test":
        label = "0"
      else:
        label = tokenization.convert_to_unicode(line[0])
      examples.append(
          InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
    return examples

其实processor表示对数据进行处理的类,它继承了DataProcessor类对输入数据进行预处理,此外,在data_dir文件夹中,我们的文件格式为.tsv格式,由于设定的分类为二分类,我们将label设置为了0,1

同时_create_examples()中,给定了如何获取guid以及如何给text_a, text_b和label赋值。

 

 

主函数的前两句代码看完了,继续看主函数

if __name__ == "__main__":
  flags.mark_flag_as_required("data_dir")
  flags.mark_flag_as_required("task_name")
  flags.mark_flag_as_required("vocab_file")
  flags.mark_flag_as_required("bert_config_file")
  flags.mark_flag_as_required("output_dir")
  tf.app.run()

3,vocab_file, bert_config_file, output_dir

其中,vocab_file, bert_config_file分别是下载预训练模型的文件,output_dir表示输出的微调之后的model

此外,在前面所说的.tsv文件格式类似于.csv文件

train.tsv和dev.tsv文件格式

标签+“/t”(制表符)+句子

test文件为

句子

 

4,修改processors字典,添加自己的分类

processors = {
      "cola": ColaProcessor,
      "mnli": MnliProcessor,
      "mrpc": MrpcProcessor,
      "xnli": XnliProcessor,
      "mrpc": MrpcProcessor
}

5,设定参数,进行fine-tune

python my_classifier.py \
  --task_name=mprc \
  --do_train=true \
  --do_eval=true \
  --data_dir=$GLUE_DIR/MRPC \
  --vocab_file=$BERT_BASE_DIR/vocab.txt \
  --bert_config_file=$BERT_BASE_DIR/bert_config.json \
  --init_checkpoint=$BERT_BASE_DIR/bert_model.ckpt \
  --max_seq_length=128 \
  --train_batch_size=32 \
  --learning_rate=2e-5 \
  --num_train_epochs=3.0 \
  --output_dir=/tmp/mrpc_output/

 

posted @ 2020-05-06 16:07  老王哈哈哈  阅读(2228)  评论(0编辑  收藏  举报