点此进入CSDN

点此添加QQ好友 加载失败时会显示




RALLM 检索增强LLM架构

 

import copy
import os
import sys

dir_path = os.path.dirname(os.path.realpath(__file__))
sys.path.insert(0, dir_path)
import contextlib
import torch.utils.checkpoint
from torch.nn import LayerNorm
from torch import nn
from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode
from modeling_perceive_sampler import BertConfig, BertLMHeadModel
from transformers.utils import logging
logger = logging.get_logger(__name__)
import transformers
from transformers import PreTrainedModel, AutoTokenizer, AutoModelForMaskedLM,AutoModel,BertTokenizer,GPT2LMHeadModel,PretrainedConfig,GPT2Model,GPT2Tokenizer,LongformerTokenizer, LongformerModel
os.environ["CUDA_VISIBLE_DEVICES"] = "0"  # 使用第一个GPU
import argparse
import math


class RALLM(nn.Module):

    def __init__(self,args):
        super(RALLM,self).__init__()

        self.is_compress = args.is_compress    
        self.use_lora = args.use_lora  
        print('Init LLM ... ')

        if args.LLM_model == "Baichuan2_13B":
            self.LLM_model_name = "Baichuan2-13B-Chat"
            self.LLM_hidden_size = 5120
        elif args.LLM_model == "Baichuan2_7B":
            self.LLM_model_name = "baichuan2_7B"
            self.LLM_hidden_size = 4096

        
        self.LLM_model = transformers.AutoModelForCausalLM.from_pretrained(
            self.LLM_model_name,
            device_map=f"cuda:{args.local_rank}",
            trust_remote_code=True,
            torch_dtype=torch.bfloat16,
            # cache_dir=training_args.cache_dir,
        )
        self.LLM_tokenizer = transformers.AutoTokenizer.from_pretrained(
            self.LLM_model_name,
            use_fast=False,
            trust_remote_code=True,
            model_max_length=4096,
            # cache_dir=training_args.cache_dir,
        )

        self.flag_context_start = nn.Parameter(torch.zeros([1, 1, self.LLM_hidden_size]))#.to(self.device)
        self.flag_context_end = nn.Parameter(torch.zeros([1, 1, self.LLM_hidden_size]))#.to(self.device)
        self.flag_context_start.requires_grad = False
        self.flag_context_end.requires_grad = False

        self.device = self.LLM_model.device
        self.user_token = self.LLM_tokenizer._convert_id_to_token(195)
        self.assisent_token = self.LLM_tokenizer._convert_id_to_token(196)
        self.eoa = self.LLM_tokenizer._convert_id_to_token(2)

        print("user_token:",self.user_token,"assisent_token:",self.assisent_token,"eoa:",self.eoa)
        print('Done')

        print('Init context encoder ... ')

        self.init_context_encoder(args)
 
        print('Done')

    def init_Qformer(self,num_query_token,num_features):
        self.Qformer  = self.init_qformer(num_query_token, num_features,cross_attention_freq=1)
        self.Qformer.bert.embeddings.word_embeddings = None
        self.Qformer.bert.embeddings.position_embeddings = None
        for layer in self.Qformer.bert.encoder.layer:
            layer.output = None
            layer.intermediate = None
        self.Qformer.cls = None


    @classmethod
    def init_qformer(cls,
                        num_query_token,
                        vision_width,
                        cross_attention_freq=2,
                        pretrain=True):
        encoder_config = BertConfig()
        encoder_config.num_hidden_layers = 2
        encoder_config.hidden_size = vision_width
        encoder_config.encoder_width = vision_width
        encoder_config.num_attention_heads = vision_width//64
        # insert cross-attention layer every other block
        encoder_config.add_cross_attention = True
        encoder_config.cross_attention_freq = cross_attention_freq
        encoder_config.query_length = num_query_token
        Qformer = BertLMHeadModel(config=encoder_config)

        return Qformer


    def init_context_encoder(self,args):
        
        num_query_token = args.query_tokens = 0

        if args.encoder == "bert_base":
            self.context_tokenizer = AutoTokenizer.from_pretrained("bert_base_chinese")
            self.context_encoder = AutoModelForMaskedLM.from_pretrained("bert_base_chinese",output_hidden_states=True)
            num_features = 768

        if args.encoder == "bert_large":
            self.context_tokenizer = AutoTokenizer.from_pretrained("bert_large_chinese",max_length=2000)
            self.context_encoder = AutoModelForMaskedLM.from_pretrained("bert_large_chinese",output_hidden_states=True)
            num_features = 1024

        if args.encoder == "gpt2_xlarge":
            self.context_tokenizer = BertTokenizer.from_pretrained("gpt2_chinese_xlarge")
            self.context_encoder = GPT2LMHeadModel.from_pretrained("gpt2_chinese_xlarge")
            num_features = 1600

        if args.encoder == "gpt2_large":
            self.context_tokenizer = BertTokenizer.from_pretrained("gpt2_chinese_large")
            self.context_encoder = GPT2LMHeadModel.from_pretrained("gpt2_chinese_large")
            num_features = 1280

        if args.encoder == "gpt2_large_en":
            self.context_tokenizer = GPT2Tokenizer.from_pretrained("/data2/xinyuuliu/Baichuan2_qformer_bert/gpt2-large-EN")
            self.context_tokenizer.add_special_tokens({'pad_token': '[PAD]'})
            self.context_encoder = GPT2Model.from_pretrained("/data2/xinyuuliu/Baichuan2_qformer_bert/gpt2-large-EN")
            num_features = 1280

        if args.encoder == "longformer":
            self.context_tokenizer = LongformerTokenizer.from_pretrained('longformer')
            self.context_encoder = LongformerModel.from_pretrained('longformer')
            num_features = 768

        if args.encoder == "longformer_large":
            self.context_tokenizer = LongformerTokenizer.from_pretrained('longformer-large')
            self.context_encoder = LongformerModel.from_pretrained('longformer-large')
            num_features = 1024



        # bert_tokenizer = AutoTokenizer.from_pretrained("bert_base_chinese",max_length=2000)
        # bert_encoder = AutoModelForMaskedLM.from_pretrained("longformer_zh",output_hidden_states=True) #.to(device)
        self.context_encoder = self.context_encoder.to(self.device)

        self.context_score = torch.nn.ModuleList([
                torch.nn.Linear(num_features, 64),
                torch.nn.Tanh(),
                torch.nn.Linear(64, 1),
            ]) # 768是BERT的隐藏状态维度,1是目标输出维度

        self.context2llm_proj = torch.nn.Linear(num_features, self.LLM_hidden_size)  # 768是BERT的隐藏状态维度,1是目标输出维度
        self.llm_proj = torch.nn.Linear(num_features, self.LLM_hidden_size)
        # model.embed2qformer_proj = torch.nn.Linear(num_features, 768)

        self.ln_features = LayerNorm(num_features) 
        self.init_Qformer(num_query_token,num_features)

        # del model.internlm_proj
        # del model.Qformer
        # torch.cuda.empty_cache()  # 释放显存

        # if device:
        # model = self.model.to(self.device)


    def encode_text(self, text, add_special_tokens=False):

        input_ids = self.LLM_tokenizer.encode(text)
        input_ids = torch.LongTensor([input_ids]).to(self.device)

        if self.use_lora:
            text_embeds = self.LLM_model.base_model.model.model.embed_tokens(input_ids)
        else:
            text_embeds = self.LLM_model.model.embed_tokens(input_ids)
        return text_embeds


    def calculate_compressibility(self,x,k=0):
        return (x * k*(9 / 1000) + 1) * 111.111 / (x + 111.111)


    # 批量输入句子
    def batch_input_sentences(self,sentences):
        input_ids_list = [self.context_tokenizer.encode(sentence,return_tensors="pt",padding='max_length', max_length=2500, truncation=True) for sentence in sentences]
        max_length = max(len(input_ids[0]) for input_ids in input_ids_list)
        input_ids_padded = [torch.cat([input_ids, torch.zeros(1, max_length - input_ids.size(1), dtype=torch.long)], dim=1) for input_ids in input_ids_list]
        input_ids_tensor = torch.cat(input_ids_padded, dim=0)
        return input_ids_tensor

    def encode_context(self, text_list):
        if text_list is None:
            return None

        inputs_LLMs = []
        input_atts = []
        # print(text_list)
        for text in text_list:
            # input_ids = self.context_tokenizer.encode(text, add_special_tokens=True, return_tensors="pt")
            # 对文本列表进行编码并进行最长的填充
            # encoded_ids = self.context_tokenizer(text, padding=True, return_tensors="pt", truncation=True)
            input_ids = self.batch_input_sentences(text)
            input_ids =input_ids.to(self.device)
            # input_ids = encoded_ids.data["input_ids"].to(self.device)
            # attention_mask = encoded_ids.data["attention_mask"].to(self.device)
            outputs = self.context_encoder(input_ids,output_hidden_states=True)
            # 提取最后一层的隐藏状态向量
            embedding,last_hidden_state = outputs.hidden_states[0],outputs.hidden_states[-1] #outputs.logits

            x = last_hidden_state
            for layer in self.context_score:
                x = layer(x)

            output = x
            # output = self.context_score(last_hidden_state)  # 进行线性变换
            
            batch,seq_len,ebd_dim = last_hidden_state.size()
            
            # compressibility = -0.0009 * seq_len+1 #压缩率计算长度越低压缩率越低,长度越长,压缩率越高。线性压缩不好 x*f(x) 不是单调递减的
            # compressibility = 111.111/(seq_len+111.111) #重新设计非线性压缩 10以下不压缩,0-1000 x*f(x) 递减 
            
            if self.is_compress:
                compressibility = self.calculate_compressibility(seq_len,0)
                K = math.ceil(seq_len*compressibility)
            else:
                K = seq_len

            # 使用 torch.topk 函数获取 top k 的索引
            topk_indices = torch.topk(output, K,dim=1).indices
            # print(topk_indices)
            topk_indices, sorted_indices = torch.sort(topk_indices,dim=1)   #恢复原文顺序
            # print(topk_indices)

            # 计算 top k 对应的 last_hidden_state
            topk_selected_last_hidden_state = torch.gather(last_hidden_state, 1, topk_indices.expand(-1, -1, ebd_dim))
            # print(last_hidden_state)
            # print(topk_selected_last_hidden_state)
            topk_selected_embedding = torch.gather(embedding, 1, topk_indices.expand(-1, -1, ebd_dim))
            # bert_text_atts = torch.gather(attention_mask, 1, torch.squeeze(topk_indices, dim=2))

            bert_text_embeds = self.ln_features(last_hidden_state)
            bert_text_atts = torch.ones(bert_text_embeds.size()[:-1],dtype=torch.long).to(self.device)
            # query_tokens = self.query_tokens.expand(bert_text_atts.shape[0], -1,-1)
            # query_tokens = topk_selected_embedding
            query_tokens = topk_selected_last_hidden_state

            query_output = self.Qformer.bert(
                query_embeds=query_tokens,
                encoder_hidden_states=bert_text_embeds,
                encoder_attention_mask=bert_text_atts,
                return_dict=True,
            )

            # topk_context_hidden_state = self.context2llm_proj(topk_selected_last_hidden_state)

            inputs_LLM = self.llm_proj(query_output.last_hidden_state)

            inputs_LLM = torch.cat([
                self.flag_context_start.expand(batch, -1, -1),
                # topk_context_hidden_state,
                inputs_LLM,
                self.flag_context_end.expand(batch, -1, -1)
            ],dim=1).view(-1, self.LLM_hidden_size)

            input_att = torch.cat([torch.ones((batch,1)).to(self.device),bert_text_atts,torch.ones((batch,1)).to(self.device)],dim=1).view(-1)
            # print(inputs_LLM.shape)
            inputs_LLMs.append(inputs_LLM)
            input_atts.append(input_att)
        # context_inputs = torch.stack(inputs_LLMs)
        return inputs_LLMs,input_atts



    def wrap_prompt(self,
                    text_embeds,
                    context_embeds=None,
                    history=None,
                    add_special=True):
        if add_special:
            if history is None:
                prompt_segs = [
                    self.user_token,
                    self.assisent_token
                ]
            else:
                prompt_segs = [self.user_token, self.assisent_token]
        else:
            prompt_segs = [self.user_token, self.assisent_token]  # used in wrap history
        prompt_seg_embeds = []
        for i, seg in enumerate(prompt_segs):
            if history is not None:
                add_special_tokens = False
            else:
                add_special_tokens = i == 0
            seg_embeds = self.encode_text(
                seg, add_special_tokens=add_special_tokens)
            prompt_seg_embeds.append(seg_embeds)
        if context_embeds is None:
            context_embeds = text_embeds.new_empty(text_embeds.size(0), 0,
                                                text_embeds.size(-1))
        else:
            # 在第一个维度(索引为0)添加一个维度
            context_embeds = context_embeds[0].unsqueeze(0)
        prompt_seg_embeds = [
            prompt_seg_embeds[0], text_embeds,context_embeds,  prompt_seg_embeds[1]
        ]
        prompt_embeds = torch.cat(prompt_seg_embeds, dim=1)
        if history is not None:
            prompt_embeds = torch.cat([*history, prompt_embeds], dim=1)
        return prompt_embeds


    def generate(self, text, context=None, **kwargs):
        text = text.replace("<context>","").replace(self.user_token,"").replace(self.assisent_token,"")
        text_embeds = self.encode_text(text)
        context_embeds,_ = self.encode_context(context)
        prompt_embeds = self.wrap_prompt(text_embeds, context_embeds)
        # out_embeds = self.LLM_model.generate(input_ids=None,
        #     inputs_embeds=prompt_embeds, **self.get_gen_args(**kwargs))
        # out_text = self.decode_text(out_embeds)
        outputs = self.LLM_model.generate(input_ids=None,inputs_embeds=prompt_embeds, generation_config=self.LLM_model.generation_config)
        response = self.LLM_tokenizer.decode(outputs[0], skip_special_tokens=True)
        return response

    def chat(self, text, context=None, history=None, **kwargs):
        text_embeds = self.encode_text(text)
        img_embeds = self.encode_context(context)
        prompt_embeds = self.wrap_prompt(text_embeds,
                                            img_embeds,
                                            history=history)
        out_embeds = self.internlm_model.generate(
            inputs_embeds=prompt_embeds, **self.get_gen_args(**kwargs))
        out_text = self.decode_text(out_embeds)

        # trunc at eoh and eoa
        clean_out_text_token_ids = self.tokenizer(
            out_text, return_tensors='pt').input_ids.to(self.device)
        clean_out_text_embeds = self.internlm_model.model.embed_tokens(
            clean_out_text_token_ids)
        clean_prompt_embeds = self.wrap_prompt(text_embeds,
                                                img_embeds,
                                                add_special=False)
        cur_history = torch.cat([clean_prompt_embeds, clean_out_text_embeds],
                                dim=1)
        if history is None:
            history = []
        history.append(cur_history)
        return out_text, history

    def align_text(self, samples, has_context=False):  ### add eos and eoa 返回<context>后的text

        text_new = []
        if has_context:  ### remove the first user to wrap image features
            text = [
                t.split("<context>")[-1] for t in samples["text_input"]
            ]
        else:
            text = [t for t in samples["text_input"]]

        text = [t + self.eoa  for t in text]
        for i in range(len(text)):
            temp = text[i]
            # temp = temp.replace('<|Bot|>', self.eoh + ' <|Bot|>')
            # if temp.find(self.eoh) > temp.find(self.eoa):
            #     temp = temp.replace(self.eoa, '', 1)
            text_new.append(temp)
        return text_new

    def prompt_wrap(self, context_embeds,context_atts, prompt_list):
        batch_size = len(context_embeds)
        p_before = [prompt.split('<context>')[0] for prompt in prompt_list]
        p_before_tokens = self.LLM_tokenizer(p_before,
                                        padding=True,
                                        truncation=True,
                                            return_tensors="pt",
                                            add_special_tokens=True).to(
                                                self.device)

        if self.use_lora:
            p_before_embeds = self.LLM_model.base_model.model.model.embed_tokens(p_before_tokens.input_ids).expand(batch_size, -1, -1)
        else:
            p_before_embeds = self.LLM_model.model.embed_tokens(p_before_tokens.input_ids).expand(batch_size, -1, -1)

        # wrapped_context_embeds = torch.cat([p_before_embeds, context_embeds], dim=1)
        # wrapped_context_embeds = torch.cat([p_before_embeds]+context_embeds, dim=1)
        wrapped_context_embeds = []
        wrapped_atts_context = []
        wrapped_target = []
        for i, (context_embed,context_att) in enumerate(zip(context_embeds,context_atts)):
            # 将p_before_embeds的每个序列与相应的张量在序列长度维度上拼接
            concatenated = torch.cat((p_before_embeds[i], context_embed), dim=0)
            wrapped_context_embeds.append(concatenated)
            # concatenated_att = torch.cat((torch.ones(p_before_embeds[i].size()[:-1],dtype=torch.long).to(self.device),context_att),dim=0)
            wrapped_atts_context.append(torch.ones(concatenated.size()[:-1],dtype=torch.long).to(self.device))
            # wrapped_atts_context.append(concatenated_att)
            target = torch.ones(concatenated.size()[:-1], dtype=torch.long) * -100
            target[0] = 2
            target = target.to(self.device)
            wrapped_target.append(target)

        # wrapped_atts_context = torch.ones(wrapped_context_embeds.size()[:-1],
        #                                 dtype=torch.long).to(self.device)

        # wrapped_target = torch.ones(
        #     batch_size, wrapped_context_embeds.shape[1], dtype=torch.long).to(
        #         self.device) * -100

        return wrapped_context_embeds, wrapped_atts_context, wrapped_target

    def text2emb(self, text):
        to_regress_tokens = self.LLM_tokenizer(text,
                                           return_tensors="pt",
                                           padding="longest",
                                           truncation=True,
                                           max_length=4096,
                                           add_special_tokens=False).to(
                                               self.device)

        targets = self.mask_human_targets(to_regress_tokens.input_ids)
        targets = targets.to(self.device)

        return to_regress_tokens, targets


    def mask_human_targets(self, input_ids, pure=False):
        target_batch = []
        for bs in range(input_ids.shape[0]):
            cur_idx = 0
            ids = input_ids[bs]
            targets = copy.deepcopy(ids)
            last_eoa = 0 
            last_eoh = 0
            for i, temp_id in enumerate(ids):
                if temp_id == 196:  #### end of human
                    targets[cur_idx:i+1] = -100

            target_batch.append(targets.unsqueeze(0))

        target_batch = torch.cat(target_batch, dim=0)
        target_batch[target_batch==0]=-100
        # print(input_ids)
        # print(target_batch)
        return target_batch

    def forward(self,
                input_ids=None,
                attention_mask=None,
                inputs_embeds=None,
                labels=None,
                output_attentions=None,
                output_hidden_states=None,
                return_dict=None,
                context = None,
                text_input = None,
                **kwargs):

        # samples = kwargs #.get('samples')
        # has_context = 'context' in samples.keys()
        if context:
            has_context = True
        else:
            has_context = False

        samples = {"text_input":text_input,"context":context}

        ### encode text
        text = self.align_text(samples=samples, has_context=has_context) #获取<context> 后面的text
        to_regress_tokens, targets = self.text2emb(text) #返回token和target

        if self.use_lora:
            to_regress_embeds = self.LLM_model.base_model.model.model.embed_tokens(to_regress_tokens.input_ids)
        else:
            to_regress_embeds = self.LLM_model.model.embed_tokens(to_regress_tokens.input_ids)


        attention_mask = to_regress_tokens.attention_mask

        if has_context:
            prompt = samples["text_input"]

            ### encode context
            context = samples["context"]
            context_embeds,context_atts = self.encode_context(context)
            context_embeds, atts_context, wrapped_target = self.prompt_wrap(
                context_embeds,context_atts, prompt)
            ### combine text and image

            to_regress_embeds_ = []
            attention_mask_ = []
            targets_ = []
            for i, (tensor0,tensor1,tensor2) in enumerate(zip(to_regress_embeds,attention_mask,targets)):
                # 将p_before_embeds的每个序列与相应的张量在序列长度维度上拼接
                to_regress_embed = torch.cat((context_embeds[i], tensor0), dim=0)
                to_regress_embeds_.append(to_regress_embed)
                attention_m = torch.cat((atts_context[i], tensor1), dim=0)
                attention_mask_.append(attention_m)
                target = torch.cat((wrapped_target[i], tensor2), dim=0)
                targets_.append(target)


            # to_regress_embeds = torch.cat([context_embeds, to_regress_embeds],
            #                                 dim=1)
            # attention_mask = torch.cat([atts_context, attention_mask], dim=1)
            # targets = torch.cat([wrapped_target, targets], dim=1)

            # 确定最大长度
            max_len = max(t.size(0) for t in to_regress_embeds_)

            # 填充张量
            padded_to_regress_embeds_ = []
            padded_attention_mask_ = []
            padded_targets_ = []
            for (t,a,l) in zip(to_regress_embeds_,attention_mask_,targets_):
                if t.size(0) < max_len:
                    # 计算需要填充的长度
                    padding_size = max_len - t.size(0)
                    # 在序列维度上进行填充
                    padded_regress = torch.nn.functional.pad(t, (0, 0, 0, padding_size))
                    padded_attention = torch.nn.functional.pad(a, (0, padding_size), value=0)
                    padded_target = torch.nn.functional.pad(l, (0, padding_size), value=-100)

                    padded_to_regress_embeds_.append(padded_regress)
                    padded_attention_mask_.append(padded_attention)
                    padded_targets_.append(padded_target)
                else:
                    padded_to_regress_embeds_.append(t)
                    padded_attention_mask_.append(a)
                    padded_targets_.append(l)


            # 合并张量
            to_regress_embeds = torch.stack(padded_to_regress_embeds_)
            attention_mask = torch.stack(padded_attention_mask_)
            targets = torch.stack(padded_targets_)


        outputs = self.LLM_model(
            inputs_embeds=to_regress_embeds,
            attention_mask=attention_mask,
            return_dict=True,
            labels=targets,
        )
        return outputs
    






