OpenKiwi学习笔记

Python 打印调用栈:

import traceback
traceback.print_stack()

For CLI usage, the general command is:

kiwi (train|pretrain|predict|evaluate|search) CONFIG_FILE

导入bert模型:

from transformers import (
    BERT_PRETRAINED_MODEL_ARCHIVE_LIST,
    DISTILBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
    AutoTokenizer,
    BertConfig,
    BertModel,
)

qe system 的组成:

Furthermore, all of our QE systems share a similar architecture. They are composed of

- Encoder
- Decoder
- Output
- (Optionally) TLM Output

Encoder: Embedding and creating features to be used for downstream tasks. i.e. Predictor,
BERT, etc

Decoder: Responsible for learning feature transformations better suited for the
downstream task. i.e. MLP, LSTM, etc

Output: Simple feedforwards that take decoder features and transform them into the
prediction required by the downstream task. Something in the same line as the common
"classification heads" being used with transformers.

TLM Output: A simple output layer that trains for the specific TLM objective. It can be
useful to continue finetuning the predictor during training of the complete QE system.

qe system 继承关系:

All QE systems inherit from :class:`kiwi.systems.qe_system.QESystem`.

Use ``kiwi train`` to train these systems.

Currently available are:

+--------------------------------------------------------------+
| :class:`kiwi.systems.nuqe.NuQE`                              |
+--------------------------------------------------------------+
| :class:`kiwi.systems.predictor_estimator.PredictorEstimator` |
+--------------------------------------------------------------+
| :class:`kiwi.systems.bert.Bert`                              |
+--------------------------------------------------------------+
| :class:`kiwi.systems.xlm.XLM`                                |
+--------------------------------------------------------------+
| :class:`kiwi.systems.xlmroberta.XLMRoberta`                  |
+--------------------------------------------------------------+

TLM 继承关系:

TLM --- :mod:`kiwi.systems.tlm_system`
--------------------------------------

All TLM systems inherit from :class:`kiwi.systems.tlm_system.TLMSystem`.

Use ``kiwi pretrain`` to train these systems. These systems can then be used as the
encoder part in QE systems by using the `load_encoder` flag.

Currently available are:

+-------------------------------------------+
| :class:`kiwi.systems.predictor.Predictor` |

Configuration class

kiwi.lib.train.Configuration
kiwi.lib.train.RunConfig
kiwi.lib.train.TrainerConfig
kiwi.data.datasets.wmt_qe_dataset.WMTQEDataset.Config
kiwi.systems.qe_system.QESystem.Config

encoder 采用 bert 模型,模型整体结构为:

