论文:《Transferable Multi-Domain State Generator for Task-Oriented Dialogue Systems》
源码:https://github.com/jasonwu0731/trade-dst#zero-shot-dst
参考:https://www.jianshu.com/p/fcaf02fc025d
https://blog.csdn.net/weixin_44385551/article/details/103673408
这是一篇 迁移学习+多领域对话系统 的论文。
一、训练过程
首先,笔者使用colab进行训练,设置运行时类型,只能用GPU,不能用TPU。估计跟代码有关。
需要看:原文4.2 Training Details 和 4.3 Results以及github作者说明。
训练时只用到(restaurant, hotel, attraction, taxi, train) 五个领域,因为另外两个(hospital, police) 对话太少了。
这里有两个评价指标,共同目标准确率(joint goal accuracy)和槽准确率(slot accuracy)。前者要求DST输出的第t轮对话状态(dialogue state)Bt=(领域,槽,槽值)和标签完全一致,即这一轮中每一个槽值都预测正确才算这一轮正确,而后者只计算(领域,槽,槽值)三元组的准确率。
1.Multi-Domain DST(Multi-domain Joint Training)
Multi-domain Joint Training:优化器选择Adam,batchsize为32,学习率[0.001,0.0001] 0.2的衰减,公式(7)中权重超参都设置为1,嵌入向量由拼接 Glove(Pennington et al., 2014)和单词级嵌入(Hashimoto et al., 2016)初始化嵌入维度为400。解码器用了贪婪搜索(注sequence2sequence的输出Beam Search用的比较多)。同时为了加强模型泛化能力 utterance encoder 加入了word dropout(和drop out类似随机屏蔽掉几个词,出处 Bowman et al. (2016))。
(1)命令
Training
❱❱❱ python3 myTrain.py -dec=TRADE -bsz=32 -dr=0.2 -lr=0.001 -le=1
eg:python3 myTrain.py -dec=TRADE -bsz=32 -dr=0.2 -lr=0.001 -le=1 --parallel_decode=1
Testing
❱❱❱ python3 myTest.py -path=${save_path}
eg:python3 myTest.py -path=
参数说明:
-bsz: batch size
-dr: drop out ratio
-lr: learning rate
-le: loading pretrained embeddings
-path: model saved path
--parallel_decode=1:speedup decoding process
注意:这里的--parallel_decode=1只能用在Multi-Domain DST的training,下面的Unseen Domain DST不能用!!!
(2)结果
2.Unseen Domain DST(Domain Expanding)
2.1 Zero-Shot DST
(1)命令
Training
❱❱❱ python3 myTrain.py -dec=TRADE -bsz=32 -dr=0.2 -lr=0.001 -le=1 -exceptd=${domain}
eg:python3 myTrain.py -dec=TRADE -bsz=32 -dr=0.2 -lr=0.001 -le=1 -exceptd=attraction
#训练时候,设置-path,可以从上次终端之前最新保存的模型开始训练。【这里的训练是下面few-shot的基础,但是这里太慢了,是在是艰难!两天才训练了8轮,还总是各种原因中断,就包括我设置了TPU!】
python3 myTrain.py -dec=TRADE -bsz=32 -dr=0.2 -lr=0.001 -le=1 -exceptd=attraction -path='/content/drive/My Drive/AI/trade-dst/save/TRADE-Exceptattractionmultiwozdst/HDD400BSZ32DR0.2ACC-0.5398'
Testing
❱❱❱ python3 myTest.py -path=${save_path} -exceptd=${domain}
eg:python3 myTest.py -path='save/TRADE-Exceptattractionmultiwozdst/HDD400BSZ32DR0.2ACC-0.5398' -exceptd=attraction
注意:这里的-path,不能用绝对路径,否则会出错,下面同理。
测试输出:
/content/drive/My Drive/AI/trade-dst {'dataset': 'multiwoz', 'task': 'dst', 'path': 'save/TRADE-Exceptattractionmultiwozdst/HDD400BSZ32DR0.2ACC-0.5398', 'sample': None, 'patience': 6, 'earlyStop': 'BLEU', 'all_vocab': 1, 'imbalance_sampler': 0, 'data_ratio': 100, 'unk_mask': 1, 'batch': None, 'run_dev_testing': 0, 'vizualization': 0, 'genSample': 0, 'evalp': 1, 'addName': 'Exceptattraction', 'eval_batch': 0, 'use_gate': 1, 'load_embedding': 0, 'fix_embedding': 0, 'parallel_decode': 0, 'decoder': None, 'hidden': 400, 'learn': None, 'drop': None, 'limit': -10000, 'clip': 10, 'teacher_forcing_ratio': 0.5, 'lambda_ewc': 0.01, 'fisher_sample': 0, 'all_model': False, 'domain_as_task': False, 'run_except_4d': 1, 'strict_domain': False, 'except_domain': 'attraction', 'only_domain': ''} ['save', 'TRADE-Exceptattractionmultiwozdst', 'HDD400BSZ32DR0.2ACC-0.5398'] HDD 400 decoder TRADE BSZ 32 False folder_name save/ Reading from data/dev_dials.json domain_counter {'hotel': 416, 'train': 484, 'attraction': 401, 'restaurant': 438, 'taxi': 207} Reading from data/test_dials.json domain_counter {'taxi': 195, 'attraction': 395, 'restaurant': 437, 'train': 494, 'hotel': 394} Reading from data/test_dials.json domain_counter {'taxi': 195, 'attraction': 395, 'restaurant': 437, 'train': 494, 'hotel': 394} Read 0 pairs train Read 7331 pairs dev Read 3110 pairs test Vocab_size: 18311 Vocab_size Training 0 Vocab_size Belief 983 Max. length of dialog words for RNN: 660 USE_CUDA=True [Train Set & Dev Set Slots]: Number is 27 in total ['hotel-pricerange', 'hotel-type', 'hotel-parking', 'hotel-book stay', 'hotel-book day', 'hotel-book people', 'hotel-area', 'hotel-stars', 'hotel-internet', 'train-destination', 'train-day', 'train-departure', 'train-arriveby', 'train-book people', 'train-leaveat', 'restaurant-food', 'restaurant-pricerange', 'restaurant-area', 'restaurant-name', 'hotel-name', 'taxi-leaveat', 'taxi-destination', 'taxi-departure', 'restaurant-book time', 'restaurant-book day', 'restaurant-book people', 'taxi-arriveby'] [Test Set Slots]: Number is 3 in total ['attraction-area', 'attraction-name', 'attraction-type'] MODEL save/TRADE-Exceptattractionmultiwozdst/HDD400BSZ32DR0.2ACC-0.5398 LOADED Test Set on 4 domains... STARTING EVALUATION 100% 229/229 [04:23<00:00, 1.13s/it] {'Joint Acc': 0.5392491467576792, 'Turn Acc': 0.9687422576159594, 'Joint F1': 0.8946688407302734} Test Set ... STARTING EVALUATION 100% 98/98 [00:14<00:00, 6.94it/s] {'Joint Acc': 0.19581993569131834, 'Turn Acc': 0.5506966773847822, 'Joint F1': 0.20336548767416934}
参数说明:
-exceptd: except domain selection, choose one from {hotel, train, attraction, restaurant, taxi}.
(2)结果
可以看到只有taxi领域的Zero-shot最接近左列,作者的解释是因为taxi领域的四个槽槽值都相似。
2.2 Few-Shot DST with CL(Expanding DST for Few-shot Domain)
(1)命令
Training Naive(重新训练,training 1% new domain)
❱❱❱ python3 fine_tune.py -bsz=8 -dr=0.2 -lr=0.001 -path=${save_path_except_domain} -exceptd=${except_domain}
eg:python3 fine_tune.py -bsz=8 -dr=0.2 -lr=0.001 -path='save/TRADE-Exceptattractionmultiwozdst/HDD400BSZ32DR0.2ACC-0.5398' -exceptd=attractio
训练输出:
{'dataset': 'multiwoz', 'task': 'dst', 'path': 'save/TRADE-Exceptattractionmultiwozdst/HDD400BSZ32DR0.2ACC-0.5398', 'sample': None, 'patience': 6, 'earlyStop': 'BLEU', 'all_vocab': 1, 'imbalance_sampler': 0, 'data_ratio': 100, 'unk_mask': 1, 'batch': 8, 'run_dev_testing': 0, 'vizualization': 0, 'genSample': 0, 'evalp': 1, 'addName': 'Exceptattraction', 'eval_batch': 0, 'use_gate': 1, 'load_embedding': 0, 'fix_embedding': 0, 'parallel_decode': 0, 'decoder': None, 'hidden': 400, 'learn': 0.001, 'drop': 0.2, 'limit': -10000, 'clip': 10, 'teacher_forcing_ratio': 0.5, 'lambda_ewc': 0.01, 'fisher_sample': 0, 'all_model': False, 'domain_as_task': False, 'run_except_4d': 1, 'strict_domain': False, 'except_domain': 'attraction', 'only_domain': ''} True folder_name save/ Reading from data/train_dials.json domain_counter {'hotel': 3381, 'train': 3103, 'attraction': 2717, 'restaurant': 3813, 'taxi': 1654} Reading from data/dev_dials.json domain_counter {'hotel': 416, 'train': 484, 'attraction': 401, 'restaurant': 438, 'taxi': 207} Reading from data/test_dials.json domain_counter {'taxi': 195, 'attraction': 395, 'restaurant': 437, 'train': 494, 'hotel': 394} [Info] Loading saved lang files... Reading from data/test_dials.json domain_counter {'taxi': 195, 'attraction': 395, 'restaurant': 437, 'train': 494, 'hotel': 394} Read 56177 pairs train Read 7331 pairs dev Read 3110 pairs test Vocab_size: 18311 Vocab_size Training 15462 Vocab_size Belief 1005 Max. length of dialog words for RNN: 880 USE_CUDA=True [Train Set & Dev Set Slots]: Number is 27 in total ['hotel-pricerange', 'hotel-type', 'hotel-parking', 'hotel-book stay', 'hotel-book day', 'hotel-book people', 'hotel-area', 'hotel-stars', 'hotel-internet', 'train-destination', 'train-day', 'train-departure', 'train-arriveby', 'train-book people', 'train-leaveat', 'restaurant-food', 'restaurant-pricerange', 'restaurant-area', 'restaurant-name', 'hotel-name', 'taxi-leaveat', 'taxi-destination', 'taxi-departure', 'restaurant-book time', 'restaurant-book day', 'restaurant-book people', 'taxi-arriveby'] [Test Set Slots]: Number is 3 in total ['attraction-area', 'attraction-name', 'attraction-type'] True folder_name save/ Reading from data/train_dials.json domain_counter {'hotel': 33, 'train': 35, 'restaurant': 36, 'attraction': 29, 'taxi': 11} Reading from data/dev_dials.json domain_counter {'hotel': 416, 'train': 484, 'attraction': 401, 'restaurant': 438, 'taxi': 207} Reading from data/test_dials.json domain_counter {'taxi': 195, 'attraction': 395, 'restaurant': 437, 'train': 494, 'hotel': 394} [Info] Loading saved lang files... Read 232 pairs train Read 3088 pairs dev Read 3110 pairs test Vocab_size: 18311 Vocab_size Training 15462 Vocab_size Belief 983 Max. length of dialog words for RNN: 660 USE_CUDA=True [Train Set & Dev Set Slots]: Number is 3 in total ['attraction-area', 'attraction-name', 'attraction-type'] [Test Set Slots]: Number is 3 in total ['attraction-area', 'attraction-name', 'attraction-type'] /usr/local/lib/python3.6/dist-packages/torch/nn/modules/rnn.py:50: UserWarning: dropout option adds dropout after all but last recurrent layer, so non-zero dropout expects num_layers greater than 1, but got dropout=0.2 and num_layers=1 "num_layers={}".format(dropout, num_layers)) MODEL save/TRADE-Exceptattractionmultiwozdst/HDD400BSZ32DR0.2ACC-0.5398 LOADED Epoch:0 L:1.82,LP:1.11,LG:0.71: 100% 29/29 [00:01<00:00, 15.90it/s] STARTING EVALUATION 100% 386/386 [00:24<00:00, 15.41it/s] {'Joint Acc': 0.3248056994818653, 'Turn Acc': 0.6430267702936108, 'Joint F1': 0.4696891191709857} MODEL SAVED Epoch:1 L:0.54,LP:0.27,LG:0.27: 100% 29/29 [00:01<00:00, 15.08it/s] STARTING EVALUATION 100% 386/386 [00:25<00:00, 15.25it/s] {'Joint Acc': 0.37176165803108807, 'Turn Acc': 0.6709844559585499, 'Joint F1': 0.5357620898100195} MODEL SAVED Epoch:2 L:0.31,LP:0.16,LG:0.16: 100% 29/29 [00:01<00:00, 15.61it/s] STARTING EVALUATION 100% 386/386 [00:25<00:00, 15.25it/s] {'Joint Acc': 0.37370466321243523, 'Turn Acc': 0.6797279792746118, 'Joint F1': 0.5280332469775502} MODEL SAVED Epoch:3 L:0.24,LP:0.15,LG:0.10: 100% 29/29 [00:01<00:00, 15.36it/s] STARTING EVALUATION 100% 386/386 [00:25<00:00, 15.06it/s] {'Joint Acc': 0.3636658031088083, 'Turn Acc': 0.6693652849740939, 'Joint F1': 0.5185772884283271} Epoch:4 L:0.21,LP:0.10,LG:0.11: 100% 29/29 [00:01<00:00, 15.36it/s] STARTING EVALUATION 100% 386/386 [00:24<00:00, 15.40it/s] {'Joint Acc': 0.36075129533678757, 'Turn Acc': 0.6716321243523327, 'Joint F1': 0.531012521588949} Epoch 5: reducing learning rate of group 0 to 5.0000e-04. Epoch:5 L:0.15,LP:0.07,LG:0.08: 100% 29/29 [00:01<00:00, 15.16it/s] STARTING EVALUATION 100% 386/386 [00:24<00:00, 15.34it/s] {'Joint Acc': 0.37208549222797926, 'Turn Acc': 0.682210708117445, 'Joint F1': 0.5309693436960298} Epoch:6 L:0.14,LP:0.07,LG:0.07: 100% 29/29 [00:01<00:00, 15.48it/s] STARTING EVALUATION 100% 386/386 [00:25<00:00, 15.00it/s] {'Joint Acc': 0.3873056994818653, 'Turn Acc': 0.6889032815198624, 'Joint F1': 0.5552677029360991} MODEL SAVED Epoch:7 L:0.15,LP:0.08,LG:0.07: 100% 29/29 [00:01<00:00, 15.43it/s] STARTING EVALUATION 100% 386/386 [00:25<00:00, 15.38it/s] {'Joint Acc': 0.3604274611398964, 'Turn Acc': 0.6707685664939559, 'Joint F1': 0.5084628670120919} Epoch:8 L:0.09,LP:0.04,LG:0.05: 100% 29/29 [00:01<00:00, 15.52it/s] STARTING EVALUATION 100% 386/386 [00:24<00:00, 15.48it/s] {'Joint Acc': 0.3707901554404145, 'Turn Acc': 0.6788644214162353, 'Joint F1': 0.5320811744386896} Epoch 9: reducing learning rate of group 0 to 2.5000e-04. Epoch:9 L:0.09,LP:0.04,LG:0.04: 100% 29/29 [00:01<00:00, 15.57it/s] STARTING EVALUATION 100% 386/386 [00:24<00:00, 15.45it/s] {'Joint Acc': 0.37694300518134716, 'Turn Acc': 0.68199481865285, 'Joint F1': 0.5321783246977567} Epoch:10 L:0.10,LP:0.04,LG:0.05: 100% 29/29 [00:01<00:00, 15.63it/s] STARTING EVALUATION 100% 386/386 [00:25<00:00, 14.88it/s] {'Joint Acc': 0.3753238341968912, 'Turn Acc': 0.6822107081174438, 'Joint F1': 0.5298035405872213} Epoch 11: reducing learning rate of group 0 to 1.2500e-04. Epoch:11 L:0.08,LP:0.04,LG:0.04: 100% 29/29 [00:01<00:00, 16.06it/s] STARTING EVALUATION 100% 386/386 [00:24<00:00, 15.45it/s] {'Joint Acc': 0.3711139896373057, 'Turn Acc': 0.6810233160621771, 'Joint F1': 0.5278065630397256} Epoch:12 L:0.11,LP:0.05,LG:0.06: 100% 29/29 [00:01<00:00, 15.28it/s] STARTING EVALUATION 100% 386/386 [00:24<00:00, 15.35it/s] {'Joint Acc': 0.3756476683937824, 'Turn Acc': 0.6839378238341981, 'Joint F1': 0.5391731433506065} Epoch 13: reducing learning rate of group 0 to 1.0000e-04. Ran out of patient, early stop... [Info] After Fine Tune ... [Info] Test Set on 4 domains... STARTING EVALUATION 100% 916/916 [07:43<00:00, 2.10it/s] {'Joint Acc': 0.3971331058020478, 'Turn Acc': 0.9562432056629776, 'Joint F1': 0.8310067907423058} [Info] Test Set on 1 domain attraction ... STARTING EVALUATION 100% 389/389 [00:25<00:00, 15.38it/s] {'Joint Acc': 0.367524115755627, 'Turn Acc': 0.6797427652733147, 'Joint F1': 0.5233440514469476}
EWC(微调1:elastic weight consolidation)
❱❱❱ python3 EWC_train.py -bsz=8 -dr=0.2 -lr=0.001 -path=${save_path_except_domain} -exceptd=${except_domain} -fisher_sample=10000 -l_ewc=${lambda} eg:python3 EWC_train.py -bsz=8 -dr=0.2 -lr=0.001 -path='save/TRADE-Exceptattractionmultiwozdst/HDD400BSZ32DR0.2ACC-0.5398' -exceptd=attraction -fisher_sample=10000
注意:这里的-l_ewc=${lambda}不知道怎么设置,论文和github上面都没有写,就先删掉了,应该有默认值。
微调1训练,第一次运行产生Fisher矩阵,输出:
{'dataset': 'multiwoz', 'task': 'dst', 'path': 'save/TRADE-Exceptattractionmultiwozdst/HDD400BSZ32DR0.2ACC-0.5398', 'sample': None, 'patience': 6, 'earlyStop': 'BLEU', 'all_vocab': 1, 'imbalance_sampler': 0, 'data_ratio': 100, 'unk_mask': 1, 'batch': 8, 'run_dev_testing': 0, 'vizualization': 0, 'genSample': 0, 'evalp': 1, 'addName': 'Exceptattraction', 'eval_batch': 0, 'use_gate': 1, 'load_embedding': 0, 'fix_embedding': 0, 'parallel_decode': 0, 'decoder': None, 'hidden': 400, 'learn': 0.001, 'drop': 0.2, 'limit': -10000, 'clip': 10, 'teacher_forcing_ratio': 0.5, 'lambda_ewc': 0.01, 'fisher_sample': 10000, 'all_model': False, 'domain_as_task': False, 'run_except_4d': 1, 'strict_domain': False, 'except_domain': 'attraction', 'only_domain': ''} True folder_name save/ Reading from data/train_dials.json domain_counter {'hotel': 3381, 'train': 3103, 'attraction': 2717, 'restaurant': 3813, 'taxi': 1654} Reading from data/dev_dials.json domain_counter {'hotel': 416, 'train': 484, 'attraction': 401, 'restaurant': 438, 'taxi': 207} Reading from data/test_dials.json domain_counter {'taxi': 195, 'attraction': 395, 'restaurant': 437, 'train': 494, 'hotel': 394} [Info] Loading saved lang files... Reading from data/test_dials.json domain_counter {'taxi': 195, 'attraction': 395, 'restaurant': 437, 'train': 494, 'hotel': 394} Read 56177 pairs train Read 7331 pairs dev Read 3110 pairs test Vocab_size: 18311 Vocab_size Training 15462 Vocab_size Belief 1005 Max. length of dialog words for RNN: 880 USE_CUDA=True [Train Set & Dev Set Slots]: Number is 27 in total ['hotel-pricerange', 'hotel-type', 'hotel-parking', 'hotel-book stay', 'hotel-book day', 'hotel-book people', 'hotel-area', 'hotel-stars', 'hotel-internet', 'train-destination', 'train-day', 'train-departure', 'train-arriveby', 'train-book people', 'train-leaveat', 'restaurant-food', 'restaurant-pricerange', 'restaurant-area', 'restaurant-name', 'hotel-name', 'taxi-leaveat', 'taxi-destination', 'taxi-departure', 'restaurant-book time', 'restaurant-book day', 'restaurant-book people', 'taxi-arriveby'] [Test Set Slots]: Number is 3 in total ['attraction-area', 'attraction-name', 'attraction-type'] /usr/local/lib/python3.6/dist-packages/torch/nn/modules/rnn.py:50: UserWarning: dropout option adds dropout after all but last recurrent layer, so non-zero dropout expects num_layers greater than 1, but got dropout=0.2 and num_layers=1 "num_layers={}".format(dropout, num_layers)) MODEL save/TRADE-Exceptattractionmultiwozdst/HDD400BSZ32DR0.2ACC-0.5398 LOADED Computing Fisher Matrix 100% 10000/10000 [28:55<00:00, 6.65it/s] Saving Fisher Matrix in save/TRADE-Exceptattractionmultiwozdst/HDD400BSZ32DR0.2ACC-0.5398fisher10000 /usr/local/lib/python3.6/dist-packages/torch/storage.py:34: FutureWarning: pickle support for Storage will be removed in 1.5. Use `torch.save` instead warnings.warn("pickle support for Storage will be removed in 1.5. Use `torch.save` instead", FutureWarning)
微调1训练,第二次运行计算评价指标,输出:
/content/drive/My Drive/AI/trade-dst {'dataset': 'multiwoz', 'task': 'dst', 'path': 'save/TRADE-Exceptattractionmultiwozdst/HDD400BSZ32DR0.2ACC-0.5398', 'sample': None, 'patience': 6, 'earlyStop': 'BLEU', 'all_vocab': 1, 'imbalance_sampler': 0, 'data_ratio': 100, 'unk_mask': 1, 'batch': 8, 'run_dev_testing': 0, 'vizualization': 0, 'genSample': 0, 'evalp': 1, 'addName': 'Exceptattraction', 'eval_batch': 0, 'use_gate': 1, 'load_embedding': 0, 'fix_embedding': 0, 'parallel_decode': 0, 'decoder': None, 'hidden': 400, 'learn': 0.001, 'drop': 0.2, 'limit': -10000, 'clip': 10, 'teacher_forcing_ratio': 0.5, 'lambda_ewc': 0.01, 'fisher_sample': 10000, 'all_model': False, 'domain_as_task': False, 'run_except_4d': 1, 'strict_domain': False, 'except_domain': 'attraction', 'only_domain': ''} Load Fisher Matrixsave/TRADE-Exceptattractionmultiwozdst/HDD400BSZ32DR0.2ACC-0.5398fisher10000 True folder_name save/ Reading from data/train_dials.json domain_counter {'hotel': 3381, 'train': 3103, 'attraction': 2717, 'restaurant': 3813, 'taxi': 1654} Reading from data/dev_dials.json domain_counter {'hotel': 416, 'train': 484, 'attraction': 401, 'restaurant': 438, 'taxi': 207} Reading from data/test_dials.json domain_counter {'taxi': 195, 'attraction': 395, 'restaurant': 437, 'train': 494, 'hotel': 394} [Info] Loading saved lang files... Reading from data/test_dials.json domain_counter {'taxi': 195, 'attraction': 395, 'restaurant': 437, 'train': 494, 'hotel': 394} Read 56177 pairs train Read 7331 pairs dev Read 3110 pairs test Vocab_size: 18311 Vocab_size Training 15462 Vocab_size Belief 1005 Max. length of dialog words for RNN: 880 USE_CUDA=True [Train Set & Dev Set Slots]: Number is 27 in total ['hotel-pricerange', 'hotel-type', 'hotel-parking', 'hotel-book stay', 'hotel-book day', 'hotel-book people', 'hotel-area', 'hotel-stars', 'hotel-internet', 'train-destination', 'train-day', 'train-departure', 'train-arriveby', 'train-book people', 'train-leaveat', 'restaurant-food', 'restaurant-pricerange', 'restaurant-area', 'restaurant-name', 'hotel-name', 'taxi-leaveat', 'taxi-destination', 'taxi-departure', 'restaurant-book time', 'restaurant-book day', 'restaurant-book people', 'taxi-arriveby'] [Test Set Slots]: Number is 3 in total ['attraction-area', 'attraction-name', 'attraction-type'] True folder_name save/ Reading from data/train_dials.json domain_counter {'hotel': 33, 'train': 35, 'restaurant': 36, 'attraction': 29, 'taxi': 11} Reading from data/dev_dials.json domain_counter {'hotel': 416, 'train': 484, 'attraction': 401, 'restaurant': 438, 'taxi': 207} Reading from data/test_dials.json domain_counter {'taxi': 195, 'attraction': 395, 'restaurant': 437, 'train': 494, 'hotel': 394} [Info] Loading saved lang files... Read 232 pairs train Read 3088 pairs dev Read 3110 pairs test Vocab_size: 18311 Vocab_size Training 15462 Vocab_size Belief 983 Max. length of dialog words for RNN: 660 USE_CUDA=True [Train Set & Dev Set Slots]: Number is 3 in total ['attraction-area', 'attraction-name', 'attraction-type'] [Test Set Slots]: Number is 3 in total ['attraction-area', 'attraction-name', 'attraction-type'] /usr/local/lib/python3.6/dist-packages/torch/nn/modules/rnn.py:50: UserWarning: dropout option adds dropout after all but last recurrent layer, so non-zero dropout expects num_layers greater than 1, but got dropout=0.2 and num_layers=1 "num_layers={}".format(dropout, num_layers)) MODEL save/TRADE-Exceptattractionmultiwozdst/HDD400BSZ32DR0.2ACC-0.5398 LOADED Epoch:0 L:1.87,LP:1.09,LG:0.78: 100% 29/29 [00:02<00:00, 14.12it/s] STARTING EVALUATION 100% 386/386 [00:25<00:00, 15.34it/s] {'Joint Acc': 0.2933937823834197, 'Turn Acc': 0.6260794473229723, 'Joint F1': 0.4239421416234893} MODEL SAVED Epoch:1 L:0.54,LP:0.32,LG:0.22: 100% 29/29 [00:02<00:00, 14.35it/s] STARTING EVALUATION 100% 386/386 [00:25<00:00, 14.60it/s] {'Joint Acc': 0.36560880829015546, 'Turn Acc': 0.6984024179620065, 'Joint F1': 0.5461139896373092} MODEL SAVED Epoch:2 L:0.29,LP:0.17,LG:0.12: 100% 29/29 [00:02<00:00, 14.36it/s] STARTING EVALUATION 100% 386/386 [00:24<00:00, 15.76it/s] {'Joint Acc': 0.38277202072538863, 'Turn Acc': 0.7129749568221083, 'Joint F1': 0.5885470639032855} MODEL SAVED Epoch:3 L:0.29,LP:0.16,LG:0.13: 100% 29/29 [00:02<00:00, 14.06it/s] STARTING EVALUATION 100% 386/386 [00:24<00:00, 15.69it/s] {'Joint Acc': 0.41418393782383417, 'Turn Acc': 0.7161053540587244, 'Joint F1': 0.5956066493955128} MODEL SAVED Epoch:4 L:0.21,LP:0.11,LG:0.10: 100% 29/29 [00:02<00:00, 14.56it/s] STARTING EVALUATION 100% 386/386 [00:25<00:00, 15.02it/s] {'Joint Acc': 0.3873056994818653, 'Turn Acc': 0.7184801381692614, 'Joint F1': 0.6121545768566526} Epoch:5 L:0.18,LP:0.09,LG:0.09: 100% 29/29 [00:02<00:00, 14.45it/s] STARTING EVALUATION 100% 386/386 [00:24<00:00, 15.38it/s] {'Joint Acc': 0.4119170984455959, 'Turn Acc': 0.7315414507772071, 'Joint F1': 0.6302892918825584} Epoch 6: reducing learning rate of group 0 to 5.0000e-04. Epoch:6 L:0.16,LP:0.09,LG:0.07: 100% 29/29 [00:01<00:00, 13.95it/s] STARTING EVALUATION 100% 386/386 [00:25<00:00, 14.86it/s] {'Joint Acc': 0.4167746113989637, 'Turn Acc': 0.7354274611398983, 'Joint F1': 0.6218588082901584} MODEL SAVED Epoch:7 L:0.11,LP:0.05,LG:0.06: 100% 29/29 [00:02<00:00, 14.16it/s] STARTING EVALUATION 100% 386/386 [00:24<00:00, 16.08it/s] {'Joint Acc': 0.4226036269430052, 'Turn Acc': 0.7434153713298817, 'Joint F1': 0.6380505181347186} MODEL SAVED Epoch:8 L:0.09,LP:0.05,LG:0.04: 100% 29/29 [00:02<00:00, 14.17it/s] STARTING EVALUATION 100% 386/386 [00:24<00:00, 15.79it/s] {'Joint Acc': 0.4226036269430052, 'Turn Acc': 0.7401770293609702, 'Joint F1': 0.638331174438691} MODEL SAVED Epoch:9 L:0.06,LP:0.03,LG:0.03: 100% 29/29 [00:02<00:00, 14.29it/s] STARTING EVALUATION 100% 386/386 [00:24<00:00, 15.65it/s] {'Joint Acc': 0.4339378238341969, 'Turn Acc': 0.7495682210708151, 'Joint F1': 0.6442141623488813} MODEL SAVED Epoch:10 L:0.10,LP:0.06,LG:0.04: 100% 29/29 [00:02<00:00, 14.42it/s] STARTING EVALUATION 100% 386/386 [00:24<00:00, 15.69it/s] {'Joint Acc': 0.4319948186528497, 'Turn Acc': 0.7461139896373087, 'Joint F1': 0.6425841968911951} Epoch:11 L:0.08,LP:0.03,LG:0.04: 100% 29/29 [00:01<00:00, 14.61it/s] STARTING EVALUATION 100% 386/386 [00:24<00:00, 15.79it/s] {'Joint Acc': 0.4274611398963731, 'Turn Acc': 0.7475172711571708, 'Joint F1': 0.6424978411053575} Epoch 12: reducing learning rate of group 0 to 2.5000e-04. Epoch:12 L:0.06,LP:0.03,LG:0.03: 100% 29/29 [00:01<00:00, 14.56it/s] STARTING EVALUATION 100% 386/386 [00:24<00:00, 15.77it/s] {'Joint Acc': 0.4261658031088083, 'Turn Acc': 0.7471934369602793, 'Joint F1': 0.6415263385146842} Epoch:13 L:0.09,LP:0.03,LG:0.06: 100% 29/29 [00:01<00:00, 14.50it/s] STARTING EVALUATION 100% 386/386 [00:24<00:00, 15.52it/s] {'Joint Acc': 0.42487046632124353, 'Turn Acc': 0.7452504317789324, 'Joint F1': 0.6384822970639064} Epoch 14: reducing learning rate of group 0 to 1.2500e-04. Epoch:14 L:0.07,LP:0.03,LG:0.03: 100% 29/29 [00:01<00:00, 14.99it/s] STARTING EVALUATION 100% 386/386 [00:24<00:00, 16.09it/s] {'Joint Acc': 0.4264896373056995, 'Turn Acc': 0.7464378238341997, 'Joint F1': 0.636798359240072} Epoch:15 L:0.07,LP:0.04,LG:0.03: 100% 29/29 [00:01<00:00, 14.07it/s] STARTING EVALUATION 100% 386/386 [00:24<00:00, 15.46it/s] {'Joint Acc': 0.43102331606217614, 'Turn Acc': 0.7489205526770323, 'Joint F1': 0.6395941278065662} Epoch 16: reducing learning rate of group 0 to 1.0000e-04. Ran out of patient, early stop... [Info] After Fine Tune ... [Info] Test Set on 4 domains... STARTING EVALUATION 100% 916/916 [08:14<00:00, 2.01it/s] {'Joint Acc': 0.31098976109215015, 'Turn Acc': 0.9461913790923792, 'Joint F1': 0.7731457294688385} [Info] Test Set on 1 domain attraction ... STARTING EVALUATION 100% 389/389 [00:26<00:00, 14.88it/s] {'Joint Acc': 0.4202572347266881, 'Turn Acc': 0.7412647374062197, 'Joint F1': 0.6160342979635625}
GEM(微调2:gradient episodic memory)
❱❱❱ python3 GEM_train.py -bsz=8 -dr=0.2 -lr=0.001 -path={save_path_except_domain} -exceptd=${except_domain}
eg:python3 GEM_train.py -bsz=8 -dr=0.2 -lr=0.001 -path='save/TRADE-Exceptattractionmultiwozdst/HDD400BSZ32DR0.2ACC-0.5398' -exceptd=attraction
微调2训练输出:
{'dataset': 'multiwoz', 'task': 'dst', 'path': 'save/TRADE-Exceptattractionmultiwozdst/HDD400BSZ32DR0.2ACC-0.5398', 'sample': None, 'patience': 6, 'earlyStop': 'BLEU', 'all_vocab': 1, 'imbalance_sampler': 0, 'data_ratio': 100, 'unk_mask': 1, 'batch': 8, 'run_dev_testing': 0, 'vizualization': 0, 'genSample': 0, 'evalp': 1, 'addName': 'Exceptattraction', 'eval_batch': 0, 'use_gate': 1, 'load_embedding': 0, 'fix_embedding': 0, 'parallel_decode': 0, 'decoder': None, 'hidden': 400, 'learn': 0.001, 'drop': 0.2, 'limit': -10000, 'clip': 10, 'teacher_forcing_ratio': 0.5, 'lambda_ewc': 0.01, 'fisher_sample': 0, 'all_model': False, 'domain_as_task': False, 'run_except_4d': 1, 'strict_domain': False, 'except_domain': 'attraction', 'only_domain': ''} True folder_name save/ Reading from data/train_dials.json domain_counter {'hotel': 3381, 'train': 3103, 'attraction': 2717, 'restaurant': 3813, 'taxi': 1654} Reading from data/dev_dials.json domain_counter {'hotel': 416, 'train': 484, 'attraction': 401, 'restaurant': 438, 'taxi': 207} Reading from data/test_dials.json domain_counter {'taxi': 195, 'attraction': 395, 'restaurant': 437, 'train': 494, 'hotel': 394} [Info] Loading saved lang files... Reading from data/test_dials.json domain_counter {'taxi': 195, 'attraction': 395, 'restaurant': 437, 'train': 494, 'hotel': 394} Read 56177 pairs train Read 7331 pairs dev Read 3110 pairs test Vocab_size: 18311 Vocab_size Training 15462 Vocab_size Belief 1005 Max. length of dialog words for RNN: 880 USE_CUDA=True [Train Set & Dev Set Slots]: Number is 27 in total ['hotel-pricerange', 'hotel-type', 'hotel-parking', 'hotel-book stay', 'hotel-book day', 'hotel-book people', 'hotel-area', 'hotel-stars', 'hotel-internet', 'train-destination', 'train-day', 'train-departure', 'train-arriveby', 'train-book people', 'train-leaveat', 'restaurant-food', 'restaurant-pricerange', 'restaurant-area', 'restaurant-name', 'hotel-name', 'taxi-leaveat', 'taxi-destination', 'taxi-departure', 'restaurant-book time', 'restaurant-book day', 'restaurant-book people', 'taxi-arriveby'] [Test Set Slots]: Number is 3 in total ['attraction-area', 'attraction-name', 'attraction-type'] True folder_name save/ Reading from data/train_dials.json domain_counter {'hotel': 33, 'train': 35, 'restaurant': 36, 'attraction': 29, 'taxi': 11} Reading from data/dev_dials.json domain_counter {'hotel': 416, 'train': 484, 'attraction': 401, 'restaurant': 438, 'taxi': 207} Reading from data/test_dials.json domain_counter {'taxi': 195, 'attraction': 395, 'restaurant': 437, 'train': 494, 'hotel': 394} [Info] Loading saved lang files... Reading from data/test_dials.json domain_counter {'taxi': 195, 'attraction': 395, 'restaurant': 437, 'train': 494, 'hotel': 394} Read 590 pairs train Read 7331 pairs dev Read 3110 pairs test Vocab_size: 18311 Vocab_size Training 15462 Vocab_size Belief 1005 Max. length of dialog words for RNN: 660 USE_CUDA=True [Train Set & Dev Set Slots]: Number is 27 in total ['hotel-pricerange', 'hotel-type', 'hotel-parking', 'hotel-book stay', 'hotel-book day', 'hotel-book people', 'hotel-area', 'hotel-stars', 'hotel-internet', 'train-destination', 'train-day', 'train-departure', 'train-arriveby', 'train-book people', 'train-leaveat', 'restaurant-food', 'restaurant-pricerange', 'restaurant-area', 'restaurant-name', 'hotel-name', 'taxi-leaveat', 'taxi-destination', 'taxi-departure', 'restaurant-book time', 'restaurant-book day', 'restaurant-book people', 'taxi-arriveby'] [Test Set Slots]: Number is 3 in total ['attraction-area', 'attraction-name', 'attraction-type'] True folder_name save/ Reading from data/train_dials.json domain_counter {'hotel': 33, 'train': 35, 'restaurant': 36, 'attraction': 29, 'taxi': 11} Reading from data/dev_dials.json domain_counter {'hotel': 416, 'train': 484, 'attraction': 401, 'restaurant': 438, 'taxi': 207} Reading from data/test_dials.json domain_counter {'taxi': 195, 'attraction': 395, 'restaurant': 437, 'train': 494, 'hotel': 394} [Info] Loading saved lang files... Read 232 pairs train Read 3088 pairs dev Read 3110 pairs test Vocab_size: 18311 Vocab_size Training 15462 Vocab_size Belief 983 Max. length of dialog words for RNN: 660 USE_CUDA=True [Train Set & Dev Set Slots]: Number is 3 in total ['attraction-area', 'attraction-name', 'attraction-type'] [Test Set Slots]: Number is 3 in total ['attraction-area', 'attraction-name', 'attraction-type'] /usr/local/lib/python3.6/dist-packages/torch/nn/modules/rnn.py:50: UserWarning: dropout option adds dropout after all but last recurrent layer, so non-zero dropout expects num_layers greater than 1, but got dropout=0.2 and num_layers=1 "num_layers={}".format(dropout, num_layers)) MODEL save/TRADE-Exceptattractionmultiwozdst/HDD400BSZ32DR0.2ACC-0.5398 LOADED 4 domains test set length used EVAL 7328 4 domains train set length used for GEM 640 1 domains train set length 232 Epoch:0 L:0.89,LP:0.32,LG:0.57: 100% 29/29 [01:16<00:00, 2.68s/it] STARTING EVALUATION 100% 386/386 [00:24<00:00, 15.36it/s] {'Joint Acc': 0.29436528497409326, 'Turn Acc': 0.6319084628670144, 'Joint F1': 0.4437284110535409} MODEL SAVED Epoch:1 L:0.52,LP:0.36,LG:0.17: 100% 29/29 [01:20<00:00, 2.79s/it] STARTING EVALUATION 100% 386/386 [00:24<00:00, 15.44it/s] {'Joint Acc': 0.36690414507772023, 'Turn Acc': 0.6906303972366185, 'Joint F1': 0.5358592400690866} MODEL SAVED Epoch:2 L:0.20,LP:0.11,LG:0.09: 100% 29/29 [01:18<00:00, 2.62s/it] STARTING EVALUATION 100% 386/386 [00:24<00:00, 15.42it/s] {'Joint Acc': 0.38309585492227977, 'Turn Acc': 0.6974309153713338, 'Joint F1': 0.5791342832469792} MODEL SAVED Epoch:3 L:0.13,LP:0.08,LG:0.04: 100% 29/29 [01:19<00:00, 2.67s/it] STARTING EVALUATION 100% 386/386 [00:24<00:00, 15.18it/s] {'Joint Acc': 0.43879533678756477, 'Turn Acc': 0.7376943005181376, 'Joint F1': 0.6300194300518157} MODEL SAVED Epoch:4 L:0.26,LP:0.04,LG:0.22: 100% 29/29 [01:17<00:00, 2.70s/it] STARTING EVALUATION 100% 386/386 [00:24<00:00, 15.70it/s] {'Joint Acc': 0.4417098445595855, 'Turn Acc': 0.747301381692576, 'Joint F1': 0.637435233160625} MODEL SAVED Epoch:5 L:0.12,LP:0.09,LG:0.03: 100% 29/29 [01:21<00:00, 2.87s/it] STARTING EVALUATION 100% 386/386 [00:24<00:00, 15.49it/s] {'Joint Acc': 0.42843264248704666, 'Turn Acc': 0.7400690846286737, 'Joint F1': 0.6442573402417988} Epoch:6 L:0.29,LP:0.06,LG:0.23: 100% 29/29 [01:17<00:00, 2.64s/it] STARTING EVALUATION 100% 386/386 [00:24<00:00, 15.64it/s] {'Joint Acc': 0.42357512953367876, 'Turn Acc': 0.7339162348877418, 'Joint F1': 0.6311420552677062} Epoch 7: reducing learning rate of group 0 to 5.0000e-04. Epoch:7 L:0.03,LP:0.01,LG:0.01: 100% 29/29 [01:17<00:00, 2.56s/it] STARTING EVALUATION 100% 386/386 [00:24<00:00, 15.45it/s] {'Joint Acc': 0.4177461139896373, 'Turn Acc': 0.732512953367878, 'Joint F1': 0.6171740069084664} Ran out of patient, early stop... [Info] After Fine Tune ... [Info] Test Set on 4 domains... STARTING EVALUATION 100% 916/916 [07:39<00:00, 2.09it/s] {'Joint Acc': 0.46484641638225255, 'Turn Acc': 0.9617039565162195, 'Joint F1': 0.8624498952934783} [Info] Test Set on 1 domain attraction ... STARTING EVALUATION 100% 389/389 [00:25<00:00, 15.56it/s] {'Joint Acc': 0.4112540192926045, 'Turn Acc': 0.7296891747052539, 'Joint F1': 0.6005466237942153}
(2)结果
可以看到fine-tune的方式,往往比重新训练(training 1% new domain)效果好。
二、术语解释
1.Fisher information matrix
参考:
https://www.cnblogs.com/DDJ-XLG/p/4943236.html
https://blog.csdn.net/GarfieldEr007/article/details/86773096
2.
作者:西伯尔
出处:http://www.cnblogs.com/sybil-hxl/
本文版权归作者和博客园共有,欢迎转载,但未经作者同意必须保留此段声明,且在文章页面明显位置给出原文连接,否则保留追究法律责任的权利。