transformers关键代码(需要完善)

1、训练参数的配置

    training_args=Seq2SeqTrainingArguments(
                                            # dataloader_num_workers=4,
                                            num_train_epochs=epochNo,
                                            save_strategy='epoch',
                                            evaluation_strategy=evaluation_strategy,#是否全量'no' if constants.ifFullData else 'epoch',
                                            logging_steps=50,
                                            save_total_limit=save_total_limit,  #最多保存模型个数
                                            metric_for_best_model='eval_cider',  #修改衡量指标
                                            greater_is_better=True,
                                            learning_rate=lr,
                                            warmup_ratio=0.03,
                                            seed=userSeed,overwrite_output_dir=True,
                                            per_device_eval_batch_size=batchsize,
                                            per_device_train_batch_size=batchsize,
                                            output_dir=outputPath,
                                            do_train=True,
                                            do_eval=do_eval,#是否全量False if constants.ifFullData else True,
                                            predict_with_generate=True,
                                            label_smoothing_factor=0.1 if constants.isSMOOTH else 0
                                           )

  2、 Datasets 数据的构建

首先定义一个dict,其value是list 
results={'summarization':[],'article':[]}
然后
results=Dataset.from_dict(results)
print( isinstance( results, torch.utils.data.IterableDataset))

  

posted @ 2018-01-26 23:26  随遇而安jason  阅读(40)  评论(0编辑  收藏  举报