if __name__ == "__main__":

    parser = argparse.ArgumentParser()
    parser.add_argument("--output", default="output", type=str)
    parser.add_argument("--encoder", default="gpt2_large", type=str)
    parser.add_argument("--query_tokens", default=32, type=int)
    parser.add_argument("--load_path", default="/data2/xinyuuliu/InternLM-XComposer/output_rerank", type=str)
    parser.add_argument("--local_rank", default="0", type=str)


    args = parser.parse_args()


    model = RALLM(args)

    print(model)

    # model.encode_context("我爱北京天安门")
    # model.encode_text("我爱北京天安门")
    # #<ContextHere>
    # query = "Q:请重复内容:<cont_s><ContextHere><cont_e> \n A:"
    # context = ["电饭煲不知道怎么选?想要吃一碗香喷喷的米饭,除了米要好之外,还需要一款性能优秀的电饭煲,所以大家在选购电饭煲的时候,一定要多花点心思看看攻略避免踩雷。我前前后后给亲朋好友选购过不下5台电饭煲,也算是积攒了不少选购经验,今天特意总结了一下想分享给大家。1、容量选择市面上电饭煲容量普遍在3L-5L之间,这个范围的容量足够满足绝大部分家庭使用,3L一般可以满足1-3人的家庭,4L一般可以满足2-5人的家庭,5L一般可以满足2-8人的家庭,如果人口超过8人建议直接选择5L以上的容量,使用会更方便。"]

    # model.interleav_wrap(query,context)

 

