pytorch transformers ....
huggingface模型下载
transformers的预训练模型下载到本地特定位置,默认是在~/.cache/huggingface/transformers
model = GPT2LMHeadModel.from_pretrained('gpt2', cache_dir="...")
想知道transformers的模型都是什么结构的,比如bert模型:
transformers/models/bert/__init__.py
这里可以看到导入了
from .modeling_bert import (
BERT_PRETRAINED_MODEL_ARCHIVE_LIST,
BertForMaskedLM,
BertForMultipleChoice,
BertForNextSentencePrediction,
BertForPreTraining,
BertForQuestionAnswering,
BertForSequenceClassification,
BertForTokenClassification,
BertLayer,
BertLMHeadModel,
BertModel,
BertPreTrainedModel,
load_tf_weights_in_bert,
)
然后点进去就可以看了,可以看他们的forward函数等
Trainer
Trainer提供了训练、验证、预测的功能,可以通过继承Trainer并覆写其中一些方法来自定义。
- compute_loss()
计算损失的函数,compute_loss(self, model, inputs, return_outputs=False)
,该函数执行forward和loss计算。 - 参数
compute_metrics(eval_pred: EvalPrediction)
该参数指定的函数在tranier.evaluate()时会调用,该函数的参数是EvalPrediction
类型,必须返回一个字典,类型是string-> value
。比如
{
'accuracy': 0.98,
'sensitivity': 0.65
}
EvalPrediction
包括self.predictions
和self.label_ids
,以及可能有的self.inputs
。
- 参数metric_for_best_model
这是个字符串,并且必须是compute_metrics
返回的字典的一个key,对应上例就是只能是'accuracy'
或者'sensitivity'