Bert(
  (encoder): BertEncoder(
    (bert): BertModel(
      (embeddings): BertEmbeddings(
        (word_embeddings): Embedding(119547, 768, padding_idx=0)
        (position_embeddings): Embedding(512, 768)
        (token_type_embeddings): Embedding(2, 768)
        (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (encoder): BertEncoder(
        (layer): ModuleList(
          (0): BertLayer(
            (attention): BertAttention(
              (self): BertSelfAttention(
                (query): Linear(in_features=768, out_features=768, bias=True)
                (key): Linear(in_features=768, out_features=768, bias=True)
                (value): Linear(in_features=768, out_features=768, bias=True)
                (dropout): Dropout(p=0.1, inplace=False)
              )
              (output): BertSelfOutput(
                (dense): Linear(in_features=768, out_features=768, bias=True)
                (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
                (dropout): Dropout(p=0.1, inplace=False)
              )
            )
            (intermediate): BertIntermediate(
              (dense): Linear(in_features=768, out_features=3072, bias=True)
            )
            (output): BertOutput(
              (dense): Linear(in_features=3072, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
          )
     ................
          (11): BertLayer(
            (attention): BertAttention(
              (self): BertSelfAttention(
                (query): Linear(in_features=768, out_features=768, bias=True)
                (key): Linear(in_features=768, out_features=768, bias=True)
                (value): Linear(in_features=768, out_features=768, bias=True)
                (dropout): Dropout(p=0.1, inplace=False)
              )
              (output): BertSelfOutput(
                (dense): Linear(in_features=768, out_features=768, bias=True)
                (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
                (dropout): Dropout(p=0.1, inplace=False)
              )
            )
            (intermediate): BertIntermediate(
              (dense): Linear(in_features=768, out_features=3072, bias=True)
            )
            (output): BertOutput(
              (dense): Linear(in_features=3072, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
          )
        )
      )
      (pooler): BertPooler(
        (dense): Linear(in_features=768, out_features=768, bias=True)
        (activation): Tanh()
      )
    )
    (scalar_mix): ScalarMixWithDropout(
      (scalar_parameters): ParameterList(
          (0): Parameter containing: [torch.FloatTensor of size 1]
          (1): Parameter containing: [torch.FloatTensor of size 1]
          (2): Parameter containing: [torch.FloatTensor of size 1]
          (3): Parameter containing: [torch.FloatTensor of size 1]
          (4): Parameter containing: [torch.FloatTensor of size 1]
          (5): Parameter containing: [torch.FloatTensor of size 1]
          (6): Parameter containing: [torch.FloatTensor of size 1]
          (7): Parameter containing: [torch.FloatTensor of size 1]
          (8): Parameter containing: [torch.FloatTensor of size 1]
          (9): Parameter containing: [torch.FloatTensor of size 1]
          (10): Parameter containing: [torch.FloatTensor of size 1]
          (11): Parameter containing: [torch.FloatTensor of size 1]
          (12): Parameter containing: [torch.FloatTensor of size 1]
      )
    )
    (output_embeddings): Embedding(119547, 768, padding_idx=0)
  )
  (decoder): LinearDecoder(
    (linear_outs): ModuleDict(
      (target): Sequential(
        (0): Linear(in_features=768, out_features=768, bias=True)
        (1): Tanh()
      )
      (source): Sequential(
        (0): Linear(in_features=768, out_features=768, bias=True)
        (1): Tanh()
      )
      (target_sentence): Sequential(
        (0): Linear(in_features=768, out_features=768, bias=True)
        (1): Tanh()
        (2): Dropout(p=0.1, inplace=False)
        (3): Linear(in_features=768, out_features=768, bias=True)
        (4): Tanh()
        (5): Dropout(p=0.1, inplace=False)
      )
    )
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (outputs): QEOutputs(
    (word_outputs): ModuleDict(
      (target_tags): WordLevelOutput(
        (linear): Linear(in_features=768, out_features=2, bias=True)
        (loss_fn): CrossEntropyLoss()
      )
      (gap_tags): GapTagsOutput(
        (linear): Linear(in_features=1536, out_features=2, bias=True)
        (loss_fn): CrossEntropyLoss()
      )
      (source_tags): WordLevelOutput(
        (linear): Linear(in_features=768, out_features=2, bias=True)
        (loss_fn): CrossEntropyLoss()
      )
    )
    (sentence_outputs): ModuleDict(
      (sentence_scores): SentenceScoreRegression(
        (sentence_pred): Sequential(
          (linear_0): Linear(in_features=768, out_features=384, bias=True)
          (activation_0): Tanh()
          (dropout_0): Dropout(p=0.0, inplace=False)
          (linear_1): Linear(in_features=384, out_features=1, bias=True)
        )
        (loss_fn): MSELoss()
      )
    )
  )
  (tlm_outputs): TLMOutputs(
    (masked_word_outputs): ModuleDict()
  )
)

cli.py:print(arguments):

{'--example': False,
 '--help': False,
 '--quiet': False,
 '--verbose': False,
 '--version': False,
 'CONFIG_FILE': 'config/bert.yaml',
 'OVERWRITES': [],
 'evaluate': False,
 'predict': False,
 'pretrain': False,
 'search': False,
 'train': True}

cli.py:config_dict:

{
	'run': {
		'experiment_name': 'BERT WMT20 EN-ZH',
		'seed': 42,
		'use_mlflow': False
	},
	'trainer': {
		'deterministic': True,
		'gpus': 1,
		'epochs': 10,
		'main_metric': ['WMT19_MCC', 'PEARSON'],
		'gradient_max_norm': 1.0,
		'gradient_accumulation_steps': 1,
		'amp_level': 'O2',
		'precision': 16,
		'log_interval': 100,
		'checkpoint': {
			'validation_steps': 0.2,
			'early_stop_patience': 10
		}
	},
	'system': {
		'class_name': 'Bert',
		'batch_size': 2,
		'num_data_workers': 1,
		'model': {
			'encoder': {
				'model_name': 'bert-base-multilingual-cased',
				'use_mlp': False,
				'freeze': False
			},
			'decoder': {
				'hidden_size': 768,
				'bottleneck_size': 768,
				'dropout': 0.1
			},
			'outputs': {
				'word_level': {
					'target': True,
					'gaps': True,
					'source': True,
					'class_weights': {
						'target_tags': {
							'BAD': 3.0
						},
						'gap_tags': {
							'BAD': 5.0
						},
						'source_tags': {
							'BAD': 3.0
						}
					}
				},
				'sentence_level': {
					'hter': True,
					'use_distribution': False,
					'binary': False
				},
				'n_layers_output': 2,
				'sentence_loss_weight': 1
			},
			'tlm_outputs': {
				'fine_tune': False
			}
		},
		'optimizer': {
			'class_name': 'adamw',
			'learning_rate': 1e-05,
			'warmup_steps': 0.1,
			'training_steps': 12000
		},
		'data_processing': {
			'share_input_fields_encoders': True
		}
	},
	'data': {
		'train': {
			'input': {
				'source': 'data/WMT20/en-zh/train/train.src',
				'target': 'data/WMT20/en-zh/train/train.mt',
				'alignments': 'data/WMT20/en-zh/train/train.src-mt.alignments',
				'post_edit': 'data/WMT20/en-zh/train/train.pe'
			},
			'output': {
				'source_tags': 'data/WMT20/en-zh/train/train.source_tags',
				'target_tags': 'data/WMT20/en-zh/train/train.tags',
				'sentence_scores': 'data/WMT20/en-zh/train/train.hter'
			}
		},
		'valid': {
			'input': {
				'source': 'data/WMT20/en-zh/dev/dev.src',
				'target': 'data/WMT20/en-zh/dev/dev.mt',
				'alignments': 'data/WMT20/en-zh/dev/dev.src-mt.alignments',
				'post_edit': 'data/WMT20/en-zh/dev/dev.pe'
			},
			'output': {
				'source_tags': 'data/WMT20/en-zh/dev/dev.source_tags',
				'target_tags': 'data/WMT20/en-zh/dev/dev.tags',
				'sentence_scores': 'data/WMT20/en-zh/dev/dev.hter'
			}
		},
		'test': {
			'input': {
				'source': 'data/WMT20/en-zh/test-blind/test.src',
				'target': 'data/WMT20/en-zh/test-blind/test.mt',
				'alignments': 'data/WMT20/en-zh/test-blind/test.src-mt.alignments'
			}
		}
	},
	'verbose': False,
	'quiet': False
}

QESystem:

    def forward(self, batch_inputs):
        encoder_features = self.encoder(batch_inputs)
        features = self.decoder(encoder_features, batch_inputs)
        outputs = self.outputs(features, batch_inputs)

        # For fine-tuning the encoder
        if self.tlm_outputs:
            outputs.update(self.tlm_outputs(encoder_features, batch_inputs))

        return outputs
  • pytorch_lightning 执行流程:

PL的流程很简单,生产流水线,有一个固定的顺序。

这部分代码只执行一次。

1. `__init__()`(初始化 LightningModule )
2. `prepare_data()` (准备数据,包括下载数据、预处理等等)
3. `configure_optimizers()` (配置优化器)

测试 “验证代码”,提前来做的意义在于:不需要等待漫长的训练过程才发现验证代码有错。

1. `val_dataloader()`
2. `validation_step()`
3. `validation_epoch_end()` 

batch 数据:

{
	'source': BatchedSentence(
	    tensor = tensor([
		    [
		        15846, 10491, 82978, 10226, 75312, 10571, 10105, 106095, 11942,
			    38587, 10108, 10955, 11586, 119, 102, 0, 0, 0,
			    0, 0, 0, 0, 0, 0, 0, 0, 0,
			    0, 0, 0, 0, 0, 0, 0, 0, 0,
			    0, 0
			],
		    [
		        10135, 10386, 11288, 10207, 117, 17668, 11945, 39091, 10226,
			    10751, 14310, 107, 17316, 10230, 10108, 11589, 107, 117,
			    17846, 15736, 10188, 13209, 14951, 13240, 117, 11406, 14320,
			    10270, 10108, 10226, 84977, 11305, 68999, 10731, 11202, 31419,
			    119, 102
		    ]
		], device = 'cuda:0'), 
		lengths = tensor([15, 38], device = 'cuda:0'), 
		bounds = tensor([
		    [0, 1, 3, 4, 5, 6, 7, 9, 10, 11, 12, 13, 14, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
		    [0, 1, 2, 3, 4, 5, 7, 8, 9, 10, 11, 12, 14, 15, 16, 17, 18, 19, 20, 21, 23, 24, 25, 26, 27, 28, 29, 30, 34, 36, 37]
	    ], device = 'cuda:0'), 
	    bounds_lengths = tensor([13, 31], device = 'cuda:0'), 
	    strict_masks = tensor([
			[
			    True, True, True, True, True, True, True, True, True, True,
			    True, True, True, True, False, False, False, False, False, False,
				False, False, False, False, False, False, False, False, False, False,
				False, False, False, False, False, False, False, False
			],
			[   True, True, True, True, True, True, True, True, True, True,
				True, True, True, True, True, True, True, True, True, True,
				True, True, True, True, True, True, True, True, True, True,
				True, True, True, True, True, True, True, False
			]
		], device = 'cuda:0'), 
		number_of_tokens = tensor([12, 30], device = 'cuda:0', dtype = torch.int32)
	),

	'target': BatchedSentence(
	    tensor = tensor([
			[   101, 4877, 113183, 113227, 118188, 3031, 2128, 114696, 217,
				2674, 115512, 118188, 5718, 7724, 111978, 6348, 114696, 2079,
				5740, 115551, 2196, 5718, 2773, 111915, 2206, 119, 102,
				0, 0, 0, 0, 0, 0, 0, 0, 0,
				0, 0, 0, 0, 0, 0, 0, 0, 0,
				0, 0, 0, 0, 0, 0, 0
			],
			[   101, 10207, 3642, 10186, 4460, 10386, 4348, 10064, 4536,
				4580, 114286, 7735, 117459, 2196, 5718, 84977, 11305, 68999,
				10731, 4237, 113162, 6063, 10270, 8272, 10064, 4163, 112293,
				2146, 2196, 5718, 4333, 5718, 107, 3976, 5718, 4614,
				113664, 107, 10064, 2460, 111870, 4461, 3228, 118573, 113826,
				217, 4608, 5718, 4784, 112939, 1882, 102
			]
		], device = 'cuda:0'), 
		lengths = tensor([27, 52], device = 'cuda:0'), 
		bounds = tensor([
		    [0, 1, 5, 6, 8, 9, 12, 13, 15, 17, 18, 20, 21, 22, 24, 25, 26, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
		    [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 11, 13, 14, 15, 19, 21, 22, 23, 24, 25, 27, 28, 29, 30, 31, 32, 33, 34, 35, 37, 38, 39, 41, 42, 45, 46,47, 48, 50, 51]
		], device = 'cuda:0'), 
		bounds_lengths = tensor([17, 40], device = 'cuda:0'), 
		strict_masks = tensor([
		    [
		        False, True, True, True, True, True, True, True, True, True,
			    True, True, True, True, True, True, True, True, True, True,
			    True, True, True, True, True, True, False, False, False, False,
			    False, False, False, False, False, False, False, False, False, False,
			    False, False, False, False, False, False, False, False, False, False,
			    False, False
			],
		    [
		        False, True, True, True, True, True, True, True, True, True,
			    True, True, True, True, True, True, True, True, True, True,
			    True, True, True, True, True, True, True, True, True, True,
			    True, True, True, True, True, True, True, True, True, True,
			    True, True, True, True, True, True, True, True, True, True,
			    True, False
			]
		], device = 'cuda:0'), 
		number_of_tokens = tensor([15, 38], device = 'cuda:0', dtype = torch.int32)
	),

	'alignments': tensor([
		    [
			    [1, 0, 0, ..., 0, 0, 0],
			    [0, 1, 0, ..., 0, 0, 0],
			    [0, 0, 0, ..., 0, 0, 0],
			    ..., [0, 0, 0, ..., 0, 0, 0],
			    [0, 0, 0, ..., 0, 0, 0],
			    [0, 0, 0, ..., 0, 0, 0]
		    ],
		    [
			    [1, 0, 0, ..., 0, 0, 0],
			    [0, 0, 0, ..., 0, 0, 0],
			    [0, 0, 0, ..., 0, 0, 0],
			    ..., [0, 0, 0, ..., 0, 0, 0],
			    [0, 0, 0, ..., 0, 1, 0],
			    [0, 0, 0, ..., 0, 0, 0]
		    ]
	    ], device = 'cuda:0', dtype = torch.int32),

	'pe': BatchedSentence(
	    tensor = tensor([
		    [
		        101, 4877, 113183, 113227, 118188, 10060, 15846, 10061, 3031,
		        2128, 114696, 217, 2674, 115512, 118188, 10060, 10955, 11586,
			    10061, 3740, 117490, 2775, 5718, 2212, 114236, 2468, 7475,
			    117244, 5740, 115551, 2196, 5718, 2773, 112165, 1882, 102,
			    0, 0, 0, 0, 0, 0, 0, 0, 0,
			    0, 0, 0, 0, 0
			],
		    [ 
		        101, 10207, 3642, 10186, 4460, 10386, 4348, 10064, 2234,
			    114346, 4520, 7735, 117459, 2196, 5718, 84977, 11305, 68999,
			    10731, 4237, 113162, 6063, 10270, 8272, 8505, 114003, 2196,
			    5718, 4333, 4447, 115521, 100, 17316, 10230, 10108, 11589,
			    100, 10064, 4792, 114213, 5611, 13209, 14951, 13240, 2726,
			    111847, 5162, 112652, 1882, 102
			]
		], device = 'cuda:0'), 
		lengths = tensor([36, 50], device = 'cuda:0'), 
		bounds = tensor([
		    [0, 1, 5, 6, 7, 8, 9, 11, 12, 15, 16, 17, 18, 19, 21, 22, 23, 25, 26, 28, 30, 31, 32, 34, 35, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1
		    ],
		    [0, 1, 2, 3, 4, 5, 6, 7, 8, 10, 11, 13, 14, 15, 19, 21, 22, 23, 24, 26, 27, 28, 29, 31, 32, 34, 35, 36, 37, 38, 40, 41, 43, 44, 46, 48, 49
		    ]
		], device = 'cuda:0'), bounds_lengths = tensor([25, 37], device = 'cuda:0'), 
		strict_masks = tensor([
			[   False, True, True, True, True, True, True, True, True, True,
				True, True, True, True, True, True, True, True, True, True,
				True, True, True, True, True, True, True, True, True, True,
				True, True, True, True, True, False, False, False, False, False,
				False, False, False, False, False, False, False, False, False, False
			],
			[   False, True, True, True, True, True, True, True, True, True,
				True, True, True, True, True, True, True, True, True, True,
				True, True, True, True, True, True, True, True, True, True,
				True, True, True, True, True, True, True, True, True, True,
				True, True, True, True, True, True, True, True, True, False
			]
		],
		device = 'cuda:0'), 
		number_of_tokens = tensor([23, 35], device = 'cuda:0', dtype = torch.int32)
	),

	'source_tags': BatchedSentence(
	    tensor = tensor([
		    [0, 0, 1, 1, 0, 0, 0, 0, 1, 0, 0, 0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2],
		    [1, 1, 1, 1, 1, 1, 0, 1, 1, 0, 1, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 1, 1]
		], device = 'cuda:0'), 
		lengths = tensor([12, 30], device = 'cuda:0'), 
		bounds = tensor([
		    [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
		    [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29]
		], device = 'cuda:0'), 
		bounds_lengths = tensor([12, 30], device = 'cuda:0'), 
		strict_masks = tensor([
			[   True, True, True, True, True, True, True, True, True, True,
				True, True, False, False, False, False, False, False, False, False,
				False, False, False, False, False, False, False, False, False, False
			],
			[   True, True, True, True, True, True, True, True, True, True,
				True, True, True, True, True, True, True, True, True, True,
				True, True, True, True, True, True, True, True, True, True
			]
		], device = 'cuda:0'), 
		number_of_tokens = tensor([12, 30], device = 'cuda:0', dtype = torch.int32)
	),

	'target_tags': BatchedSentence(
	    tensor = tensor([
		    [1, 1, 1, 1, 1, 1, 0, 0, 0, 1, 1, 1, 0, 0, 0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2],
		    [1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1]
		], device = 'cuda:0'), 
		lengths = tensor([15, 38], device = 'cuda:0'), 
		bounds = tensor([
		    [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
		    [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37]
		], device = 'cuda:0'), 
		bounds_lengths = tensor([15, 38], device = 'cuda:0'), 
		strict_masks = tensor([
			[   True, True, True, True, True, True, True, True, True, True,
				True, True, True, True, True, False, False, False, False, False,
				False, False, False, False, False, False, False, False, False, False,
				False, False, False, False, False, False, False, False
			],
			[   True, True, True, True, True, True, True, True, True, True,
				True, True, True, True, True, True, True, True, True, True,
				True, True, True, True, True, True, True, True, True, True,
				True, True, True, True, True, True, True, True
			]
		], device = 'cuda:0'), 
		number_of_tokens = tensor([15, 38], device = 'cuda:0', dtype = torch.int32)
	),

	'sentence_scores': tensor([0.6522, 0.5143], device = 'cuda:0'),

	'gap_tags': BatchedSentence(
	    tensor = tensor([
		    [1, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2],
		    [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
	    ], device = 'cuda:0'), 
	    lengths = tensor([16, 39], device = 'cuda:0'), 
	    bounds = tensor([
		    [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
		    [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38]
		], device = 'cuda:0'), 
		bounds_lengths = tensor([16, 39], device = 'cuda:0'), 
		strict_masks = tensor([
			[   True, True, True, True, True, True, True, True, True, True,
				True, True, True, True, True, True, False, False, False, False,
				False, False, False, False, False, False, False, False, False, False,
				False, False, False, False, False, False, False, False, False
			],
			[   True, True, True, True, True, True, True, True, True, True,
				True, True, True, True, True, True, True, True, True, True,
				True, True, True, True, True, True, True, True, True, True,
				True, True, True, True, True, True, True, True, True
			]
		], device = 'cuda:0'), 
		number_of_tokens = tensor([16, 39], device = 'cuda:0', dtype = torch.int32)),
	
	'binary': tensor([1, 1], device = 'cuda:0')
}

result example:

2021-08-27T08:15:17Z INFO      kiwi.training.callbacks:117: Best validation so far was in epoch 0:
	val_loss: 85.9468, 
	val_WMT19_MCC: 0.5628, 
	val_loss_target_tags: 30.9671,
	val_loss_gap_tags: 22.8837, 

	val_loss_source_tags: 32.0960, 
	val_WMT19_F1_MULT: 0.5807, 
	val_F1_BAD: 0.7120, 
	val_F1_OK: 0.3879, 
	val_WMT19_CORRECT: 0.7842,

	val_target_tags_F1_MULT: 0.3510, 
	val_target_tags_MCC: 0.3523,
	val_target_tags_CORRECT: 0.6544, 

	val_gap_tags_F1_MULT: 0.2005,
	val_gap_tags_MCC: 0.1625, 
	val_gap_tags_CORRECT: 0.9068,

	val_source_tags_F1_MULT: 0.2762, 
	val_source_tags_MCC: 0.2858,
	val_source_tags_CORRECT: 0.6083 

 

 

 

 

 

 

 

 

 

未完待续。。。。。。

posted @ 2021-08-24 22:52  _yanghh  阅读(169)  评论(0编辑  收藏  举报