modeling_perceive_sampler.py

"""
 * Copyright (c) 2023, salesforce.com, inc.
 * All rights reserved.
 * SPDX-License-Identifier: BSD-3-Clause
 * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
 * By Junnan Li
 * Based on huggingface code base
 * https://github.com/huggingface/transformers/blob/v4.15.0/src/transformers/models/bert
"""

import math
from typing import Tuple

import torch
import torch.utils.checkpoint
from torch import Tensor, device
from torch import nn
from torch.nn import CrossEntropyLoss
from transformers.activations import ACT2FN
from transformers.modeling_outputs import (
    BaseModelOutputWithPastAndCrossAttentions,
    BaseModelOutputWithPoolingAndCrossAttentions,
    CausalLMOutputWithCrossAttentions,
    MaskedLMOutput,
)
from transformers.modeling_utils import (
    PreTrainedModel,
    apply_chunking_to_forward,
    find_pruneable_heads_and_indices,
    prune_linear_layer,
)
from transformers.models.bert.configuration_bert import BertConfig
from transformers.utils import logging

logger = logging.get_logger(__name__)


class BertEmbeddings(nn.Module):
    """Construct the embeddings from word and position embeddings."""
    def __init__(self, config):
        super().__init__()
        self.word_embeddings = nn.Embedding(config.vocab_size,
                                            config.hidden_size,
                                            padding_idx=config.pad_token_id)
        self.position_embeddings = nn.Embedding(config.max_position_embeddings,
                                                config.hidden_size)

        # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
        # any TensorFlow checkpoint file
        self.LayerNorm = nn.LayerNorm(config.hidden_size,
                                      eps=config.layer_norm_eps)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

        # position_ids (1, len position emb) is contiguous in memory and exported when serialized
        self.register_buffer(
            "position_ids",
            torch.arange(config.max_position_embeddings).expand((1, -1)))
        self.position_embedding_type = getattr(config,
                                               "position_embedding_type",
                                               "absolute")

        self.config = config

    def forward(
        self,
        input_ids=None,
        position_ids=None,
        query_embeds=None,
        past_key_values_length=0,
    ):
        if input_ids is not None:
            seq_length = input_ids.size()[1]
        else:
            seq_length = 0

        if position_ids is None:
            position_ids = self.position_ids[:, past_key_values_length:
                                             seq_length +
                                             past_key_values_length].clone()

        if input_ids is not None:
            embeddings = self.word_embeddings(input_ids)
            if self.position_embedding_type == "absolute":
                position_embeddings = self.position_embeddings(position_ids)
                embeddings = embeddings + position_embeddings

            if query_embeds is not None:
                embeddings = torch.cat((query_embeds, embeddings), dim=1)
        else:
            embeddings = query_embeds

        embeddings = self.LayerNorm(embeddings)
        embeddings = self.dropout(embeddings)
        return embeddings


class BertSelfAttention(nn.Module):
    def __init__(self, config, is_cross_attention):
        super().__init__()
        self.config = config
        if config.hidden_size % config.num_attention_heads != 0 and not hasattr(
                config, "embedding_size"):
            raise ValueError(
                "The hidden size (%d) is not a multiple of the number of attention "
                "heads (%d)" %
                (config.hidden_size, config.num_attention_heads))

        self.num_attention_heads = config.num_attention_heads
        self.attention_head_size = int(config.hidden_size /
                                       config.num_attention_heads)
        self.all_head_size = self.num_attention_heads * self.attention_head_size

        self.query = nn.Linear(config.hidden_size, self.all_head_size)
        if is_cross_attention:
            self.key = nn.Linear(config.encoder_width, self.all_head_size)
            self.value = nn.Linear(config.encoder_width, self.all_head_size)
        else:
            self.key = nn.Linear(config.hidden_size, self.all_head_size)
            self.value = nn.Linear(config.hidden_size, self.all_head_size)

        self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
        self.position_embedding_type = getattr(config,
                                               "position_embedding_type",
                                               "absolute")
        if (self.position_embedding_type == "relative_key"
                or self.position_embedding_type == "relative_key_query"):
            self.max_position_embeddings = config.max_position_embeddings
            self.distance_embedding = nn.Embedding(
                2 * config.max_position_embeddings - 1,
                self.attention_head_size)
        self.save_attention = False

    def save_attn_gradients(self, attn_gradients):
        self.attn_gradients = attn_gradients

    def get_attn_gradients(self):
        return self.attn_gradients

    def save_attention_map(self, attention_map):
        self.attention_map = attention_map

    def get_attention_map(self):
        return self.attention_map

    def transpose_for_scores(self, x):
        new_x_shape = x.size()[:-1] + (
            self.num_attention_heads,
            self.attention_head_size,
        )
        x = x.view(*new_x_shape)
        return x.permute(0, 2, 1, 3)

    def forward(
        self,
        hidden_states,
        attention_mask=None,
        head_mask=None,
        encoder_hidden_states=None,
        encoder_attention_mask=None,
        past_key_value=None,
        output_attentions=False,
    ):

        # If this is instantiated as a cross-attention module, the keys
        # and values come from an encoder; the attention mask needs to be
        # such that the encoder's padding tokens are not attended to.
        is_cross_attention = encoder_hidden_states is not None

        if is_cross_attention:
            key_layer = self.transpose_for_scores(
                self.key(encoder_hidden_states))
            value_layer = self.transpose_for_scores(
                self.value(encoder_hidden_states))
            attention_mask = encoder_attention_mask
        elif past_key_value is not None:
            key_layer = self.transpose_for_scores(self.key(hidden_states))
            value_layer = self.transpose_for_scores(self.value(hidden_states))
            key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
            value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
        else:
            key_layer = self.transpose_for_scores(self.key(hidden_states))
            value_layer = self.transpose_for_scores(self.value(hidden_states))

        mixed_query_layer = self.query(hidden_states)

        query_layer = self.transpose_for_scores(mixed_query_layer)

        past_key_value = (key_layer, value_layer)

        # Take the dot product between "query" and "key" to get the raw attention scores.
        attention_scores = torch.matmul(query_layer,
                                        key_layer.transpose(-1, -2))

        if (self.position_embedding_type == "relative_key"
                or self.position_embedding_type == "relative_key_query"):
            seq_length = hidden_states.size()[1]
            position_ids_l = torch.arange(seq_length,
                                          dtype=torch.long,
                                          device=hidden_states.device).view(
                                              -1, 1)
            position_ids_r = torch.arange(seq_length,
                                          dtype=torch.long,
                                          device=hidden_states.device).view(
                                              1, -1)
            distance = position_ids_l - position_ids_r
            positional_embedding = self.distance_embedding(
                distance + self.max_position_embeddings - 1)
            positional_embedding = positional_embedding.to(
                dtype=query_layer.dtype)  # fp16 compatibility

            if self.position_embedding_type == "relative_key":
                relative_position_scores = torch.einsum(
                    "bhld,lrd->bhlr", query_layer, positional_embedding)
                attention_scores = attention_scores + relative_position_scores
            elif self.position_embedding_type == "relative_key_query":
                relative_position_scores_query = torch.einsum(
                    "bhld,lrd->bhlr", query_layer, positional_embedding)
                relative_position_scores_key = torch.einsum(
                    "bhrd,lrd->bhlr", key_layer, positional_embedding)
                attention_scores = (attention_scores +
                                    relative_position_scores_query +
                                    relative_position_scores_key)

        attention_scores = attention_scores / math.sqrt(
            self.attention_head_size)
        if attention_mask is not None:
            # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
            attention_scores = attention_scores + attention_mask

        # Normalize the attention scores to probabilities.
        attention_probs = nn.Softmax(dim=-1)(attention_scores)

        if is_cross_attention and self.save_attention:
            self.save_attention_map(attention_probs)
            attention_probs.register_hook(self.save_attn_gradients)

        # This is actually dropping out entire tokens to attend to, which might
        # seem a bit unusual, but is taken from the original Transformer paper.
        attention_probs_dropped = self.dropout(attention_probs)

        # Mask heads if we want to
        if head_mask is not None:
            attention_probs_dropped = attention_probs_dropped * head_mask

        context_layer = torch.matmul(attention_probs_dropped, value_layer)

        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
        new_context_layer_shape = context_layer.size()[:-2] + (
            self.all_head_size, )
        context_layer = context_layer.view(*new_context_layer_shape)

        outputs = ((context_layer, attention_probs) if output_attentions else
                   (context_layer, ))

        outputs = outputs + (past_key_value, )
        return outputs


class BertSelfOutput(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
        self.LayerNorm = nn.LayerNorm(config.hidden_size,
                                      eps=config.layer_norm_eps)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

    def forward(self, hidden_states, input_tensor):
        hidden_states = self.dense(hidden_states)
        hidden_states = self.dropout(hidden_states)
        hidden_states = self.LayerNorm(hidden_states + input_tensor)
        return hidden_states


class BertAttention(nn.Module):
    def __init__(self, config, is_cross_attention=False):
        super().__init__()
        self.self = BertSelfAttention(config, is_cross_attention)
        self.output = BertSelfOutput(config)
        self.pruned_heads = set()

    def prune_heads(self, heads):
        if len(heads) == 0:
            return
        heads, index = find_pruneable_heads_and_indices(
            heads,
            self.self.num_attention_heads,
            self.self.attention_head_size,
            self.pruned_heads,
        )

        # Prune linear layers
        self.self.query = prune_linear_layer(self.self.query, index)
        self.self.key = prune_linear_layer(self.self.key, index)
        self.self.value = prune_linear_layer(self.self.value, index)
        self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)

        # Update hyper params and store pruned heads
        self.self.num_attention_heads = self.self.num_attention_heads - len(
            heads)
        self.self.all_head_size = (self.self.attention_head_size *
                                   self.self.num_attention_heads)
        self.pruned_heads = self.pruned_heads.union(heads)

    def forward(
        self,
        hidden_states,
        attention_mask=None,
        head_mask=None,
        encoder_hidden_states=None,
        encoder_attention_mask=None,
        past_key_value=None,
        output_attentions=False,
    ):
        self_outputs = self.self(
            hidden_states,
            attention_mask,
            head_mask,
            encoder_hidden_states,
            encoder_attention_mask,
            past_key_value,
            output_attentions,
        )
        attention_output = self.output(self_outputs[0], hidden_states)

        outputs = (attention_output,
                   ) + self_outputs[1:]  # add attentions if we output them
        return outputs


