# /*
#  * @Author: Yue.Fan 
#  * @Date: 2022-03-23 11:35:37 
#  * @Last Modified by:   Yue.Fan 
#  * @Last Modified time: 2022-03-23 11:35:37 
#  */
# from pytorch_pretrained_bert.modeling import BertPreTrainedModel, BertModel
from pytorch_pretrained_bert.configuration_bert import BertConfig
from pytorch_pretrained_bert.modeling_bert import BertLayer, BertPreTrainedModel, BertModel
import torch
from torch import nn
from torch.nn import CrossEntropyLoss
import torch.nn.functional as F
from torchcrf import CRF

class CailModel(BertPreTrainedModel):
    def __init__(self, config, answer_verification=True, hidden_dropout_prob=0.3, need_birnn=False, rnn_dim=128):
        super(CailModel, self).__init__(config)
        self.bert = BertModel(config)
        # TODO check with Google if it's normal there is no dropout on the token classifier of SQuAD in the TF version
        # self.qa_dropout = nn.Dropout(config.hidden_dropout_prob)
        # max_n_answers=3
        self.num_answers = 4  # args.max_n_answers + 1
        self.qa_outputs = nn.Linear(config.hidden_size*4, 2)
        self.qa_classifier = nn.Linear(config.hidden_size, self.num_answers)
        # self.apply(self.init_bert_weights)

        self.answer_verification = answer_verification
        head_num = config.num_attention_heads // 4

        self.coref_config = BertConfig(num_hidden_layers=1, num_attention_heads=head_num,
                                       hidden_size=config.hidden_size, intermediate_size=256 * head_num)

        self.coref_layer = BertLayer(self.coref_config)

        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        out_dim = config.hidden_size
        self.need_birnn = need_birnn
        # 如果为False,则不要BiLSTM层
        if need_birnn:
            self.birnn = nn.LSTM(config.hidden_size, rnn_dim, num_layers=1, bidirectional=True, batch_first=True)
            self.gru = nn.GRU(config.hidden_size, rnn_dim, num_layers=1, bidirectional=True, batch_first=True)
            out_dim = rnn_dim * 2

        self.hidden2tag = nn.Linear(out_dim, 2)   # I O 二分类
        # self.crf = CRF(config.num_labels, batch_first=True)
        self.crf = CRF(2, batch_first=True)


        if self.answer_verification:
            self.retionale_outputs = nn.Linear(config.hidden_size*4, 1)
            self.unk_ouputs = nn.Linear(config.hidden_size, 1)
            self.doc_att = nn.Linear(config.hidden_size*4, 1)
            self.yes_no_ouputs = nn.Linear(config.hidden_size*4, 2)
            # self.yes_no_ouputs_noAttention = nn.Linear(config.hidden_size, 2)
            self.ouputs_cls_3 = nn.Linear(config.hidden_size*4, 3)

            self.beta = 100
            # self.unk_yes_no_outputs_dropout = nn.Dropout(config.hidden_dropout_prob)
            self.unk_yes_no_outputs = nn.Linear(config.hidden_size, 3)

    def forward(self, input_ids, token_type_ids=None, attention_mask=None, start_positions=None, end_positions=None,
                unk_mask=None, yes_mask=None, no_mask=None, answer_masks=None, answer_nums=None, label_ids=None):
        # sequence_output, pooled_output = self.bert(input_ids, token_type_ids, attention_mask,
        #                                            output_all_encoded_layers=True)
        # 以下例子以batch_size=2,seq_len=512, hidden_dim=768为例
        # sequence_output长度为2
        # sequence_output[0].shape=[2,512,768]
        sequence_output, pooled_output = self.bert(input_ids, token_type_ids, attention_mask)
        # print("sequence_output:", sequence_output[0].shape)
        # print("pooled_output.shape:", pooled_output.shape)
        sequence_output = sequence_output[1]
        # [2, 512, 768]
        sequence_output_IO = sequence_output[-1]  # 取最后一层的输出
        # sequence_output:[2, 512, 768*4]
        sequence_output =[-4], sequence_output[-3], sequence_output[-2],
                                     sequence_output[-1]), -1)    # 拼接BERT最后四层

        if self.answer_verification:
            batch_size = sequence_output.size(0)
            seq_length = sequence_output.size(1)
            hidden_size = sequence_output.size(2)
            # [2*512, 3072]
            sequence_output_matrix = sequence_output.view(batch_size*seq_length, hidden_size)
            # [2*512 , 1]
            rationale_logits = self.retionale_outputs(sequence_output_matrix)
            # print(rationale_logits.shape)
            # [2, 512]
            rationale_logits = rationale_logits.view(batch_size, seq_length)
            # [2, 512]
            # 这里计算的是问题和文本之间的一个注意力
            rationale_logits = F.softmax(rationale_logits, dim=-1)

            # [batch, seq, hidden] [batch, seq_len, 1] = [batch, seq, hidden]
            # [2, 512, 3072]
            final_hidden = sequence_output*rationale_logits.unsqueeze(2)
            # print(final_hidden.shape)
            # [2*512, 3072]
            sequence_output = final_hidden.view(batch_size*seq_length, hidden_size)
            logits = self.qa_outputs(sequence_output).view(batch_size, seq_length, 2)
            start_logits, end_logits = logits.split(1, dim=-1)
            start_logits = start_logits.squeeze(-1)
            end_logits = end_logits.squeeze(-1)
            # [000,11111] 1代表了文章
            # [batch, seq_len] [batch, seq_len]
            rationale_logits = rationale_logits * attention_mask.float()
            # [batch, seq_len, 1] [batch, seq_len]
            start_logits = start_logits*rationale_logits
            end_logits = end_logits*rationale_logits

            if self.need_birnn:
                sequence_output_IO, _ = self.birnn(sequence_output_IO)
                # sequence_output_IO, _ = self.gru(sequence_output_IO)
            sequence_output_IO = self.dropout(sequence_output_IO)
            # [2, 512, 2]  每一个token进行二分类
            emissions = self.hidden2tag(sequence_output_IO)

            # answers num
            # [2, 3] 进行答案数量的分类
            switch_logits = self.qa_classifier(pooled_output)  # 用cls位置向量进行答案数量分类

            # extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
            # extended_attention_mask =  # fp16 compatibility
            # extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
            # sequence_output_sw = self.coref_layer(sequence_output_switch, extended_attention_mask)[0]
            # switch_logits = self.qa_classifier(sequence_output_sw[:,0,:])

            # unk
            # [2, 1]
            unk_logits = self.unk_ouputs(pooled_output)

            # doc_attn
            # [2*512, 1]
            attention = self.doc_att(sequence_output)
            # [2, 512]
            attention = attention.view(batch_size, seq_length)
            # 这里计算的是文本之间的注意力
            # [2, 512]
            attention = attention*token_type_ids.float() + (1-token_type_ids.float())*VERY_NEGATIVE_NUMBER

            attention = F.softmax(attention, 1)
            # [2, 512, 1]
            attention = attention.unsqueeze(2)
            # [2, 512, 1]*[2, 512, 3072] = [2, 512, 3072]
            attention_pooled_output = attention*final_hidden
            # [2, 3072]
            attention_pooled_output = attention_pooled_output.sum(1)

            # 去掉attention
            # attention_pooled_output = pooled_output
            # yes_no_logits = self.yes_no_ouputs_noAttention(attention_pooled_output)
            # [2, 2]
            yes_no_logits = self.yes_no_ouputs(attention_pooled_output)
            # [2, 1]
            yes_logits, no_logits = yes_no_logits.split(1, dim=-1)

            # unk_yes_no_logits = self.ouputs_cls_3(attention_pooled_output)
            # unk_logits, yes_logits, no_logits = unk_yes_no_logits.split(1, dim=-1)

            # sequence_output = self.qa_dropout(sequence_output)
            logits = self.qa_outputs(sequence_output)
            # self attention
            start_logits, end_logits = logits.split(1, dim=-1)
            start_logits = start_logits.squeeze(-1)
            end_logits = end_logits.squeeze(-1)

            # answers num
            switch_logits = self.qa_classifier(pooled_output)  # 用cls位置向量进行答案数量分类

            # # unk yes_no_logits
            # pooled_output = self.unk_yes_no_outputs_dropout(pooled_output)
            unk_yes_no_logits = self.unk_yes_no_outputs(pooled_output)
            unk_logits, yes_logits, no_logits= unk_yes_no_logits.split(1, dim=-1)
        # # [batch, 1]
        # unk_logits = unk_logits.squeeze(-1)
        # yes_logits = yes_logits.squeeze(-1)
        # no_logits = no_logits.squeeze(-1)

        # token的logits,未知的logits, yes的logits,no的logits拼接
        # [2, 515]
        # 512标识没有答案,513标识YES,514标识NO
        new_start_logits =[start_logits, unk_logits, yes_logits, no_logits], 1)
        new_end_logits =[end_logits, unk_logits, yes_logits, no_logits], 1)

        if self.answer_verification and start_positions is not None and end_positions is not None:
            # If we are on multi-GPU, split add a dimension
            if len(start_positions.size()) > 1:
                start_positions = start_positions.squeeze(-1)
            if len(end_positions.size()) > 1:
                end_positions = end_positions.squeeze(-1)

            if len(answer_nums.size()) > 1:
                answer_nums = answer_nums.squeeze(-1)
            # sometimes the start/end positions are outside our model inputs, we ignore these terms
            ignored_index = new_start_logits.size(1)
            start_positions.clamp_(0, ignored_index)
            end_positions.clamp_(0, ignored_index)

            loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
            # torch.unbind:不改变原来的tensor的shape,只是返回展开后的切片
            # print(answer_masks.shape)
            # [1,3] --> (tensor([1]), tensor([1]), tensor([0]))
            # print(torch.unbind(answer_masks, dim=1))
            # print(torch.unbind(start_positions, dim=1))
            start_positions = torch.tensor([[1,2,3], [4,5,6]])
            answer_mask = torch.tensor([[1,1,0],[1,0,0]])
            print(torch.unbind(start_positions, dim=1))
            print(torch.unbind(answer_mask, dim=1))
            (tensor([1, 4]), tensor([2, 5]), tensor([3, 6]))
            (tensor([1, 1]), tensor([1, 0]), tensor([0, 0]))
            start_losses = [(loss_fct(new_start_logits, _start_positions) * _span_mask) \
                            for (_start_positions, _span_mask) \
                            in zip(torch.unbind(start_positions, dim=1), torch.unbind(answer_masks, dim=1))]  # torch.unbind 移除指定维后,返回一个元组,包含了沿着指定维切片后的各个切片
            end_losses = [(loss_fct(new_end_logits, _end_positions) * _span_mask) \
                          for (_end_positions, _span_mask) \
                          in zip(torch.unbind(end_positions, dim=1), torch.unbind(answer_masks, dim=1))]
            loss_IO = -1 * self.crf(emissions, label_ids, mask=attention_mask.byte())

            switch_loss = loss_fct(switch_logits, answer_nums)

            # start_loss = loss_fct(new_start_logits, start_positions)
            # end_loss = loss_fct(new_end_logits, end_positions)

            rationale_positions = token_type_ids.float()
            alpha = 0.25
            gamma = 2.
            # 这里还可以这么干,有意思。
            rationale_loss = -alpha * ((1 - rationale_logits) ** gamma) * rationale_positions * torch.log(
                rationale_logits + 1e-8) - (1 - alpha) * (rationale_logits ** gamma) * (
                                     1 - rationale_positions) * torch.log(1 - rationale_logits + 1e-8)
            rationale_loss = (rationale_loss*token_type_ids.float()).sum() / token_type_ids.float().sum()

            # s_e_loss = sum(start_losses + end_losses) + rationale_loss*self.beta
            # total_loss = torch.mean(s_e_loss + switch_loss)

            s_e_loss = sum(start_losses + end_losses)
            total_loss = torch.mean(s_e_loss + switch_loss + loss_IO) + rationale_loss * self.beta
            # total_loss = (start_losses + end_losses) / 2

            return total_loss

        elif start_positions is not None and end_positions is not None:
            # If we are on multi-GPU, split add a dimension
            if len(start_positions.size()) > 1:
                start_positions = start_positions.squeeze(-1)
            if len(end_positions.size()) > 1:
                end_positions = end_positions.squeeze(-1)
            # sometimes the start/end positions are outside our model inputs, we ignore these terms
            ignored_index = new_start_logits.size(1)
            start_positions.clamp_(1, ignored_index)
            end_positions.clamp_(1, ignored_index)

            loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
            start_loss = loss_fct(new_start_logits, start_positions)
            end_loss = loss_fct(new_end_logits, end_positions)
            total_loss = (start_loss + end_loss) / 2
            return total_loss
            IO_logits = self.crf.decode(emissions, attention_mask.byte())
            for io in IO_logits:
                while len(io) < 512:
            IO_logits = IO_logits.cuda()
            return start_logits, end_logits, unk_logits, yes_logits, no_logits, switch_logits, IO_logits

