transformers关键代码(需要完善)
1、训练参数的配置
training_args=Seq2SeqTrainingArguments( # dataloader_num_workers=4, num_train_epochs=epochNo, save_strategy='epoch', evaluation_strategy=evaluation_strategy,#是否全量'no' if constants.ifFullData else 'epoch', logging_steps=50, save_total_limit=save_total_limit, #最多保存模型个数 metric_for_best_model='eval_cider', #修改衡量指标 greater_is_better=True, learning_rate=lr, warmup_ratio=0.03, seed=userSeed,overwrite_output_dir=True, per_device_eval_batch_size=batchsize, per_device_train_batch_size=batchsize, output_dir=outputPath, do_train=True, do_eval=do_eval,#是否全量False if constants.ifFullData else True, predict_with_generate=True, label_smoothing_factor=0.1 if constants.isSMOOTH else 0 )
2、 Datasets 数据的构建
首先定义一个dict,其value是list results={'summarization':[],'article':[]} 然后 results=Dataset.from_dict(results)
print( isinstance( results, torch.utils.data.IterableDataset))
我当记事本用的