class BertIntermediate(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
        if isinstance(config.hidden_act, str):
            self.intermediate_act_fn = ACT2FN[config.hidden_act]
        else:
            self.intermediate_act_fn = config.hidden_act

    def forward(self, hidden_states):
        hidden_states = self.dense(hidden_states)
        hidden_states = self.intermediate_act_fn(hidden_states)
        return hidden_states


class BertOutput(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
        self.LayerNorm = nn.LayerNorm(config.hidden_size,
                                      eps=config.layer_norm_eps)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

    def forward(self, hidden_states, input_tensor):
        hidden_states = self.dense(hidden_states)
        hidden_states = self.dropout(hidden_states)
        hidden_states = self.LayerNorm(hidden_states + input_tensor)
        return hidden_states


class BertLayer(nn.Module):
    def __init__(self, config, layer_num):
        super().__init__()
        self.config = config
        self.chunk_size_feed_forward = config.chunk_size_feed_forward
        self.seq_len_dim = 1
        self.attention = BertAttention(config)
        self.layer_num = layer_num
        if (self.config.add_cross_attention
                and layer_num % self.config.cross_attention_freq == 0):
            self.crossattention = BertAttention(
                config, is_cross_attention=self.config.add_cross_attention)
            self.has_cross_attention = True
        else:
            self.has_cross_attention = False
        self.intermediate = BertIntermediate(config)
        self.output = BertOutput(config)

        self.intermediate_query = BertIntermediate(config)
        self.output_query = BertOutput(config)

    def forward(
        self,
        hidden_states,
        attention_mask=None,
        head_mask=None,
        encoder_hidden_states=None,
        encoder_attention_mask=None,
        past_key_value=None,
        output_attentions=False,
        query_length=0,
    ):
        # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
        self_attn_past_key_value = (past_key_value[:2]
                                    if past_key_value is not None else None)
        self_attention_outputs = self.attention(
            hidden_states,
            attention_mask,
            head_mask,
            output_attentions=output_attentions,
            past_key_value=self_attn_past_key_value,
        )
        attention_output = self_attention_outputs[0]
        outputs = self_attention_outputs[1:-1]

        present_key_value = self_attention_outputs[-1]

        if query_length > 0:
            query_attention_output = attention_output[:, :query_length, :]

            if self.has_cross_attention:
                assert (
                    encoder_hidden_states is not None
                ), "encoder_hidden_states must be given for cross-attention layers"
                cross_attention_outputs = self.crossattention(
                    query_attention_output,
                    attention_mask,
                    head_mask,
                    encoder_hidden_states,
                    encoder_attention_mask,
                    output_attentions=output_attentions,
                )
                query_attention_output = cross_attention_outputs[0]
                outputs = (
                    outputs + cross_attention_outputs[1:-1]
                )  # add cross attentions if we output attention weights

            layer_output = apply_chunking_to_forward(
                self.feed_forward_chunk_query,
                self.chunk_size_feed_forward,
                self.seq_len_dim,
                query_attention_output,
            )
            if attention_output.shape[1] > query_length:
                layer_output_text = apply_chunking_to_forward(
                    self.feed_forward_chunk,
                    self.chunk_size_feed_forward,
                    self.seq_len_dim,
                    attention_output[:, query_length:, :],
                )
                layer_output = torch.cat([layer_output, layer_output_text],
                                         dim=1)
        else:
            layer_output = apply_chunking_to_forward(
                self.feed_forward_chunk,
                self.chunk_size_feed_forward,
                self.seq_len_dim,
                attention_output,
            )
        outputs = (layer_output, ) + outputs

        outputs = outputs + (present_key_value, )

        return outputs

    def feed_forward_chunk(self, attention_output):
        intermediate_output = self.intermediate(attention_output)
        layer_output = self.output(intermediate_output, attention_output)
        return layer_output

    def feed_forward_chunk_query(self, attention_output):
        intermediate_output = self.intermediate_query(attention_output)
        layer_output = self.output_query(intermediate_output, attention_output)
        return layer_output


class BertEncoder(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.layer = nn.ModuleList(
            [BertLayer(config, i) for i in range(config.num_hidden_layers)])

    def forward(
        self,
        hidden_states,
        attention_mask=None,
        head_mask=None,
        encoder_hidden_states=None,
        encoder_attention_mask=None,
        past_key_values=None,
        use_cache=None,
        output_attentions=False,
        output_hidden_states=False,
        return_dict=True,
        query_length=0,
    ):
        all_hidden_states = () if output_hidden_states else None
        all_self_attentions = () if output_attentions else None
        all_cross_attentions = (() if output_attentions
                                and self.config.add_cross_attention else None)

        next_decoder_cache = () if use_cache else None

        for i in range(self.config.num_hidden_layers):
            layer_module = self.layer[i]
            if output_hidden_states:
                all_hidden_states = all_hidden_states + (hidden_states, )

            layer_head_mask = head_mask[i] if head_mask is not None else None
            past_key_value = past_key_values[
                i] if past_key_values is not None else None

            if getattr(self.config, "gradient_checkpointing",
                       False) and self.training:

                if use_cache:
                    logger.warn(
                        "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
                    )
                    use_cache = False

                def create_custom_forward(module):
                    def custom_forward(*inputs):
                        return module(*inputs, past_key_value,
                                      output_attentions, query_length)

                    return custom_forward

                layer_outputs = torch.utils.checkpoint.checkpoint(
                    create_custom_forward(layer_module),
                    hidden_states,
                    attention_mask,
                    layer_head_mask,
                    encoder_hidden_states,
                    encoder_attention_mask,
                )
            else:
                layer_outputs = layer_module(
                    hidden_states,
                    attention_mask,
                    layer_head_mask,
                    encoder_hidden_states,
                    encoder_attention_mask,
                    past_key_value,
                    output_attentions,
                    query_length,
                )

            hidden_states = layer_outputs[0]
            if use_cache:
                next_decoder_cache += (layer_outputs[-1], )
            if output_attentions:
                all_self_attentions = all_self_attentions + (
                    layer_outputs[1], )
                all_cross_attentions = all_cross_attentions + (
                    layer_outputs[2], )

        if output_hidden_states:
            all_hidden_states = all_hidden_states + (hidden_states, )

        if not return_dict:
            return tuple(v for v in [
                hidden_states,
                next_decoder_cache,
                all_hidden_states,
                all_self_attentions,
                all_cross_attentions,
            ] if v is not None)
        return BaseModelOutputWithPastAndCrossAttentions(
            last_hidden_state=hidden_states,
            past_key_values=next_decoder_cache,
            hidden_states=all_hidden_states,
            attentions=all_self_attentions,
            cross_attentions=all_cross_attentions,
        )


class BertPooler(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
        self.activation = nn.Tanh()

    def forward(self, hidden_states):
        # We "pool" the model by simply taking the hidden state corresponding
        # to the first token.
        first_token_tensor = hidden_states[:, 0]
        pooled_output = self.dense(first_token_tensor)
        pooled_output = self.activation(pooled_output)
        return pooled_output


class BertPredictionHeadTransform(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
        if isinstance(config.hidden_act, str):
            self.transform_act_fn = ACT2FN[config.hidden_act]
        else:
            self.transform_act_fn = config.hidden_act
        self.LayerNorm = nn.LayerNorm(config.hidden_size,
                                      eps=config.layer_norm_eps)

    def forward(self, hidden_states):
        hidden_states = self.dense(hidden_states)
        hidden_states = self.transform_act_fn(hidden_states)
        hidden_states = self.LayerNorm(hidden_states)
        return hidden_states


class BertLMPredictionHead(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.transform = BertPredictionHeadTransform(config)

        # The output weights are the same as the input embeddings, but there is
        # an output-only bias for each token.
        self.decoder = nn.Linear(config.hidden_size,
                                 config.vocab_size,
                                 bias=False)

        self.bias = nn.Parameter(torch.zeros(config.vocab_size))

        # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
        self.decoder.bias = self.bias

    def forward(self, hidden_states):
        hidden_states = self.transform(hidden_states)
        hidden_states = self.decoder(hidden_states)
        return hidden_states


class BertOnlyMLMHead(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.predictions = BertLMPredictionHead(config)

    def forward(self, sequence_output):
        prediction_scores = self.predictions(sequence_output)
        return prediction_scores


class BertPreTrainedModel(PreTrainedModel):
    """
    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
    models.
    """

    config_class = BertConfig
    base_model_prefix = "bert"
    _keys_to_ignore_on_load_missing = [r"position_ids"]

    def _init_weights(self, module):
        """Initialize the weights"""
        if isinstance(module, (nn.Linear, nn.Embedding)):
            # Slightly different from the TF version which uses truncated_normal for initialization
            # cf https://github.com/pytorch/pytorch/pull/5617
            module.weight.data.normal_(mean=0.0,
                                       std=self.config.initializer_range)
        elif isinstance(module, nn.LayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)
        if isinstance(module, nn.Linear) and module.bias is not None:
            module.bias.data.zero_()


class BertModel(BertPreTrainedModel):
    """
    The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
    cross-attention is added between the self-attention layers, following the architecture described in `Attention is
    all you need <https://arxiv.org/abs/1706.03762>`__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
    Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
    argument and :obj:`add_cross_attention` set to :obj:`True`; an :obj:`encoder_hidden_states` is then expected as an
    input to the forward pass.
    """
    def __init__(self, config, add_pooling_layer=False):
        super().__init__(config)
        self.config = config

        self.embeddings = BertEmbeddings(config)

        self.encoder = BertEncoder(config)

        self.pooler = BertPooler(config) if add_pooling_layer else None

        self.init_weights()

    def get_input_embeddings(self):
        return self.embeddings.word_embeddings

    def set_input_embeddings(self, value):
        self.embeddings.word_embeddings = value

    def _prune_heads(self, heads_to_prune):
        """
        Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
        class PreTrainedModel
        """
        for layer, heads in heads_to_prune.items():
            self.encoder.layer[layer].attention.prune_heads(heads)

    def get_extended_attention_mask(
        self,
        attention_mask: Tensor,
        input_shape: Tuple[int],
        device: device,
        is_decoder: bool,
        has_query: bool = False,
    ) -> Tensor:
        """
        Makes broadcastable attention and causal masks so that future and masked tokens are ignored.

        Arguments:
            attention_mask (:obj:`torch.Tensor`):
                Mask with ones indicating tokens to attend to, zeros for tokens to ignore.
            input_shape (:obj:`Tuple[int]`):
                The shape of the input to the model.
            device: (:obj:`torch.device`):
                The device of the input to the model.

        Returns:
            :obj:`torch.Tensor` The extended attention mask, with a the same dtype as :obj:`attention_mask.dtype`.
        """
        # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
        # ourselves in which case we just need to make it broadcastable to all heads.
        if attention_mask.dim() == 3:
            extended_attention_mask = attention_mask[:, None, :, :]
        elif attention_mask.dim() == 2:
            # Provided a padding mask of dimensions [batch_size, seq_length]
            # - if the model is a decoder, apply a causal mask in addition to the padding mask
            # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
            if is_decoder:
                batch_size, seq_length = input_shape

                seq_ids = torch.arange(seq_length, device=device)
                causal_mask = (seq_ids[None, None, :].repeat(
                    batch_size, seq_length, 1) <= seq_ids[None, :, None])

                # add a prefix ones mask to the causal mask
                # causal and attention masks must have same type with pytorch version < 1.3
                causal_mask = causal_mask.to(attention_mask.dtype)

                if causal_mask.shape[1] < attention_mask.shape[1]:
                    prefix_seq_len = attention_mask.shape[
                        1] - causal_mask.shape[1]
                    if has_query:  # UniLM style attention mask
                        causal_mask = torch.cat(
                            [
                                torch.zeros(
                                    (batch_size, prefix_seq_len, seq_length),
                                    device=device,
                                    dtype=causal_mask.dtype,
                                ),
                                causal_mask,
                            ],
                            axis=1,
                        )
                    causal_mask = torch.cat(
                        [
                            torch.ones(
                                (batch_size, causal_mask.shape[1],
                                 prefix_seq_len),
                                device=device,
                                dtype=causal_mask.dtype,
                            ),
                            causal_mask,
                        ],
                        axis=-1,
                    )
                extended_attention_mask = (causal_mask[:, None, :, :] *
                                           attention_mask[:, None, None, :])
            else:
                extended_attention_mask = attention_mask[:, None, None, :]
        else:
            raise ValueError(
                "Wrong shape for input_ids (shape {}) or attention_mask (shape {})"
                .format(input_shape, attention_mask.shape))

        # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
        # masked positions, this operation will create a tensor which is 0.0 for
        # positions we want to attend and -10000.0 for masked positions.
        # Since we are adding it to the raw scores before the softmax, this is
        # effectively the same as removing these entirely.
        extended_attention_mask = extended_attention_mask.to(
            dtype=self.dtype)  # fp16 compatibility
        extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
        return extended_attention_mask

    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        position_ids=None,
        head_mask=None,
        query_embeds=None,
        encoder_hidden_states=None,
        encoder_attention_mask=None,
        past_key_values=None,
        use_cache=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
        is_decoder=False,
    ):
        r"""
        encoder_hidden_states  (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
            Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
            the model is configured as a decoder.
        encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
            Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
            the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
            - 1 for tokens that are **not masked**,
            - 0 for tokens that are **masked**.
        past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
            Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
            If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
            (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
            instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
        use_cache (:obj:`bool`, `optional`):
            If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
            decoding (see :obj:`past_key_values`).
        """
        output_attentions = (output_attentions if output_attentions is not None
                             else self.config.output_attentions)
        output_hidden_states = (output_hidden_states
                                if output_hidden_states is not None else
                                self.config.output_hidden_states)
        return_dict = (return_dict if return_dict is not None else
                       self.config.use_return_dict)

        # use_cache = use_cache if use_cache is not None else self.config.use_cache

        if input_ids is None:
            assert (
                query_embeds is not None
            ), "You have to specify query_embeds when input_ids is None"

        # past_key_values_length
        past_key_values_length = (past_key_values[0][0].shape[2] -
                                  self.config.query_length
                                  if past_key_values is not None else 0)

        query_length = query_embeds.shape[1] if query_embeds is not None else 0

        embedding_output = self.embeddings(
            input_ids=input_ids,
            position_ids=position_ids,
            query_embeds=query_embeds,
            past_key_values_length=past_key_values_length,
        )

        input_shape = embedding_output.size()[:-1]
        batch_size, seq_length = input_shape
        device = embedding_output.device

        if attention_mask is None:
            attention_mask = torch.ones(
                ((batch_size, seq_length + past_key_values_length)),
                device=device)

        # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
        # ourselves in which case we just need to make it broadcastable to all heads.
        if is_decoder:
            extended_attention_mask = self.get_extended_attention_mask(
                attention_mask,
                input_ids.shape,
                device,
                is_decoder,
                has_query=(query_embeds is not None),
            )
        else:
            extended_attention_mask = self.get_extended_attention_mask(
                attention_mask, input_shape, device, is_decoder)

        # If a 2D or 3D attention mask is provided for the cross-attention
        # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
        if encoder_hidden_states is not None:
            if type(encoder_hidden_states) == list:
                encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[
                    0].size()
            else:
                (
                    encoder_batch_size,
                    encoder_sequence_length,
                    _,
                ) = encoder_hidden_states.size()
            encoder_hidden_shape = (encoder_batch_size,
                                    encoder_sequence_length)

            if type(encoder_attention_mask) == list:
                encoder_extended_attention_mask = [
                    self.invert_attention_mask(mask)
                    for mask in encoder_attention_mask
                ]
            elif encoder_attention_mask is None:
                encoder_attention_mask = torch.ones(encoder_hidden_shape,
                                                    device=device)
                encoder_extended_attention_mask = self.invert_attention_mask(
                    encoder_attention_mask)
            else:
                encoder_extended_attention_mask = self.invert_attention_mask(
                    encoder_attention_mask)
        else:
            encoder_extended_attention_mask = None

        # Prepare head mask if needed
        # 1.0 in head_mask indicate we keep the head
        # attention_probs has shape bsz x n_heads x N x N
        # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
        # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
        head_mask = self.get_head_mask(head_mask,
                                       self.config.num_hidden_layers)

        encoder_outputs = self.encoder(
            embedding_output,
            attention_mask=extended_attention_mask,
            head_mask=head_mask,
            encoder_hidden_states=encoder_hidden_states,
            encoder_attention_mask=encoder_extended_attention_mask,
            past_key_values=past_key_values,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            query_length=query_length,
        )
        sequence_output = encoder_outputs[0]
        pooled_output = (self.pooler(sequence_output)
                         if self.pooler is not None else None)

        if not return_dict:
            return (sequence_output, pooled_output) + encoder_outputs[1:]

        return BaseModelOutputWithPoolingAndCrossAttentions(
            last_hidden_state=sequence_output,
            pooler_output=pooled_output,
            past_key_values=encoder_outputs.past_key_values,
            hidden_states=encoder_outputs.hidden_states,
            attentions=encoder_outputs.attentions,
            cross_attentions=encoder_outputs.cross_attentions,
        )


class BertLMHeadModel(BertPreTrainedModel):

    _keys_to_ignore_on_load_unexpected = [r"pooler"]
    _keys_to_ignore_on_load_missing = [
        r"position_ids", r"predictions.decoder.bias"
    ]

    def __init__(self, config):
        super().__init__(config)

        self.bert = BertModel(config, add_pooling_layer=False)
        self.cls = BertOnlyMLMHead(config)

        self.init_weights()

    def get_output_embeddings(self):
        return self.cls.predictions.decoder

    def set_output_embeddings(self, new_embeddings):
        self.cls.predictions.decoder = new_embeddings

    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        position_ids=None,
        head_mask=None,
        query_embeds=None,
        encoder_hidden_states=None,
        encoder_attention_mask=None,
        labels=None,
        past_key_values=None,
        use_cache=True,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
        return_logits=False,
        is_decoder=True,
        reduction="mean",
    ):
        r"""
        encoder_hidden_states  (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
            Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
            the model is configured as a decoder.
        encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
            Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
            the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
            - 1 for tokens that are **not masked**,
            - 0 for tokens that are **masked**.
        labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
            Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
            ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are
            ignored (masked), the loss is only computed for the tokens with labels n ``[0, ..., config.vocab_size]``
        past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
            Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
            If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
            (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
            instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
        use_cache (:obj:`bool`, `optional`):
            If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
            decoding (see :obj:`past_key_values`).
        Returns:
        Example::
            >>> from transformers import BertTokenizer, BertLMHeadModel, BertConfig
            >>> import torch
            >>> tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
            >>> config = BertConfig.from_pretrained("bert-base-cased")
            >>> model = BertLMHeadModel.from_pretrained('bert-base-cased', config=config)
            >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
            >>> outputs = model(**inputs)
            >>> prediction_logits = outputs.logits
        """
        return_dict = (return_dict if return_dict is not None else
                       self.config.use_return_dict)
        if labels is not None:
            use_cache = False
        if past_key_values is not None:
            query_embeds = None

        outputs = self.bert(
            input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            head_mask=head_mask,
            query_embeds=query_embeds,
            encoder_hidden_states=encoder_hidden_states,
            encoder_attention_mask=encoder_attention_mask,
            past_key_values=past_key_values,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            is_decoder=is_decoder,
        )

        sequence_output = outputs[0]
        if query_embeds is not None:
            sequence_output = outputs[0][:, query_embeds.shape[1]:, :]

        prediction_scores = self.cls(sequence_output)

        if return_logits:
            return prediction_scores[:, :-1, :].contiguous()

        lm_loss = None
        if labels is not None:
            # we are doing next-token prediction; shift prediction scores and input ids by one
            shifted_prediction_scores = prediction_scores[:, :
                                                          -1, :].contiguous()
            labels = labels[:, 1:].contiguous()
            loss_fct = CrossEntropyLoss(reduction=reduction,
                                        label_smoothing=0.1)
            lm_loss = loss_fct(
                shifted_prediction_scores.view(-1, self.config.vocab_size),
                labels.view(-1),
            )
            if reduction == "none":
                lm_loss = lm_loss.view(prediction_scores.size(0), -1).sum(1)

        if not return_dict:
            output = (prediction_scores, ) + outputs[2:]
            return ((lm_loss, ) + output) if lm_loss is not None else output

        return CausalLMOutputWithCrossAttentions(
            loss=lm_loss,
            logits=prediction_scores,
            past_key_values=outputs.past_key_values,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
            cross_attentions=outputs.cross_attentions,
        )

    def prepare_inputs_for_generation(self,
                                      input_ids,
                                      query_embeds,
                                      past=None,
                                      attention_mask=None,
                                      **model_kwargs):
        # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
        if attention_mask is None:
            attention_mask = input_ids.new_ones(input_ids.shape)
        query_mask = input_ids.new_ones(query_embeds.shape[:-1])
        attention_mask = torch.cat([query_mask, attention_mask], dim=-1)

        # cut decoder_input_ids if past is used
        if past is not None:
            input_ids = input_ids[:, -1:]

        return {
            "input_ids":
            input_ids,
            "query_embeds":
            query_embeds,
            "attention_mask":
            attention_mask,
            "past_key_values":
            past,
            "encoder_hidden_states":
            model_kwargs.get("encoder_hidden_states", None),
            "encoder_attention_mask":
            model_kwargs.get("encoder_attention_mask", None),
            "is_decoder":
            True,
        }

    def _reorder_cache(self, past, beam_idx):
        reordered_past = ()
        for layer_past in past:
            reordered_past += (tuple(
                past_state.index_select(0, beam_idx)
                for past_state in layer_past), )
        return reordered_past


class BertForMaskedLM(BertPreTrainedModel):

    _keys_to_ignore_on_load_unexpected = [r"pooler"]
    _keys_to_ignore_on_load_missing = [
        r"position_ids", r"predictions.decoder.bias"
    ]

    def __init__(self, config):
        super().__init__(config)

        self.bert = BertModel(config, add_pooling_layer=False)
        self.cls = BertOnlyMLMHead(config)

        self.init_weights()

    def get_output_embeddings(self):
        return self.cls.predictions.decoder

    def set_output_embeddings(self, new_embeddings):
        self.cls.predictions.decoder = new_embeddings

    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        position_ids=None,
        head_mask=None,
        query_embeds=None,
        encoder_hidden_states=None,
        encoder_attention_mask=None,
        labels=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
        return_logits=False,
        is_decoder=False,
    ):
        r"""
        labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
            Labels for computing the masked language modeling loss. Indices should be in ``[-100, 0, ...,
            config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored
            (masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]``
        """

        return_dict = (return_dict if return_dict is not None else
                       self.config.use_return_dict)

        outputs = self.bert(
            input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            head_mask=head_mask,
            query_embeds=query_embeds,
            encoder_hidden_states=encoder_hidden_states,
            encoder_attention_mask=encoder_attention_mask,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            is_decoder=is_decoder,
        )

        if query_embeds is not None:
            sequence_output = outputs[0][:, query_embeds.shape[1]:, :]
        prediction_scores = self.cls(sequence_output)

        if return_logits:
            return prediction_scores

        masked_lm_loss = None
        if labels is not None:
            loss_fct = CrossEntropyLoss()  # -100 index = padding token
            masked_lm_loss = loss_fct(
                prediction_scores.view(-1, self.config.vocab_size),
                labels.view(-1))

        if not return_dict:
            output = (prediction_scores, ) + outputs[2:]
            return (((masked_lm_loss, ) +
                     output) if masked_lm_loss is not None else output)

        return MaskedLMOutput(
            loss=masked_lm_loss,
            logits=prediction_scores,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )

 

dataset_batch.py

# -*- coding: utf-8 -*-
from torch.utils.data import Dataset
import torch
import json
import numpy as np
import pandas as pd
from torch.utils.data import DataLoader
import random


class QADataset(Dataset):
    def __init__(self, data_path,train) -> None:
        super().__init__()
 
        self.data = []

        
        data = pd.read_csv(data_path).dropna()
        print(data.columns)
        condition = (data['answer'].str.len() <= 1000) & (data['summary'].str.len() <= 500)

        filtered_data = data[condition]

        with open("data/corpus.tsv","r") as f_read:
            corpus = [i.split()[-1] for i in f_read.readlines()]

        retell_prompts = ["请复述这段被压缩的内容",
                        "复述这段被压缩的内容",
                        "请将被压缩的内容复述出来",]

        summary_prompts = ["请总结被压缩的信息",
                    "还原被压缩信息的主要内容",
                    "请写出被压缩信息的主要内容",
                    "请对之前压缩的信息进行概括",
                    "请提炼出之前被压缩信息的核心要点",
                    "请归纳一下之前被压缩的内容的主旨"]

        if train:
            # 过滤出符合长度条件的文章
            # filtered_data1000 = list(filter(self.filter_by_length1000, data["answer"]))

            for idx in range(5000):
            
                # if not line or line == "" or len(line) < 50 or len(line) > 2000:
                #     continue
                
                # 随机确定重复次数(1到5次)
                repeat_count = random.randint(1, 10)

                flag_context = "<context> "*repeat_count

                prompt = random.choice(retell_prompts)
                selected_articles = random.sample(corpus, repeat_count)
                selected_articles_ = "[SEP]".join(selected_articles)

                text = f'<reserved_106>{prompt}{flag_context}<reserved_107>{selected_articles_}'
                test_data = {"context":selected_articles,"text_input":text,"label":selected_articles_}

                self.data.append(
                    test_data
                )

            # for idx in range(5000):
            #     repeat_count = random.randint(1, 1)

            #     flag_context = "<context> "*repeat_count

            #     selected_articles = random.sample(filtered_data150, repeat_count)
            #     selected_articles_ = " ".join(selected_articles)

            #     text = f'<|User|>:请复述这段话{flag_context} <|Bot|>:{selected_articles_}'
            #     test_data = {"samples":{"context":selected_articles,"text_input":[text]}}

            #     self.data.append(
            #         test_data
            #     )    

            for idx,(answer,summary) in enumerate(zip(filtered_data["answer"],filtered_data["summary"])):
                
                answer = [answer[:1000]]
                flag_context = "<context> "

                prompt = random.choice(summary_prompts)

                # user_token: <reserved_106> assisent_token: <reserved_107> eoa: </s>
                text = f'<reserved_106>{prompt}{flag_context}<reserved_107>{summary}'
                test_data = {"context":answer,"text_input":text,"label":summary}

                self.data.append(
                    test_data
                )   

            # for idx in range(10000):
            #     repeat_count = random.randint(1, 1)

            #     flag_context = "<context> "*repeat_count

            #     selected_articles = random.sample(filtered_data1500, repeat_count)
            #     selected_articles_ = " ".join(selected_articles)

            #     text = f'<|User|>:请复述这段话{flag_context} <|Bot|>:{selected_articles_}'
            #     test_data = {"samples":{"context":selected_articles,"text_input":[text]}}

            #     self.data.append(
            #         test_data
            #     )


            print("data load , size:", len(self.data))

        else:
            for idx in range(100):
            
                # if not line or line == "" or len(line) < 50 or len(line) > 2000:
                #     continue
                
                # 随机确定重复次数(1到5次)
                repeat_count = random.randint(3, 5)

                flag_context = "<context> "*repeat_count

                prompt = random.choice(retell_prompts)
                selected_articles = random.sample(corpus, repeat_count)
                selected_articles_ = "[SEP]".join(selected_articles)

                text = f'<reserved_106>{prompt}{flag_context}<reserved_107>'
                test_data = {"context":selected_articles,"text_input":text,"label":selected_articles_}

                self.data.append(
                    test_data
                )


    # 创建一个函数来过滤文章长度
    @staticmethod
    def filter_by_length150(article):
        return 180 <= len(article) <= 200

    @staticmethod
    def filter_by_length1000(article):
        return 50 <= len(article) <= 1000

    @staticmethod
    def filter_by_length1500(article):
        return 500 <= len(article) <= 1500

    def __getitem__(self, index):
        item_data = self.data[index]

        return item_data

    def __len__(self):
        return len(self.data)


if __name__ == "__main__":
    data_path = "QA_5000_summary.csv"
    dataset = QADataset(data_path,train=True)
    # print(dataset[0])
    val_params = {
        "batch_size": 2,
        "shuffle": False,
        "num_workers": 0,
    }

    def collate_fn(batch):
        """
        对batch数据进行处理
        :param batch: [一个getitem的结果,getitem的结果,getitem的结果]
        :return: 元组
        """
        # 初始化一个空字典来存储合并后的结果
        merged_dict = {}

        # 遍历列表中的每个字典
        for d in batch:
            # 遍历每个字典中的键值对
            for key, value in d.items():
                # 如果键已经存在于merged_dict中,将值合并为一个字符串,用逗号分隔
                if key in merged_dict:
                    merged_dict[key].append(value)
                else:
                    # 如果键不存在于merged_dict中,直接添加到merged_dict中
                    merged_dict[key] = [value]
        # 输出合并后的结果
        # print(merged_dict)

        return merged_dict

    val_loader = DataLoader(dataset, **val_params,collate_fn=collate_fn)

    for i in val_loader:
        print(i)
        break
    

train_batch.py

# -*- coding: utf-8 -*-
import pandas as pd
from torch.utils.data import DataLoader
from transformers import AutoTokenizer, AutoModel
# from dataset_batch_en import QADataset
# from dataset_rerank import QADataset
# from dataset_rerank_en_gpt import QADataset
from dataset_rerank_en import QADataset

from peft import LoraConfig, get_peft_model, TaskType
from tqdm import tqdm
import torch
import os, time, sys
import numpy as np
from modeling_RALLM import RALLM
import argparse
import deepspeed
from torch.nn.parallel import DataParallel


# 设置CUDA设备可见性,例如仅使用第一个GPU
# os.environ["CUDA_VISIBLE_DEVICES"] = "3"

parser = argparse.ArgumentParser()
parser.add_argument("--is_compress", default=True, type=bool)
parser.add_argument("--compressibility_factor", default=0, type=float,dest="0-1")
parser.add_argument("--LLM_model", default="Baichuan2_7B", type=str)
parser.add_argument("--output", default="output_english_longformer_rerank100k_2", type=str)
parser.add_argument("--encoder", default="longformer", type=str)
parser.add_argument("--query_tokens", default=98, type=int)
parser.add_argument("--load_path", default="output_english_longformer_msmarco2019/checkpoint-87500", type=str)
parser.add_argument("--local_rank", default=0, type=int)
parser.add_argument("--num_train_epochs", default=10, type=int)
parser.add_argument("--learning_rate", default=5e-3, type=int)
parser.add_argument("--weight_decay", default=0.005, type=int)
parser.add_argument("--per_device_train_batch_size", default=6, type=int)
parser.add_argument("--max_length", default=4096, type=int)
parser.add_argument("--use_lora", default=True, type=bool)
parser.add_argument("--use_lora_gpt2", default=False, type=bool)
parser.add_argument("--train_dataset", default="data/qa_3w_summary.csv", type=str)
parser.add_argument("--epochs", default=1, type=int)
parser.add_argument("--batch_size", default=1, type=int)



args = parser.parse_args()


def train(epoch, model, loader, optimizer,scaler, gradient_accumulation_steps,model_output_dir):
    model.train()

    time1 = time.time()
    losses = []
    train_bar = tqdm(loader,total=len(loader))
    for index, data in enumerate(train_bar):
        optimizer.zero_grad()
        with torch.autocast(device_type="cuda",dtype=torch.float16):
            
            # print(data)
            outputs = model(model,**data)
            loss = outputs.loss
        # 反向传播,计算当前梯度
            loss.requires_grad_(True)
            losses.append(loss.item())

            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()

        if (index+1) % 5000 == 0:
            model_output_dir_ = os.path.join(model_output_dir,f"epoch{epoch}")
            model_save_path = os.path.join(model_output_dir_,"index_{}".format(index))
            if os.path.exists(model_save_path):
                pass
            else:
                os.makedirs(model_save_path)

            torch.save(model.state_dict(), os.path.join(model_save_path,"LLM_model_{:.6f}.pth".format(np.mean(losses))))

        train_bar.set_description("epoch:{} idx:{} loss:{:.6f}".format(epoch,index,np.mean(losses)))



def validate( model,  loader):
    model.eval()

    predictions = []
    actuals = []
    
    with torch.no_grad():
        with torch.autocast(device_type="cuda",dtype=torch.float16):
            for _, data in enumerate(tqdm(loader, file=sys.stdout, desc="Validation Data")):

                text = data["text_input"]
                context = data["context"]
                label = data["label"]
                print(text)
                print("len context:",len(context))
                for text_,context_ in zip(text,context):
                    preds = model.generate(
                        text=text_,context = [context_]
                    )
                    print(preds)
                    print(label)
                    predictions.append(preds)
                actuals.extend(label)
    return predictions, actuals


def main():
    epochs = args.epochs
    batch_size = args.batch_size
    lr = 1e-5
    gradient_accumulation_steps = 16
    model_output_dir = args.output
    # train_path = "qa_3w_summary.csv"
    train_path = args.train_dataset
    val_path = args.train_dataset

    device = torch.device(f"cuda:{args.local_rank}")
    model = RALLM(args)
    model = model.to(device)

    if args.use_lora:
        print("使用lora训练模型"+"*"*10)
        from peft import LoraConfig, TaskType, get_peft_model,PeftModel,AutoPeftModelForCausalLM

        peft_config = LoraConfig(
            task_type=TaskType.CAUSAL_LM,
            target_modules=["W_pack",],
            inference_mode=False,
            r=256,
            lora_alpha=512,
            lora_dropout=0.1,
        )
        model.LLM_model.enable_input_require_grads()
        model.LLM_model = get_peft_model(model.LLM_model, peft_config)

    if args.use_lora_gpt2:
        print("使用lora训练模型"+"*"*10)
        from peft import LoraConfig, TaskType, get_peft_model,PeftModel,AutoPeftModelForCausalLM

        peft_config = LoraConfig(
            task_type=TaskType.CAUSAL_LM,
            target_modules=["wte","c_attn",],
            inference_mode=False,
            r=256,
            lora_alpha=512,
            lora_dropout=0.1,
        )
        # model.LLM_model.enable_input_require_grads()
        model.context_encoder = get_peft_model(model.context_encoder, peft_config)


    print(model)

    torch.cuda.empty_cache()  # 释放显存

    if args.load_path:
        base_load_path = args.load_path
        # 列出所有分块模型参数文件的文件名

        if base_load_path.endswith(".pth"):
            state_dict = torch.load(base_load_path,map_location=device)
        else:
            file_list = ['pytorch_model.bin']
            # 创建一个空的模型状态字典
            state_dict = {}
            # 遍历所有分块文件并加载它们
            for file_name in file_list:
                # 加载单个分块文件的模型参数
                part_state_dict = torch.load(os.path.join(base_load_path,file_name),map_location=device)
                
                # 将加载的模型参数合并到总的模型状态字典中
                state_dict.update(part_state_dict)
        # 将合并后的模型状态字典加载到模型中
        print("state_dict:")
        print(state_dict.keys())
    
        model.load_state_dict(state_dict,strict=False)


    for param in model.context_encoder.parameters():
        param.requires_grad = False
    # layers_to_modify = [30,31,32,33,34,35]
    # # Iterate over all named parameters in the model
    # for name, param in model.context_encoder.named_parameters():
    #     # Check if the parameter belongs to the specified layers
    #     if any(f"context_encoder.h.{layer}." in name for layer in layers_to_modify):
    #         # Set requires_grad to True or False depending on whether you want the parameter to be trainable or not
    #         param.requires_grad = True  # or False if you want to freeze the layer
    for param in model.ln_features.parameters():
        param.requires_grad = True
    for param in model.Qformer.parameters():
        param.requires_grad = True

    # 遍历每一层并冻结参数
    # for param in model.LLM_model.parameters():
    #     param.requires_grad = False

    # 冻结除了lora_A和lora_B以外的所有层
    # trained = []
    # untrained = []
    # for name, param in model.LLM_model.named_parameters():
    #     # if 'v_proj.lora_A' in name or 'v_proj.lora_B' in name or 'q_proj.lora_B' in name or 'q_proj.lora_B' in name:
    #     if 'lora_A' in name or 'lora_B' in name:
    #         param.requires_grad = True
    #         trained.append(name)
    #     else:
    #         param.requires_grad = False
    #         untrained.append(name)


    # Print trainable and non-trainable parameters
    trainable_params = []
    non_trainable_params = []

    for name, param in model.named_parameters():
        if param.requires_grad:
            trainable_params.append(name)
        else:
            non_trainable_params.append(name)

    print("Trainable Parameters:")
    print("\n".join(trainable_params))

    print("\nNon-Trainable Parameters:")
    print("\n".join(non_trainable_params))

    # setup peft
    # peft_config = LoraConfig(
    #     task_type=TaskType.CAUSAL_LM,
    #     target_modules=["q_proj","v_proj"], #W_pack. query_key_value
    #     inference_mode=False,
    #     r=lora_rank,
    #     lora_alpha=lora_alpha,
    #     lora_dropout=0.1
    # )
    # model = get_peft_model(model, peft_config)


    # model.is_parallelizable = True
    # model.model_parallel = True
    # model.print_trainable_parameters()
    # 转为半精度
    # model.LLM_model = model.LLM_model.half()
    # model.float()

    scaler = torch.cuda.amp.GradScaler()
    def collate_fn(batch):
        """
        对batch数据进行处理
        :param batch: [一个getitem的结果,getitem的结果,getitem的结果]
        :return: 元组
        """
        # 初始化一个空字典来存储合并后的结果
        merged_dict = {}

        # 遍历列表中的每个字典
        for d in batch:
            # 遍历每个字典中的键值对
            for key, value in d.items():
                # 如果键已经存在于merged_dict中,将值合并为一个字符串,用逗号分隔
                if key in merged_dict:
                    merged_dict[key].append(value)
                else:
                    # 如果键不存在于merged_dict中,直接添加到merged_dict中
                    merged_dict[key] = [value]
        # 输出合并后的结果
        # print(merged_dict)

        return merged_dict

    print("Start Load Train Data...")
    train_params = {
        "batch_size": batch_size,
        "shuffle": True,
        "num_workers": 0,
    }
    training_set = QADataset(train_path,train=True)
    training_loader = DataLoader(training_set, **train_params,collate_fn=collate_fn)
    print("Start Load Validation Data...")
    val_params = {
        "batch_size": batch_size,
        "shuffle": False,
        "num_workers": 0,
    }
    val_set = QADataset(val_path,train=False)
    val_loader = DataLoader(val_set, **val_params,collate_fn=collate_fn)

    # optimizer = torch.optim.AdamW([{'params': model.bert_encoder.parameters(), 'lr': 1e-5},
    #                                 {'params': model.Qformer.parameters(), 'lr': 1e-3},
    #                                 {'params': model.ln_features.parameters(), 'lr': 1e-3},
    #                                 {'params': model.internlm_model.parameters(), 'lr': 1e-5},
    #                                 {'params': query_tokens_clone, 'lr': 1e-3}] #
    #                                 )

    optimizer = torch.optim.AdamW([{'params': model.parameters(), 'lr': lr}])

    # device_ids = [1,3,6,7]
    # model = DataParallel(model, device_ids=device_ids)


    print("Start Training...")
    for epoch in range(epochs):
        # train(epoch, model, training_loader, optimizer,scaler, gradient_accumulation_steps,model_output_dir)
        # print("Save Model To ",    加)
        # model.save_pretrained(model_output_dir)
        # 验证
        # print("Start Validation...")
        with torch.no_grad():
            predictions, actuals = validate(model, val_loader)
            # 验证结果存储
            final_df = pd.DataFrame({"Generated Text": predictions, "Actual Text": actuals})
            val_data_path = os.path.join(model_output_dir, f"predictions_{epoch}.csv")
            final_df.to_csv(val_data_path)
            print("Validation Data To ", val_data_path)


if __name__ == '__main__':
    main()

test_chat.py

# -*- coding: utf-8 -*-
import pandas as pd
from torch.utils.data import DataLoader
from transformers import AutoModel, AutoTokenizer,AutoModelForMaskedLM,BertTokenizer, GPT2LMHeadModel, TextGenerationPipeline,AutoModelForCausalLM,AutoConfig
from transformers.generation.utils import GenerationConfig

from peft import LoraConfig, get_peft_model, TaskType
from tqdm import tqdm
import torch
import os, time, sys
import numpy as np
from modeling_RALLM import RALLM
import argparse
from torch import autocast

parser = argparse.ArgumentParser()
parser.add_argument("--is_compress", default=False, type=bool)
parser.add_argument("--compressibility_factor", default=0.1, type=float,dest="0-1")
parser.add_argument("--LLM_model", default="Baichuan2_7B", type=str)
parser.add_argument("--output", default="output_corpus_2", type=str)
parser.add_argument("--encoder", default="gpt2_large", type=str)
parser.add_argument("--query_tokens", default=98, type=int)
parser.add_argument("--load_path", default="output_corpus_lora/checkpoint-200004", type=str)
parser.add_argument("--local_rank", default=0, type=int)
parser.add_argument("--num_train_epochs", default=10, type=int)
parser.add_argument("--learning_rate", default=5e-5, type=int)
parser.add_argument("--weight_decay", default=0.005, type=int)
parser.add_argument("--per_device_train_batch_size", default=6, type=int)
parser.add_argument("--max_length", default=4096, type=int)
parser.add_argument("--use_lora", default=False, type=bool)
parser.add_argument("--use_lora_gpt2", default=True, type=bool)



args = parser.parse_args()

def chat( model):

    while True:
        context1 = input("输入context:")
        context2 = input("输入context2:")
    #     context = """NVIDIA的A6000显卡是一款面向专业领域的高性能显卡。关于它的双精度(Double Precision)、单精度(Single Precision)和半精度(Half Precision)的算力,我们可以参考官方提供的规格参数。截至我最后更新的信息(2023年4月),以下是A6000显卡的相关算力数据:双精度(Double Precision): A6000显卡在双精度计算方面的性能通常不如单精度和半精度,因为双精度计算需要更多的计算资源和带宽。具体数值因显卡的不同批次和制造工艺的微小差异可能有所不同。单精度(Single Precision): A6000在单精度计算方面的性能通常很高,适合于大多数图形处理和一些科学计算任务。单精度计算是大多数显卡的主要优势。半精度(Half Precision): 半精度计算主要用于某些机器学习和深度学习应用,能提供更高的吞吐量。A6000显卡在半精度计算方面的性能通常很高。
    #     """
        flag_context = "<context> "*2
        text = f'<reserved_106>请复述这段被压缩的内容{flag_context} <reserved_107>'
        data = {"context":[[context1,context2]],"text_input":text}

        model.eval()
        with torch.no_grad():
            with autocast(device_type="cuda",dtype=torch.float16):
                text = data["text_input"]
                context = data["context"]
                preds = model.generate(
                    text=text,context = context
                )
        print("输出:",preds) 


def main():
   
    model = RALLM(args)   # 释放不再需要的模型
    device = torch.device(f"cuda:{args.local_rank}")
    model.to(device)


    if args.use_lora:
        from peft import LoraConfig, TaskType, get_peft_model,PeftModel,AutoPeftModelForCausalLM

        peft_config = LoraConfig(
            task_type=TaskType.CAUSAL_LM,
            target_modules=["W_pack",],
            inference_mode=False,
            r=256,
            lora_alpha=512,
            lora_dropout=0.1,
        )
        model.LLM_model.enable_input_require_grads()
        model.LLM_model = get_peft_model(model.LLM_model, peft_config)


    if args.use_lora_gpt2:
        print("使用lora训练模型"+"*"*10)
        from peft import LoraConfig, TaskType, get_peft_model,PeftModel,AutoPeftModelForCausalLM

        peft_config = LoraConfig(
            task_type=TaskType.CAUSAL_LM,
            target_modules=["c_attn",],
            inference_mode=False,
            r=64,
            lora_alpha=256,
            lora_dropout=0.1,
        )
        # model.LLM_model.enable_input_require_grads()
        model.context_encoder = get_peft_model(model.context_encoder, peft_config)


    print(model)


    base_load_path = "output_qa3w_lora_gpt2_base_corpus"
    # 列出所有分块模型参数文件的文件名
    file_list = ['pytorch_model.bin']

    # 创建一个空的模型状态字典
    state_dict = {}

    # 遍历所有分块文件并加载它们
    for file_name in file_list:
        # 加载单个分块文件的模型参数
        part_state_dict = torch.load(os.path.join(base_load_path,file_name),map_location=f"cuda:{args.local_rank}")
        
        # 将加载的模型参数合并到总的模型状态字典中
        state_dict.update(part_state_dict)

    # 将合并后的模型状态字典加载到模型中
    model.load_state_dict(state_dict)
    model.LLM_model.generation_config = GenerationConfig.from_pretrained(base_load_path)



    # 加载模型的参数
    # load_path = '/data2/xinyuuliu/InternLM-XComposer/output12/epoch9/index_29999/LLM_model_0.109371.pth'
    # checkpoint = torch.load(load_path,map_location="cuda:0") #,map_location="cuda:3"
    # # 将参数加载到模型中
    # model.load_state_dict(checkpoint)
    # 转为半精度
    # model.LLM_model = model.LLM_model.half()
    model = model.half()
    # model.float()

    chat(model)
 


if __name__ == '__main__':
    main()


 

ds_config.json

{
    "fp16": {
        "enabled": "auto",
        "loss_scale": 0,
        "loss_scale_window": 1000,
        "initial_scale_power": 16,
        "hysteresis": 2,
        "min_loss_scale": 1
    },
    "bf16": {
        "enabled": "auto"
    },

    "zero_optimization": {
        "stage": 2,
        "offload_optimizer": {
            "device": "none",
            "pin_memory": true
        },
        "allgather_partitions": true,
        "allgather_bucket_size": 2e8,
        "overlap_comm": true,
        "reduce_scatter": true,
        "reduce_bucket_size": 2e8,
        "contiguous_gradients": true
    },

    "gradient_accumulation_steps": "auto",
    "gradient_clipping": "auto",
    "steps_per_print": 10,
    "train_batch_size": "auto",
    "train_micro_batch_size_per_gpu": "auto",
    "wall_clock_breakdown": false

}

fine-tune.py

# This code is based on the revised code from fastchat based on tatsu-lab/stanford_alpaca.
from dataclasses import dataclass, field
import json
import math
import logging
import os
import random
from typing import Dict, Optional, List
import torch
from torch.utils.data import Dataset
from deepspeed import zero
from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
import transformers
from transformers import Trainer, deepspeed
from transformers.trainer_pt_utils import LabelSmoother
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from accelerate.utils import DistributedType
from torchvision import transforms
from typing import Dict, Optional, Sequence, List
from modeling_RALLM import RALLM
# from dataset_batch import QADataset
from dataset_rerank_en import QADataset
# from dataset_rerank_en_gpt import QADataset
# from dataset_rerank import QADataset
import argparse
from transformers import AutoModel, AutoTokenizer,AutoModelForMaskedLM,BertTokenizer, GPT2LMHeadModel, TextGenerationPipeline,AutoModelForCausalLM,AutoConfig
from torch.optim import AdamW


IGNORE_TOKEN_ID = LabelSmoother.ignore_index

parser = argparse.ArgumentParser()

parser.add_argument("--is_compress", default=True, type=bool)
parser.add_argument("--compressibility_factor", default=0.1, type=float,dest="0-1")
parser.add_argument("--LLM_model", default="Baichuan2_7B", type=str)
parser.add_argument("--output", default="output_corpus_lora2", type=str)
parser.add_argument("--encoder", default="gpt2_large", type=str)
parser.add_argument("--query_tokens", default=98, type=int)
parser.add_argument("--load_path", default="output_english_longformer_rerank100k/checkpoint-112356", type=str)
parser.add_argument("--local_rank", default=-1, type=int)
parser.add_argument("--num_train_epochs", default=10, type=int)
parser.add_argument("--learning_rate", default=5e-5, type=float)
parser.add_argument("--weight_decay", default=0.01, type=float)
parser.add_argument("--per_device_train_batch_size", default=6, type=int)
parser.add_argument("--max_length", default=4096, type=int)
parser.add_argument("--use_lora", default=True, type=bool)
parser.add_argument("--use_lora_gpt2", default=False, type=bool)
parser.add_argument("--train_dataset", default="data/qa_3w_summary.csv", type=str)



args = parser.parse_args()
os.environ["CUDA_VISIBLE_DEVICES"] = str(args.local_rank)


@dataclass
class TrainingArguments(transformers.TrainingArguments):
    # cache_dir: Optional[str] = field(default=None)
    optim: str = field(default="adamw_torch")
    local_rank: int = field(default=None)


@dataclass
class DataCollatorForSupervisedDataset(object):
    """Collate examples for supervised fine-tuning."""
    def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
        """
        对batch数据进行处理
        :param batch: [一个getitem的结果,getitem的结果,getitem的结果]
        :return: 元组
        """

        # 初始化一个空字典来存储合并后的结果
        merged_dict = {}

        # 遍历列表中的每个字典
        for d in instances:
            # 遍历每个字典中的键值对
            for key, value in d.items():
                # 如果键已经存在于merged_dict中,将值合并为一个字符串,用逗号分隔
                if key in merged_dict:
                    merged_dict[key].append(value)
                else:
                    # 如果键不存在于merged_dict中,直接添加到merged_dict中
                    merged_dict[key] = [value]
        # 输出合并后的结果
        # print(merged_dict)

        return merged_dict



def train():
    global model

    train_path = args.train_dataset
    # train_path = "data/news_summary_30w.csv"
    # val_path = "QA_5000_summary.csv"
    
    device_map = None
    world_size = int(os.environ.get("WORLD_SIZE", 1))
    ddp = world_size != 1


    torch.cuda.device(0)
    torch.cuda.empty_cache()  # 释放显存


    # init model and tokenizer
    model = RALLM(args)   # 释放不再需要的模型


    
    device = torch.device(f"cuda:{args.local_rank}")
    model.to(device)
    # torch.cuda.device(0)
    torch.cuda.empty_cache()  # 释放显存

    if args.use_lora:
        print("使用lora训练模型"+"*"*10)
        from peft import LoraConfig, TaskType, get_peft_model,PeftModel,AutoPeftModelForCausalLM

        peft_config = LoraConfig(
            task_type=TaskType.CAUSAL_LM,
            target_modules=["W_pack",],
            inference_mode=False,
            r=256,
            lora_alpha=512,
            lora_dropout=0.1,
        )
        model.LLM_model.enable_input_require_grads()
        model.LLM_model = get_peft_model(model.LLM_model, peft_config)


    if args.use_lora_gpt2:
        print("使用lora训练模型"+"*"*10)
        from peft import LoraConfig, TaskType, get_peft_model,PeftModel,AutoPeftModelForCausalLM

        peft_config = LoraConfig(
            task_type=TaskType.CAUSAL_LM,
            # target_modules=["wte","c_attn",],
            target_modules=["query","key","value","query_global","key_global","value_global"],
            inference_mode=False,
            r=128,
            lora_alpha=512,
            lora_dropout=0.1,
        )
        # model.LLM_model.enable_input_require_grads()
        model.context_encoder = get_peft_model(model.context_encoder, peft_config)



    print(model)

    for param in model.context_encoder.parameters():
        param.requires_grad = False
    # layers_to_modify = [27,28,29,30,31,32,33,34, 35]
    #     # Iterate over all named parameters in the model
    # for name, param in model.context_encoder.named_parameters():
    #     # Check if the parameter belongs to the specified layers
    #     if any(f"h.{layer}."  in name for layer in layers_to_modify):
    #         # Set requires_grad to True or False depending on whether you want the parameter to be trainable or not
    #         param.requires_grad = True  # or False if you want to freeze the layer
    #     # if f"ln_f"  in name:
    #     #     # Set requires_grad to True or False depending on whether you want the parameter to be trainable or not
    #     #     param.requires_grad = True  # or False if you want to freeze the layer
    #     if f"wte"  in name:
    #     #     # Set requires_grad to True or False depending on whether you want the parameter to be trainable or not
    #         param.requires_grad = True  # or False if you want to freeze the layer
    #     if f"wpe"  in name:
    #     #     # Set requires_grad to True or False depending on whether you want the parameter to be trainable or not
    #         param.requires_grad = True  # or False if you want to freeze the layer

    for param in model.ln_features.parameters():
        param.requires_grad = True


    # 遍历每一层并冻结参数
    # for param in model.LLM_model.parameters():
    #     param.requires_grad = False

    # 冻结除了lora_A和lora_B以外的所有层
    trained = []
    untrained = []
    # for name, param in model.LLM_model.named_parameters():
    #     # if 'v_proj.lora_A' in name or 'v_proj.lora_B' in name or 'q_proj.lora_B' in name or 'q_proj.lora_B' in name:
    #     # if 'lora_A' in name or 'lora_B' in name or "layers.30" in name or "layers.31" in name or "embed_tokens" in name: 
    #     if 'lora_A' in name or 'lora_B' in name or "embed_tokens" in name: 

    #         param.requires_grad = True
    #         trained.append(name)
    #     else:
    #         param.requires_grad = False
    #         untrained.append(name)

    # print("可训练的大模型层",trained)
    # print("不可训练的大模型层",untrained)

    # Print trainable and non-trainable parameters
    trainable_params = []
    non_trainable_params = []

    for name, param in model.named_parameters():
        if param.requires_grad:
            trainable_params.append(name)
        else:
            non_trainable_params.append(name)

    print("Trainable Parameters:")
    print("\n".join(trainable_params))

    print("\nNon-Trainable Parameters:")
    print("\n".join(non_trainable_params))


    if args.load_path:
        base_load_path = args.load_path
    # 列出所有分块模型参数文件的文件名

        if base_load_path.endswith(".pth"):
            state_dict = torch.load(base_load_path,map_location=device)
        else:
            file_list = ['pytorch_model.bin']
            # 创建一个空的模型状态字典
            state_dict = {}
            # 遍历所有分块文件并加载它们
            for file_name in file_list:
                # 加载单个分块文件的模型参数
                part_state_dict = torch.load(os.path.join(base_load_path,file_name),map_location=device)
                
                # 将加载的模型参数合并到总的模型状态字典中
                state_dict.update(part_state_dict)
        # 将合并后的模型状态字典加载到模型中
        print("state_dict:")
        print(state_dict.keys())
        
        model.load_state_dict(state_dict,strict=False)

    # # 分离 model.Qformer 的参数和其他所有参数
    # qformer_params = set(model.Qformer.parameters())
    # other_params = [p for p in model.parameters() if p not in qformer_params]

    # # 创建参数组
    # param_groups = [
    #     {'params': list(qformer_params), 'lr': 1e-3},
    #     {'params': other_params, 'lr': 1e-5}
    # ]


    # 使用参数组创建 AdamW 优化器
    # optimizer = AdamW(param_groups)
    

    training_set = QADataset(train_path,train=True)
    # val_set = QADataset(val_path,train=False)

    print(training_set[0])

    # 设置训练参数
    training_args = TrainingArguments(
        local_rank=args.local_rank,
        output_dir=args.output,          # 输出目录
        num_train_epochs=args.num_train_epochs,              # 训练轮数
        per_device_train_batch_size=args.per_device_train_batch_size,  # 每个设备的批大小
        warmup_steps=500,                # 预热步骤
        weight_decay=0.01,               # 权重衰减
        logging_dir='./logs',            # 日志目录
        deepspeed = "ds_config.json",
        gradient_accumulation_steps = 1 ,
        save_strategy = "epoch" ,
        learning_rate = 5e-5 ,
        # lr_scheduler_type='linear',
        # logging_steps= 100,
    )

    data_collator = DataCollatorForSupervisedDataset()

    # Start trainner
    trainer = Trainer(
        model = model,
        tokenizer = model.LLM_tokenizer,
        train_dataset=training_set,
        # eval_dataset=val_set,
        data_collator=data_collator,
        args = training_args,
        # optimizers=(optimizer, None)  # 自定义优化器

    )

    trainer.train()
    trainer.save_state()
    trainer.save_model(output_dir=args.output)



if __name__ == "__main__":
    train()

#https://arxiv.org/pdf/2102.05951.pdf

fine-tune.sh

hostfile=""

# deepspeed --include localhost:1,2,3 --hostfile=$hostfile fine-tune.py  \
#     --report_to "none" \
#     --data_path "/data1/xinyuuliu/qa_data/professional_data/train_二阶段.json" \
#     --model_name_or_path "/data1/xinyuuliu/Baichuan2-13B-Chat" \
#     --output_dir "output_lora3_1_2" \
#     --model_max_length  4000\
#     --num_train_epochs 10 \
#     --per_device_train_batch_size 4 \
#     --gradient_accumulation_steps 1 \
#     --save_strategy epoch \
#     --learning_rate 2e-4 \
#     --lr_scheduler_type constant \
#     --adam_beta1 0.9 \
#     --adam_beta2 0.98 \
#     --adam_epsilon 1e-8 \
#     --max_grad_norm 1.0 \
#     --weight_decay 1e-4 \
#     --warmup_ratio 0.0 \
#     --logging_steps 1 \
#     --gradient_checkpointing True \
#     --deepspeed ds_config.json \
#     --bf16 True \
#     --tf32 True \
#     --use_lora True \
#     --load_lora_path /data1/xinyuuliu/Baichuan2-main/fine-tune/output_lora3_1/checkpoint-8260
    # --use_NEFT True
    # --use_frozen True
# export CUDA_LAUNCH_BLOCKING=1

# CUDA_VISIBLE_DEVICES=“2,3,4,5,6,7” 
deepspeed --include localhost:0,1,2,3,4,5,6,7 --master_port 29501 --hostfile=$hostfile fine-tune.py \
    --encoder longformer \
    --query_tokens 32 \
    --output output_english_longformer_msmarco2019\
    --num_train_epochs 20 \
    --per_device_train_batch_size 1 \
    # --load_path /data2/xinyuuliu/Baichuan2_qformer_bert/output_30w/checkpoint-22488 \

 

posted @ 2024-05-27 17:30  高颜值的殺生丸  阅读(15)  评论(0编辑  收藏  举报

作者信息

昵称:

刘新宇

园龄:4年6个月


粉丝:1209


QQ:522414928