class MultiLinearLayer(nn.Module):
    def __init__(self, layers, hidden_size, output_size, activation=None):
        super(MultiLinearLayer, self).__init__() = nn.Sequential()

        for i in range(layers-1):
  'linear', nn.Linear(hidden_size, hidden_size))
  'relu', nn.ReLU(inplace=True))'linear', nn.Linear(hidden_size, output_size))

    def forward(self, x):

if __name__ == '__main__':
    import torch
    input_ids = torch.tensor([[101, 839, 5442, 6158, 6843, 2518, 1525, 763, 1278, 7368, 8043, 102, 5307, 2144, 4415, 3389, 3209, 131,
                123, 121, 122, 125, 2399, 129, 3299, 127, 3189, 677, 1286, 5276, 128, 4157, 117, 1333, 1440, 7942, 166,
                121, 3341, 1168, 6158, 1440, 5529, 166, 124, 5307, 5852, 4638, 3717, 3799, 2421, 1079, 6579, 743, 697,
                1259, 3717, 3799, 117, 4507, 1333, 1440, 1350, 1071, 707, 3198, 7416, 3341, 4638, 676, 6762, 6756, 1923,
                5632, 6121, 6566, 6569, 3021, 6817, 3717, 3799, 511, 1762, 3021, 6817, 6814, 4923, 704, 117, 6158, 1440,
                2421, 1079, 1831, 3123, 4638, 3717, 3799, 948, 1847, 678, 3341, 2199, 1333, 1440, 4790, 839, 511, 2496,
                1921, 677, 1286, 117, 1333, 1440, 6158, 6843, 2518, 727, 3926, 2356, 5018, 676, 782, 3696, 1278, 7368,
                117, 5307, 7305, 6402, 3466, 3389, 6402, 3171, 711, 100, 5587, 123, 510, 124, 3491, 860, 7755, 2835,
                510, 2340, 1079, 6674, 7755, 2835, 100, 117, 1066, 3118, 1139, 1278, 4545, 6589, 127, 128, 122, 1039,
                511, 1728, 4567, 2658, 698, 7028, 117, 2496, 1921, 6760, 1057, 3946, 2336, 1278, 4906, 1920, 2110, 7353,
                2247, 5018, 753, 1278, 7368, 6822, 6121, 857, 7368, 3780, 4545, 117, 754, 123, 121, 122, 125, 2399, 129,
                3299, 122, 122, 3189, 1762, 1059, 7937, 678, 6121, 100, 5587, 3491, 7755, 2835, 1147, 1908, 1121, 1327,
                1079, 1743, 2137, 3318, 100, 1469, 100, 2340, 1079, 6674, 7755, 2835, 1079, 1743, 2137, 3318, 100, 117,
                754, 123, 121, 122, 125, 2399, 129, 3299, 123, 122, 3189, 1139, 7368, 117, 1066, 6369, 3118, 1139, 857,
                7368, 6589, 4500, 126, 125, 126, 121, 126, 119, 126, 128, 1039, 511, 5307, 3315, 7368, 1999, 2805, 3946,
                2336, 1921, 3633, 1385, 3791, 7063, 2137, 2792, 7063, 2137, 117, 1333, 1440, 4638, 3655, 4565, 4923,
                2428, 711, 736, 5277, 117, 5852, 1075, 3309, 7361, 6397, 2137, 711, 124, 702, 3299, 113, 794, 1358, 839,
                722, 3189, 6629, 6369, 5050, 114, 117, 753, 3309, 2797, 3318, 113, 2858, 7370, 1079, 1743, 2137, 114,
                4638, 5852, 1075, 3309, 7361, 6397, 2137, 711, 1288, 702, 3299, 117, 1400, 5330, 3780, 4545, 6589, 5276,
                7444, 122, 121, 121, 121, 121, 1039, 2772, 2902, 2141, 7354, 1394, 4415, 1355, 4495, 6589, 4500, 711,
                1114, 511, 1333, 1440, 857, 7368, 3780, 4545, 1350, 1139, 7368, 1400, 117, 6158, 1440, 5529, 166, 124,
                1350, 1071, 1036, 2094, 3295, 1343, 2968, 3307, 2400, 6843, 677, 5852, 1075, 1501, 511, 1352, 3175,
                2218, 6608, 985, 752, 2139, 3187, 3791, 6809, 2768, 671, 5636, 2692, 6224, 117, 3125, 3868, 6401, 511,
                809, 677, 752, 2141, 117, 3300, 1333, 1440, 6716, 819, 6395, 510, 697, 6158, 1440, 2787, 5093, 6395,
                3209, 510, 7305, 6402, 4567, 1325, 1350, 1355, 4873, 1063, 819, 510, 857, 7368, 6589, 4500, 1355, 4873,
                1350, 3926, 1296, 510, 1139, 7368, 6381, 2497, 510, 1278, 4545, 6395, 3209, 741, 510, 1385, 3791, 7063,
                2137, 2692, 6224, 741, 510, 7063, 2137, 6589, 1355, 4873, 1350, 2431, 2144, 5011, 2497, 1762, 3428, 858,
                6395, 117, 3315, 7368, 750, 809, 6371, 2137, 511, 1333, 1440, 2990, 897, 4638, 6228, 7574, 6598, 3160,
                117, 1377, 809, 6395, 102]])
    input_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
                 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
                 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
                 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
                 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
                 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
                 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
                 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
                 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
                 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
                 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
                 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
                 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
                 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
                 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
                 1, 1]])
    segment_ids = torch.tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
                  1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
                  1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
                  1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
                  1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
                  1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
                  1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
                  1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
                  1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
                  1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
                  1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
                  1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
                  1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
                  1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
                  1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
                  1, 1]])
    paragraph_len = 499
    label_ids = torch.tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
               0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
               0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
               0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
               0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1,
               1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
               0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
               0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
               0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
               0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
               0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
               0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
               0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
               0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
               0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]])
    start_positions = torch.tensor([[118, 174, 0]])
    end_positions = torch.tensor([[126, 185, 0]])
    is_impossible =False
    unk_mask = torch.tensor([[0]])
    yes_mask = torch.tensor([[0]])
    no_mask = torch.tensor([[0]])
    answer_masks = torch.tensor([[1, 1, 0]])
    answer_nums = torch.tensor([2])

    class Args:
        bert_config_file = 'model_hub/chinese-bert-wwm-ext/config.json'
        need_birnn = False
        rnn_dim = 128
    args = Args()
    config = BertConfig.from_json_file(args.bert_config_file)
    model = CailModel(config, need_birnn=args.need_birnn, rnn_dim=args.rnn_dim)
    # print(model)
    loss = model(input_ids, segment_ids, input_mask, start_positions, end_positions,
                 unk_mask, yes_mask, no_mask, answer_masks, answer_nums, label_ids)


  • query和context的注意力、context自己之间的注意力。
  • 【答案的开始的loss、答案的结束的loss、没有答案的loss、答案为yes的loss、答案为no的loss】、答案数目的loss、每一个token是否属于答案的loss、token_type的loss。
