pytorch transformers ....

huggingface模型下载

transformers的预训练模型下载到本地特定位置,默认是在~/.cache/huggingface/transformers

model = GPT2LMHeadModel.from_pretrained('gpt2', cache_dir="...")

想知道transformers的模型都是什么结构的,比如bert模型:

transformers/models/bert/__init__.py

这里可以看到导入了

from .modeling_bert import (
            BERT_PRETRAINED_MODEL_ARCHIVE_LIST,
            BertForMaskedLM,
            BertForMultipleChoice,
            BertForNextSentencePrediction,
            BertForPreTraining,
            BertForQuestionAnswering,
            BertForSequenceClassification,
            BertForTokenClassification,
            BertLayer,
            BertLMHeadModel,
            BertModel,
            BertPreTrainedModel,
            load_tf_weights_in_bert,
        )

然后点进去就可以看了,可以看他们的forward函数等

Trainer

Trainer提供了训练、验证、预测的功能,可以通过继承Trainer并覆写其中一些方法来自定义。

  • compute_loss()
    计算损失的函数,compute_loss(self, model, inputs, return_outputs=False),该函数执行forward和loss计算。
  • 参数 compute_metrics(eval_pred: EvalPrediction)
    该参数指定的函数在tranier.evaluate()时会调用,该函数的参数是EvalPrediction类型,必须返回一个字典,类型是string-> value。比如
{
	'accuracy': 0.98, 
	'sensitivity': 0.65
}

EvalPrediction包括self.predictionsself.label_ids,以及可能有的self.inputs

  • 参数metric_for_best_model
    这是个字符串,并且必须是compute_metrics返回的字典的一个key,对应上例就是只能是'accuracy'或者'sensitivity'
posted @ 2022-05-19 10:12  王冰冰  阅读(699)  评论(0编辑  收藏  举报